前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一个简单的Trainer项目(一)

一个简单的Trainer项目(一)

作者头像
带萝卜
发布2022-04-29 14:42:16
2640
发布2022-04-29 14:42:16
举报

整理Trainer的目的就是为了在偷懒的同时减少返工的可能,有一个好的trainer可以给我们省出不少喝茶的时间。

那么一个Trainer应该由哪些功能呢?

我认为主要有如下几个方面:

  1. 超参可配,将所有超参提取出来,使用配置文件进行配置,训练时只修改配置文件,不修改代码;
  2. 中间结果可查看,要留出调试接口,避免在调试的时候改动核心代码;
  3. 模块化,模块清晰明了,且相互不干扰。
  4. 日志可溯源,实验做多了可能喝口水就忘了刚才提交的任务是什么配置,所以训练日志里面要有尽量详细的信息;

一个基于PyTorch的Trainer由以下部分构成:

  1. 主流程

即训练、验证及模型推理调试的流程,包括forward, backward, optimizer, LRscheduler, 模型存取以及多机多卡训练等机制。

2. 数据加载

PyTorch中一般用自定义DataLoader来实现,其中包含数据增强、Sampler数据采样、collate_fn数据分批等,如果是lmdb数据,还需要用到worker_init_fn

3. 函数库

包括评价函数,loss函数等等

4. 网络模块库

包括网络中的各种layer、block、module等

5. 模型

使用各种网络模块和函数搭建起的网络,输入数据,输出loss及acc, 保证统一接口。

6. 配置文件

将包括模型在内的所有参数都提出来,写在配置文件里面,一般用cfg或者yaml,也可以利用argparse库以命令行参数的形式实现。

7. 日志模块

包括日志文件写入、屏幕输出等,包含时间和配置信息,也可以嵌入tensorboard做展示用。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2022-04-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档