首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >带Keras神经网络的8x8棋盘游戏的input_shape

带Keras神经网络的8x8棋盘游戏的input_shape
EN

Stack Overflow用户
提问于 2018-09-18 17:24:30
回答 1查看 217关注 0票数 0

当涉及到建立一个NN时,我的耳朵是非常绿色的。现在,我收到以下错误:

ValueError:检查时出错:期望dense_1_input具有三维,但得到形状为(8,8)的数组

背景:我使用的是8x8板--这是我初始化它的方式:

代码语言:javascript
复制
self.state = np.zeros((LENGTH, LENGTH))

下面是我构建模型的代码:

代码语言:javascript
复制
def build_model(self):
    #builds the NN for Deep-Q Model
    model = Sequential() 
    model.add(Dense(24,input_shape = (LENGTH, LENGTH), activation='relu'))
    model.add(Flatten())
    model.add(Dense(24, activation='relu'))
    model.add(Dense(self.action_size, activation = 'linear'))
    model.compile(loss='mse', optimizer='Adam')

    return model

我想,因为董事会的形状是(8,8),input_size应该是一样的。不知道我做错了什么?

以防万一这是有用的:

我制作的游戏非常简单,包括棋盘上的5块:

  • player1有1块,只需1步就可以向前和向后移动。
  • player2有4块,只能从对角线的位置向前移动1步。

player1的目标是到达板子的另一边,player2的目标是诱捕1名球员,这样他就不能移动。

任何帮助都将不胜感激!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-09-18 20:18:41

我设法让它跑了..。但这就是:我把input_shape改成了这个input_shape = (LENGTH, ),问题是我不知道为什么它会接受这个?如果我做得对吗?

不,您使用的第一个输入形状(即(LENGTH,LENGTH))是正确的。注意,input_shape参数指定了的形状--一个并且只有一个训练样本,而不是所有的培训数据。例如,如果您有1000个8x8的板,那么训练数据的形状为(1000, 8, 8),但是input_shape参数必须指定为(8,8),即一个训练样本的形状。

此外,由于您可能知道或可能不知道the Dense layer is applied on the last axis,并且由于您已经将密集层的输入形状定义为(LENGTH,LENGTH),所以稠密层将不应用于所有输入(即板),而是应用于第二轴(即板的行)。我想这不是您想要的,所以这里有两个选项: 1)您可以将扁平层移到顶部,并将其作为模型的第一层:

代码语言:javascript
复制
model = Sequential()
model.add(Flatten(input_shape=(LENGTH, LENGTH)))
model.add(Dense(24, activation='relu'))
# the rest of the model

或者2)您可以重塑训练数据,使其具有(num_boards, LENGTH*LENGTH)的形状,并相应地调整input_shape参数(在这种情况下,您不需要模型中的平坦层,您可以删除它):

代码语言:javascript
复制
training_data = np.reshape(training_data, (num_boards, LENGTH*LENGTH))

model = Sequential() 
model.add(Dense(24,input_shape=(LENGTH*LENGTH,), activation='relu'))

另外,如果您只有一个培训/测试样本(这是奇怪的!)或者无论有多少训练/测试样本,训练/测试数据数组的第一轴必须对应于样本,即所有训练/测试数据数组的形状必须是(num_sample, ...)。否则,在调用fit/predict/evaluate方法时,您可能会收到抱怨形状的错误,就像已经得到的一样。同样的情况也适用于包含培训/测试标签的数组。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52391757

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档