首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >CV新进展 | 迭代视觉推理框架 | 李飞飞团队 | Tensorflow的MNIST案例

CV新进展 | 迭代视觉推理框架 | 李飞飞团队 | Tensorflow的MNIST案例

作者头像
用户7623498
发布2020-08-04 15:14:25
发布2020-08-04 15:14:25
5010
举报

技术引领

陈鑫磊、李佳、李飞飞、Abhinav Gupta等人提出了一种新的迭代视觉推理框架

李飞飞团队提出了一种新的迭代视觉推理框架。该框架包括两个核心模块:一个局部模块,用空间记忆来存储之前并行更新的认知;一个全局的图推理模块。除了卷积之外,它还使用图来编码区域和类之间的空间和语义关系,并在图上传递消息。与普通ConvNets相比,其性能表现更加优越,在ADE上实现了8.4 %的绝对提升,在COCO上实现了3.7 %的绝对提升。分析还表明,我们的推理框架对当前区域分割方法造成的区域缺失具有很强的适应性。

该框架引入了全局模块进行局域外的推理。在全局模块中,推理是基于图模型展开的。它有三个组成部分:

(a)一个知识图谱,我们把类当做结点,建立边来对它们之间不同类型的语义关系进行编码;

(b)一个当前图像的区域图,图中的区域是结点,区域间的空间关系是边;

(c)一个工作分配图,将区域分配给类别。利用这种结构的优势,我们开发了一个推理模型,专门用于在图中传递信息。局部模块和全局模块迭代工作,交叉互递预测结果来调整预期。

局部模块和全局模块不是分离的,对图像的深刻理解通常是先验的背景知识和对图像的具体观察间的折中。因此,我们用注意力机制联合两个模块,使模型在做最终预测时使用相关性最大的特征。

案例应用

TensorFlow的输入流水线

在训练模型时,我们首先要处理的就是训练数据的加载与预处理的问题,这里称这个过程为输入流水线。在TensorFlow中,典型的输入流水线包含三个流程(ETL流程):

1、提取(Extract):从存储介质(如硬盘)中读取数据,可能是本地读取,也可能是远程读取(比如在分布式存储系统HDFS)

2、预处理(Transform):利用CPU处理器解析和预处理提取的数据,如图像解压缩,数据扩增或者变换,然后会做random shuffle,并形成batch。

3、加载(load):将预处理后的数据加载到加速设备中(如GPUs)来执行模型的训练。

采用feedable Iterator来实现mnist数据集的训练过程,分别创建两个Dataset,一个为训练集,一个为验证集,对于验证集不需要shuffle操作。首先我们创建Dataset对象的辅助函数,主要是解析TFRecords文件,并对image做归一化处理:

defdecode(serialized_example):

"""decode the serialized example"""

features =tf.parse_single_example(serialized_example,

features={"image":tf.FixedLenFeature([],tf.string), "label": tf.FixedLenFeature([], tf.int64)})

image = tf.decode_raw(features["image"], tf.uint8)

image = tf.cast(image,tf.float32)

image =tf.reshape(image, [784])

label =tf.cast(features["label"], tf.int64) return image, label defnormalize(image, label):

"""normalize the image to [-0.5, 0.5]"""

image = image / 255.0 - 0.5

return image, label

然后定义创建Dataset的函数,对于训练集和验证集,两者的参数会不同:

defcreate_dataset(filename, batch_size=64,is_shuffle=False, n_repeats=0):

"""create dataset for train and validationdataset"""

dataset =tf.data.TFRecordDataset(filename) ifn_repeats>0:

dataset= dataset.repeat(n_repeats) # for train

dataset =dataset.map(decode).map(normalize) # decode and normalize

ifis_shuffle:

dataset= dataset.shuffle(1000 +3 * batch_size) # shuffle

dataset = dataset.batch(batch_size)

我们使用一个简单的全连接层网络来实现mnist的分类模型:

defmodel(inputs,hidden_sizes=(500, 500)):

h1, h2 = hidden_sizes

net =tf.layers.dense(inputs, h1, activation=tf.nn.relu)

net =tf.layers.dense(net, h2, activation=tf.nn.relu)

net = tf.layers.dense(net, 10,activation=None) return net

训练的主体代码

n_train_examples=55000n_val_examples=5000n_epochs=50batch_size=64train_dataset= create_dataset("TFRecords/train.tfrecords",batch_size=batch_size, is_shuffle=True,n_repeats=n_epochs)

val_dataset = create_dataset("TFRecords/validation.tfrecords", batch_size=batch_size) # 创建一个feedable iterator handle = tf.placeholder(tf.string, [])

feed_iterator = tf.data.Iterator.from_string_handle(handle,train_dataset.output_types,

train_dataset.output_shapes)

images, labels = feed_iterator.get_next() # 创建不同的iterator train_iterator =train_dataset.make_one_shot_iterator()

val_iterator = val_dataset.make_initializable_iterator() # 创建模型 logits = model(images, [500, 500])

loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits)

loss = tf.reduce_mean(loss)

train_op = tf.train.AdamOptimizer(learning_rate=1e-04).minimize(loss)

predictions = tf.argmax(logits, axis=1)

accuracy = tf.reduce_mean(tf.cast(tf.equal(predictions, labels),tf.float32))

init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) withtf.Session()assess:

sess.run(init_op) # 生成对应的handle

train_handle =sess.run(train_iterator.string_handle())

val_handle =sess.run(val_iterator.string_handle()) # 训练

for n inrange(n_epochs):

ls = [] foriin range(n_train_examples // batch_size):

_, l = sess.run([train_op, loss], feed_dict={handle: train_handle})

ls.append(l)

print("Epoch %d, train loss: %f" %(n, np.mean(ls))) if (n + 1) % 10 == 0:

sess.run(val_iterator.initializer)

accs = [] foriin range(n_val_examples //batch_size):

acc = sess.run(accuracy, feed_dict={handle: val_handle})

accs.append(acc)

print("\t validation accuracy: %f"% (np.mean(accs)))

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-04-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 决策智能与机器学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档