dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
train_images = dataset['train']
test_images = dataset['test']
train_batches = (
train_images
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)
现在,我想将test_images大小缩小到100个图像。我期待一些代码,如:
test_images = test_images[100]
但这会带来一个错误:
'ParallelMapDataset' object is not subscriptable
发布于 2021-10-18 06:10:20
使用take()
方法,您可以从目标数据集中获取批处理或项。
如果数据集是批处理的:
test_images.take((100 // BATCH_SIZE) + 1)
对数据集进行批处理时,它将包含批处理或组。
因此,假设您将数据批处理为大小为32的数据,test_images.take(1)
将返回32个元素,换句话说,是一个批处理。test_images.take(2)
将返回64个元素等。
如果不是批次的话:
test_images.take(100)
与批处理数据集不同,dataset将返回传递到take()
方法的元素数量。
https://stackoverflow.com/questions/69611411
复制