我在apache_beam中找到了专门用于tfrecords的io函数,这可以实现:
from apache_beam.io.tfrecordio import ReadFromTFRecord
class VerifyOutput(beam.DoFn):
def process(self, pcollection):
try:
pcollection = pcollection.element
except AttributeError:
pass
logging.info(pcollection.subject_id)
(pipeline | ReadFromTFRecord(opt.input_path, compression_type='auto', validate=True)
| beam.ParDo(VerifyOutput()) )
这将打印出tfrecord的字节串。如何在束流管道中解析这一点?这样才能把个别元素拿回来?
我确实找到了tensorflow成语用于阅读tfrecords 这里,但这对我没有用,我想是因为beam并不是“线程安全的”(挂起永远挂起)。
另外,在没有实际运行管道的情况下运行ReadFromTFRecord (用于调试/学习/测试)是否可能?比如在一个衣冠楚楚的笔记本里?
发布于 2019-06-06 23:19:09
您需要指定编码器来解码这些特性。像这样的事情应该能做好:
import tensorflow as tf
import tensorflow_transform as tft
from apache_beam.io.tfrecordio import ReadFromTFRecord
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.beam import tft_beam_io
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema
...
# for each feature, adapt accordingly:
column_schemas['FEATURE_NAME']= dataset_schema.ColumnSchema(tf.int64, [], dataset_schema.FixedColumnRepresentation())
raw_data_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema(column_schemas))
data_coder = tft.coders.ExampleProtoCoder(raw_data_metadata.schema)
_ = (pipeline | ReadFromTFRecord(opt.input_path, coder=data_coder, compression_type='auto', validate=True)
| beam.ParDo(VerifyOutput()) )
有关更详细的示例,请参见这里。
https://stackoverflow.com/questions/46307498
复制