首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Tensorflow学习(3)——迁移学习

目前,并不是每一种机器学习任务都具有大样本,因此如何在小样本上训练一个可靠的模型,也是机器学习领域的一个重点方向。而迁移学习在很多领域取得很大的成功,尤其是基于网络的迁移学习。

迁移学习是指在训练新的目标任务时,并不是从无到有的训练,而是将源任务学习到的知识迁移到目标任务上,进而再训练。其中最常见的就是基于网络的迁移学习,即将一个在源任务中训练得到的网络,我们只需要修改其网络顶层的全连接层结构,重新训练参数,而对于其他层的参数只需要进行微调即可。因此这里我尝试通过迁移学习,来对人类蛋白图谱图像进行分类(来自kaggle的一个比赛)。

这里使用tensorflow中一个库slim,这个库可以简化我们定义,训练,评价一个复杂模型的工作。这里我们迁移的是resnet101层的网络,首先定义网络结构。

这里resnet_v1.resnet_v1_101直接定义了整个网络结构,由于这里是一个多标签的任务,因此我使用sigmoid交叉熵定义代价函数。接下来我们需要定义从训练好的模型中加载哪些参数。

exclusions=['resnet_v1/Logits','resnet_v1/AuxLogits']

variables_to_restore=[]

forvarinslim.get_model_variables():

forexclusioninexclusions:

ifvar.op.name.startswith(exclusion):

break

else:

variables_to_restore.append(var)

这里,因为最后一层需要基于自己任务训练,因此选择不加载这些参数。加下来只需要初始化参数不断循环执行train_op这个操作就好。

withtf.Session()assess:

sess.run(tf.global_variables_initializer())

slim.assign_from_checkpoint_fn(checkpoint_path,variables_to_restore,ignore_missing_vars=True)

foriinrange(,1942):

ifi%20==:

print(i)

img_batch,label_batch=dataset.next()

sess.run(train_op,feed_dict={x:img_batch,y:label_batch})

saver=tf.train.Saver()

saver.save(sess,'models/mymodel.ckpt')

这里首先通过global_variables_initializer初始化全局参数,在通过slim.assign_from_checkpoint_fn加载全连接层之前的参数。接下来就是测试,我们这里,先对logits进行sigmoid计算,定义值大于0.5则为预测的标签。然后将测试样本的Id和预测结果写入到csv文件中。

net=tf.sigmoid(net)

#Test

test_data=pd.read_csv('../input/sample_submission.csv')

pred_list=[]

withtf.Session()assess:

saver=tf.train.Saver()

saver.restore(sess,tf.train.latest_checkpoint('./models'))

foridintest_data['Id']:

img=open_test_rgby(id)

img=img[np.newaxis,:]

predict=sess.run(net,feed_dict={x:img})

predict=np.squeeze(predict)

label=np.argwhere(predict>0.5)

iflabel.size==:

label=np.argwhere(predict==np.max(predict))

label=' '.join([str(i)foriinlabel])

pred_list.append(label)

df=pd.DataFrame({'Id':test_data['Id'],'Predicted':pred_list})

df.to_csv('protein_classification.csv',header=True,index=False)

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181121G1ZFQA00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券