深度学习工程模板:简化加载数据、构建网络、训练模型和预测样本的流程

使用方式

下载工程

创建和激活虚拟环境

安装Python依赖库

开发流程

定义自己的数据加载类,继承DataLoaderBase;

定义自己的网络结构类,继承ModelBase;

定义自己的模型训练类,继承TrainerBase;

定义自己的样本预测类,继承InferBase;

定义自己的配置文件,写入实验的相关参数;

执行训练模型和预测样本操作。

示例工程

识别MNIST库中手写数字,工程

训练:

预测:

网络结构

TensorBoard

工程架构

框架图

文件夹结构

主要组件

DataLoader

操作步骤:

创建自己的加载数据类,继承DataLoaderBase基类;

覆写和,返回训练和测试数据;

Model

操作步骤:

创建自己的网络结构类,继承ModelBase基类;

覆写,创建网络结构;

在构造器中,调用;

注意:支持绘制网络结构;

Trainer

操作步骤:

创建自己的训练类,继承TrainerBase基类;

参数:网络结构model、训练数据data;

覆写,fit数据,训练网络结构;

注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。

Infer

操作步骤:

创建自己的预测类,继承InferBase基类;

覆写,提供模型加载功能;

覆写,提供样本预测功能;

Config

定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。

Main

训练:

创建配置文件config;

创建数据加载类dataloader;

创建网络结构类model;

创建训练类trainer,参数是训练和测试数据、模型;

执行训练类trainer的train();

预测:

创建配置文件config;

处理预测样本test;

创建预测类infer;

执行预测类infer的predict();

原文:https://github.com/SpikeKing/DL-Project-Template

- 加入人工智能学院系统学习 -

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180505A0M9DX00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券