首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在使用TFrecords的TPU中,使用图像分割来消除训练损失

在使用TFrecords的TPU中,使用图像分割来消除训练损失
EN

Stack Overflow用户
提问于 2020-06-16 00:40:28
回答 1查看 334关注 0票数 1

我是一个初学者,试图在Kaggle内核中使用Tensorflow处理TPU。我之前在GPU中使用数据集训练了一个Unet模型,现在我正在尝试在TPU中实现它。我从数据集图像和蒙版中创建了一个tfrecord,然后TFrecord返回图像和蒙版。当我尝试在TPU中训练时,损失总是Nan,即使度量精度是正常的。由于这是我在GPU中使用的相同模型和损失,我猜测问题出在tfrecord或加载数据集上。加载数据的代码如下:

代码语言:javascript
运行
复制
 def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / (255.0)  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def decode_image_mask(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float64) / (255.0)  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    image=tf.image.rgb_to_grayscale(image)
    image=tf.math.round(image)
    return image

def read_tfrecord(example):
    TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "mask": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    mask=decode_image_mask(example['mask'])
    return image, mask 



def load_dataset(filenames, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset



def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

那么,我做错了什么呢?

EN

回答 1

Stack Overflow用户

发布于 2020-06-16 19:29:28

事实证明,问题是我将数据分批,并将其分批到20以正确查看matplotlib中的图像和掩码,这搞乱了数据发送到模型的方式,因此出现了Nan损失。制作另一个数据集的副本并使用该副本查看图像,同时发送原始数据集进行训练,解决了这个问题。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62392904

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档