首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Google最新发布大规模分布式机器学习架构Tensor2Robot

Github项目推荐 | Google-Research发布基于TensorFlow的大规模分布式机器学习架构Tensor2Robot

‍Tensor2Robot - Distributed machine learning infrastructure for large-scale robotics research

本项目包含分布式机器学习和强化学习基础结构。

它在Alphabet内部使用,开源的目的是使Robotics @ Google的研究为更广泛的机器人和计算机视觉社区更具可重现性。

Github项目地址:

https://github.com/google-research/tensor2robot

Site:

https://ai.google/research/teams/brain/robotics/

已经使用Tensor2Robot的项目和文章有:QT-Opt、Grasp2Vec 等等。

特点

Tensor2Robot(以下简称:T2R)是一个用于大规模深度神经网络的训练、评估和推理的库,专门针对与机器人感知和控制相关的神经网络而定制。它基于TensorFlow深度学习框架进行开发。

机器人研究中的常见任务涉及向神经网络图添加新的传感器模态或新的标签张量。这涉及到:

1)更改保存的数据,

2)在训练时更改数据管道代码以便读取新模态,

3)添加新的tf.placeholder以在测试时处理新的输入模态。

Tensor2Robot的主要功能是为步骤 2 和 3 自动生成TensorFlow代码。Tensor2Robot可以自动生成模型的占位符以匹配其输入,或者导出可以与TFExportedSavedModelPolicy一起使用的SavedModel,这样原始的图形就不必重新构建。。

机器学习中遇到的另一个常见任务涉及到裁剪/变换输入模态,例如jpeg解码和在训练时应用随机图像失真。Preprocessor(预处理器类)声明它自己的输入特性和标签,并且期望输出与输入要素和模型标签兼容的形状。你可以在 预处理器的链接中找到相关示例。

快速开始

环境要求:Python 3

T2R 模型

要使用Tensor2Robot,用户需要定义一个T2RModel对象,该对象按规范定义其输入要求——一个用于其功能(feature_spec),另一个用于其标签(label_spec):

这些规范定义了所有必需和可选的张量,以便于调用model_fn。 使用模型的输入管道参数化的输入管道将确保满足所有必需的规范。注意:我们总是会省略批量维度,只指定单个元素的形状。

在训练时,T2RModel提供model_train_fn或model_eval_fn作为model_fn参数中的tf.estimator.Estimator类。model_train_fn和model_eval_fn都是根据inference_network_fn的特征、标签和输出来定义的,它们可能实现了训练/评估图的共享部分。

请注意左侧如何具有name的值,这个值与ExtendedTensorSpec中右侧不同。 左侧的键在model_fn中用于访问加载的张量,而在创建parse_tf_example_fnnumpy_feed_dict时会使用这个名称。 我们确保这个名称在整个规范中是唯一的,除非这个规范是匹配的,否则我们无法保证映射功能。

继承T2RModel的好处

功能和标签的独立输入规范。

自动生成tf.train.Examples和tf.train.SequenceExamples的tf.data.Dataset管道。

对于策略推断,T2RModel可以生成占位符或导出密封的SavedModel,并且可以与ExportSavedModelPolicy一起使用。

为Estimator自动构建model_fn,用于共享单个inference_network_fn的训练和评估图。

可以在单个模型下将多个模型的inference_network_fn和model_train_fn组合在一起。 这种抽象允许我们实现调用其子模型的model_train_fn的通用元学习模型(例如MAML)。

自动支持GPU和TPU上的分布式训练。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20190516A0K3HQ00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券