首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Keras中定义DQN模型的输出层形状

如何在Keras中定义DQN模型的输出层形状
EN

Stack Overflow用户
提问于 2019-07-24 00:16:17
回答 1查看 593关注 0票数 0

我正在尝试学习DQN代理使用Keras来玩Tic Tac Toe。问题是我的输出与我预期的形状不同。

详细信息:输入形状:(BOARD_SIZE ^ 2) * 3 -->这是一个热编码的游戏板输出形状:我希望输出将列出(BOARD_SIZE^2)的大小,因为它应该有可用动作的数量

问题:输出具有输入层[(BOARD_SIZE ^ 2) *3] * Number of actions (BOARD_SIZE^2)的形状大小

我试图寻找解决方案,但Keras文档相当糟糕。请帮助

这是我的模型

代码语言:javascript
运行
复制
    def create_model(self, game: GameController) -> Sequential:
    input_size = (game.shape ** 2) * 3

    model = Sequential()
    model.add(Dense(input_size, input_dim=1, activation='relu'))
    model.add(Dense(int(input_size / 2), activation='relu'))
    model.add(Dense(int(input_size / 2), activation='relu'))
    model.add(Dense((game.shape ** 2), activation='linear'))
    model.compile(loss="mean_squared_error", optimizer=Adam(self.alpha))

    return model

这就是我试图获得输出的方式

代码语言:javascript
运行
复制
q_values = self.model.predict(processed_input)

这是BOAD预处理(一个热编码)

代码语言:javascript
运行
复制
def preprocess_input(self, game: GameController) -> list:
    encoded_x = copy.deepcopy(game.board)
    encoded_o = copy.deepcopy(game.board)
    encoded_blank = copy.deepcopy(game.board)

    for row in range(game.shape):
        for col in range(game.shape):
            if encoded_x[row][col] == 'X':
                encoded_x[row][col] = 1
            else:
                encoded_x[row][col] = 0

            if encoded_o[row][col] == 'O':
                encoded_o[row][col] = 1
            else:
                encoded_o[row][col] = 0

            if encoded_blank[row][col] == '-':
                encoded_blank[row][col] = 1
            else:
                encoded_blank[row][col] = 0

    chained_x = list(chain.from_iterable(encoded_x))
    chained_o = list(chain.from_iterable(encoded_o))
    chained_blank = list(chain.from_iterable(encoded_blank))

    string_board = list(chain(chained_x, chained_o, chained_blank))
    board_to_int = [int(element) for element in string_board]

    return board_to_int
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-07-24 23:21:54

好吧,经过几次尝试,我发现我的输入被转置了,所以我将input_dim设置为((BOARD_SIZE^2)*3),并将input_board重塑为(1,(BOARD_SIZE^2)*3)修复了问题。希望它能在未来帮助其他人:)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57168356

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档