首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >用BatchDataset和numpy数组的混合输入拟合keras模型

用BatchDataset和numpy数组的混合输入拟合keras模型
EN

Stack Overflow用户
提问于 2021-04-18 17:02:18
回答 1查看 1.1K关注 0票数 1

我做了一个接受两个输入的模型。当我用两个numpy数组来拟合模型时,它可以工作。下面是一个例子:

代码语言:javascript
运行
复制
model.fit(x=[image_input, other_features], y = y, epochs=epochs)

但是,我的问题是,other_features是一个numpy数组,而image_input使用tf.keras.preprocessing.image_dataset_from_directory加载keras。我面临的问题是:

  1. 如何从image_input中正确地给y?当我只使用一个输入image_input训练模型时,y被打包在其中,所以我不必在另一个y中指定它才能将BatchDatasetnumpy.array放在一起?无论如何,当我这样做时,我收到了一个错误:

代码语言:javascript
运行
复制
ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>", "<class 'numpy.ndarray'>"}), <class 'NoneType'>
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-19 12:56:59

好吧,我解决了。我会写详细的解决方案,因为我看到了类似的问题,多次发帖,没有答案。这是混合输入,解决方案是依赖自定义生成器。

第一步是制作自定义生成器。您必须返回输入+输出的列表/数据。I followed this link to make mine。下面是我的生成器的示例代码:

代码语言:javascript
运行
复制
def generator(subset, batch_size=256):
    i = 0
    DIR = f"data/{subset}"
    image_files = pd.read_csv(f"{DIR}.csv")
    while True:
        batch_x = [list(), list()] # I have two input: image + feature vector
        batch_y = list() # output
        for b in range(batch_size):
            if i == len(image_files):
                i = 0
            filename = image_files.loc[i, "filename"]
            label = image_files.loc[i, "Class"]
            image_file_path = f'{DIR}/{label}/{filename}'
            i += 1
            image = cv2.imread(image_file_path, 0)
            batch_x[0].append(image)
            feat = get_feature_vector(filename)
            batch_x[1].append(feat)
            batch_y.append(one_hot(label))

        batch_x[0] = np.array(batch_x[0])  # convert each list to array
        batch_x[1] = np.array(batch_x[1])
        batch_y = np.array(batch_y)
        yield batch_x, batch_y

然后,利用函数tensorflow建立模型。当您匹配数据时,使用所需的args调用生成器:

代码语言:javascript
运行
复制
history = model.fit(generator('train'),
                             validation_data = generator('validate'))
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67151256

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档