首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用keras在神经网络中输入n个项目的数组并输出大小为k的数组?

如何使用keras在神经网络中输入n个项目的数组并输出大小为k的数组?
EN

Stack Overflow用户
提问于 2020-05-03 07:41:05
回答 2查看 487关注 0票数 0

我对机器学习和在keras中使用神经网络是个新手。我正在尝试在神经网络的帮助下使用强化学习,如果机器人要与人类比赛,它可能最终会预测机器人在垄断游戏中采取的正确行动。

为此,我尝试使用神经网络,它接收23个浮点数的数组(定义玩家状态),并输出7个浮点数的数组(在给定时间可以采取的最大可能操作数)。我当前的NN如下:

代码语言:javascript
复制
model = Sequential()
model.add(Dense(150, input_dim=23, activation='relu'))
model.add(Dense(7, activation='sigmoid'))
model.compile(loss='mse', optimizer=Adam(lr=0.2))

我的意图是有一个3层的nn,有150个(神经元)隐藏层,最后一层有7个神经元。

代码语言:javascript
复制
#An input example would be:
state = [0.35,0.65,0.35,3.53...] # array of 23 items, float numbers.
output = model.predict(state)

#I expect output to be:
[0.21,0.12,0.98,0.32,0.44,0.12,0.41] #array size of 7

#Then I could simply just use the index with the highest number as the action to take. 
action = output.index(max(output))

我不知道为什么,但我得到了这个错误:ValueError: Error when checking input: expected dense_23_input to have shape (23,) but got array with shape (1,)

我确信,如果我只有一个最后一层的神经元预测一个范围内的整数会更好,例如数字1到7。然而,我不知道有任何激活函数可以做到这一点。请随时为此目的建议更好的神经网络模型,我将非常感谢。我知道这可能不是最好的模型。

但本质上,这里的主要问题是,如何输入单个大小为23的数组,并输出大小为7的数组?

谢谢!

EN

回答 2

Stack Overflow用户

发布于 2020-05-03 08:08:59

我对keras不是很熟悉,但是在pytorch中,所有的东西都是批量工作的,这就是为什么你得到了比你想要的更多的维度。

第一个线性图层的输入应具有尺寸(batch_size,23)。如果你想看看一个例子是如何在整个网络中运行的,首先像input.reshape(1,-1)一样重塑它。输出将显示dims (1,7)。您应该将最后一个图层激活更改为softmax

票数 0
EN

Stack Overflow用户

发布于 2020-05-03 23:51:42

感谢您的贡献!我设法解决了这个问题。我最终使用了一个具有单个神经元输出和一个sigmoid激活函数的3层nn,如下所示:

代码语言:javascript
复制
model = Sequential()
model.add(Dense(150, input_shape=(23,), activation='relu'))
model.add(Dense(1, input_shape=(7,), activation='sigmoid'))
model.compile(loss='mse', optimizer=Adam(lr=0.2))

#Required input would look something like this:
input =np.array([0.2,0.1,0.5,0.5,0.8,0.3,0.2,0.2,0.2,0.9,0.2,0.8,0.6,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.1,0.5,0.4])
input = np.reshape(input,(1,-1))

#The output would look like something like this:
print(saved_model.predict(input))
#[[0.00249215 0.15893634 0.50805619 0.86176279 0.34417642 0.29258215
  0.131994  ]]

从这里开始,我将简单地获得概率最高的索引,以确定我的输入的类别。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61567766

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档