前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习实战篇之 ( 十) -- TensorFlow学习之路(七)

深度学习实战篇之 ( 十) -- TensorFlow学习之路(七)

作者头像
用户5410712
发布2022-06-01 20:11:16
2600
发布2022-06-01 20:11:16
举报
文章被收录于专栏:居士说AI居士说AI

知识之窗

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。

2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。

PyTorch的前身是Torch,其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接口。它是由Torch7团队开发,是一个以Python优先的深度学习框架,不仅能够实现强大的GPU加速,同时还支持动态神经网络。

PyTorch既可以看作加入了GPU支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络。除了Facebook外,它已经被Twitter、CMU和Salesforce等机构采用。

回顾

在上周的文章中, 我们学习了整合所有的代码(数据预处理,网络模型,训练代码),然后进行了实际的训练,我们必须知道,神经网络的训练结果小除了知道模型的好坏以及有效性以外,我们还需要考虑将训练好的模型进行实际的测试,也需要后期需要用来部署为应用也说不定,当然不会直接部署,还需呀考虑优化,压缩,剪枝等问题。

一、模型预测

实现步骤:

1.在训练过程中保存模型

2.编写测试代码(数据处理,模型调用,数据测试)

4.输出模型结果,映射为真实标签

1.训练过程中保存模型

代码语言:javascript
复制
#在训练之前添加
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()

在每训练一个batch后,开始整个验证集的测试(现在一般训练一个epoch后,才进行验证),验证集测试后,如果大于上一次的测试准确率并且大于80%才考虑保存模型,即最终保存最好的模型。

代码语言:javascript
复制
 if avg_test_acc > pre_test_acc and avg_test_acc > 0.80:
checkpoint_path = os.path.join(logs_checkpoint,
 'model.ckpt')
saver.save(sess,

2.测试代码

1.数据预处理:

这个地方与训练的时候一样

代码语言:javascript
复制
# 获取一张图片
def get_one_image(img_dir):
    # 输入参数:train,训练图片的路径
    # 返回参数:image,从训练图片中随机抽取一张图片
    #print("train", train)
    #n = len(train)
    #ind = np.random.randint(0, n)
    #img_dir = train[ind]  # 随机选择测试的图片
    # img_dir = train

    img = Image.open(img_dir)
    #plt.imshow(img)
    #imag = img.resize([150, 150])  # 由于图片在预处理阶段以及resize,因此该命令可略
    imge = tf.image.resize_images(img, (150, 150))
    image = tf.reshape(imge, [1, 150, 150, 3])
    #image = np.array(imge)

    image = image/255
    image = tf.cast(image, tf.float32)

    return image

2.模型调用

其实就是回复保存模型的参数后导入到现在的网络中,进行测试。

现在的网络只进行前向传播,不进行反向传播。

代码语言:javascript
复制
saver = tf.train.Saver()

with tf.Session() as sess:
img_array = sess.run(image_array)

print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('Loading success, global_step is %s' % global_step)
else:

3.数据测试

代码语言:javascript
复制
# 测试图片
def evaluate_one_image(image_array):
global graph
graph = tf.get_default_graph()
with graph.as_default():
  BATCH_SIZE = 1
  N_CLASSES = 2
  #image = tf.cast(image_array, tf.float32)

  x = tf.placeholder(tf.float32, shape=[1,150, 150, 3])

  logit = model.inference(x, BATCH_SIZE, N_CLASSES,1)

  logit = tf.nn.softmax(logit)

4.输出结果:

代码语言:javascript
复制
prediction = sess.run(logit,feed_dict={x: img_array})
max_index = np.argmax(prediction)
# print(max_index)
# 标签映射可以选择字典或者列表
label_dict = {0: 'cat', 1: 'dog'}
label_list = ['cat','dog']
print("模型的输出为{},对应的真实标签为:{}".format(max_index,label_list[max_index]))

全部的测试代码:

实际预测展示

可以看到我们读取的是测试中的dog的图片,随后网络的预测标签是1,当初给dog的标签为1,即映射实际标签为dog,预测正确。

结语

本次分享结束了,算是图像分类项目的一个完整流程的项目,从数据处理到网络搭建,到训练,到调用模型做预测,我们都进行了分享,同时对代码细节进行了注释,相信聪敏的你一定可以看懂,如有疑惑请随时后台哦。

虽然本次项目结束,但我相信,其中或多或少有些地方大家不太理解,不管数据处理还是网络的搭建等等都或许不是那么简单,没关系,下次,小编会针对本次项目中的漏洞进行一个总结,算是图像分类项目的总结篇吧,同时也欢迎各位老铁,多多提问,以促使我们一起进步。

周末愉快,我们下期再见!

编辑:玥怡居士|审核:小圈圈居士

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-06-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 IT进阶之旅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
NLP 服务
NLP 服务(Natural Language Process,NLP)深度整合了腾讯内部的 NLP 技术,提供多项智能文本处理和文本生成能力,包括词法分析、相似词召回、词相似度、句子相似度、文本润色、句子纠错、文本补全、句子生成等。满足各行业的文本智能需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档