首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >我在加载Tensorflow (python3.7)中的数据集时遇到了问题

我在加载Tensorflow (python3.7)中的数据集时遇到了问题
EN

Stack Overflow用户
提问于 2019-01-26 01:18:57
回答 3查看 722关注 0票数 2

我正在尝试将图像数据集加载到tensorflow中,但我面临一个问题,要正确加载它。实际上,我在C驱动器中有一个名为PetImages的文件夹,其中包含两个名为cat和dog的文件夹。每个文件夹都保存更多的12450图像,因此总共是24500加上图像。我正在用以下代码加载它们:

代码语言:javascript
运行
复制
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
DATADIR = "C:\Datasets\PetImages"
CATEGORIES = ["Dog","Cat"]
for the category in CATEGORIES:
path = os.path.join(DATADIR, category)

for img in os.listdir(path):
    img_array = cv2.imread(os.path.join(path,img), cv2.IMREAD_GRAYSCALE)
    plt.imshow(img_array, cmap="gray")
    plt.show()
    break
break

代码的结果看起来非常好,它显示了文件夹的第一个图像。然后,使用以下代码将整个数组的形状转换为所需的像素速率:

代码语言:javascript
运行
复制
IMG_SIZE=50
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
plt.imshow(new_array, cmap = "gray")
plt.show()

这个部分也很好,但是我想混合(洗牌)图像,这样我就可以对系统进行困惑并以这种方式检查准确性,但问题是它只显示了12450图像,在这段代码之后:

代码语言:javascript
运行
复制
training_data = []
def create_training_data():
for category in CATEGORIES:
    path = os.path.join(DATADIR, category)
    class_num = CATEGORIES.index(category)
for img in os.listdir(path):
    try:
        img_array = cv2.imread(os.path.join(path,img), 
cv2.IMREAD_GRAYSCALE)
        new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
        training_data.append([new_array, class_num])
    except Exception as e:
        pass       
create_training_data()
print(len(training_data)

然后使用随机的,我没有获得成功的洗牌图像从两个文件夹,它只显示一个文件夹的值。

代码语言:javascript
运行
复制
import random   
random.shuffle(training_data)
for the sample in training_data[:10]:  
print(sample[1])

但是我的结果是111-1,而不是随机生成的,比如011001,这个样式,我的意思是,下一个是1,或者。

你的帮助对我很有价值。提前感谢

EN

回答 3

Stack Overflow用户

发布于 2019-01-26 05:44:26

在我看来就像一个缩进错误。第二个for循环位于第一个for循环之外,这将导致第一个循环完全终止,并在输入第二个循环之前将class_num设置为1。你可能想把它们筑巢。尝试:

代码语言:javascript
运行
复制
def create_training_data():
    for category in CATEGORIES:
        path = os.path.join(DATADIR, category)
        class_num = CATEGORIES.index(category)
        for img in os.listdir(path):
            try:
                img_array = cv2.imread(os.path.join(path,img), cv2.IMREAD_GRAYSCALE)
                new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
                training_data.append([new_array, class_num])
            except Exception as e:
                pass       
create_training_data()
print(len(training_data)
票数 3
EN

Stack Overflow用户

发布于 2019-01-26 06:23:00

您可以尝试对训练数据的掩码或索引进行洗牌。

代码语言:javascript
运行
复制
import random
index=[k for k in range(len(training_data))]
shuffIndex=random.shuffle(index)
shuffTrainigData=[training_data[val] for val in shuffIndex]

希望它能帮上忙

票数 0
EN

Stack Overflow用户

发布于 2022-06-10 18:10:35

您的代码只加载了Dog数据--训练数据,因此是用于训练lentgh的12450。这意味着你只是洗牌狗的形象,这将给你1s。你的训练时间应该是25000。修补你的契约,你应该会没事的。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54374825

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档