前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习工程模板

深度学习工程模板

作者头像
机器学习AI算法工程
发布2019-10-28 16:09:17
5630
发布2019-10-28 16:09:17
举报
使用方式

下载工程

代码语言:javascript
复制
git clone https://github.com/SpikeKing/DL-Project-Template

创建和激活虚拟环境

代码语言:javascript
复制
virtualenv venv
source venv/bin/activate

安装Python依赖库

代码语言:javascript
复制
pip install -r requirements.txt

开发流程

  1. 定义自己的数据加载类,继承DataLoaderBase;
  2. 定义自己的网络结构类,继承ModelBase;
  3. 定义自己的模型训练类,继承TrainerBase;
  4. 定义自己的样本预测类,继承InferBase;
  5. 定义自己的配置文件,写入实验的相关参数;

执行训练模型和预测样本操作。

示例工程

识别MNIST库中手写数字,工程simple_mnist

训练:

代码语言:javascript
复制
python main_train.py -c configs/simple_mnist_config.json

预测:

代码语言:javascript
复制
python main_test.py -c configs/simple_mnist_config.json -m simple_m
nist.weights.10-0.24.hdf5

TensorBoard

工程架构

主要组件

DataLoader

操作步骤:

  1. 创建自己的加载数据类,继承DataLoaderBase基类;
  2. 覆写get_train_data()get_test_data(),返回训练和测试数据;

Model

操作步骤:

  1. 创建自己的网络结构类,继承ModelBase基类;
  2. 覆写build_model(),创建网络结构;
  3. 在构造器中,调用build_model()

注意:plot_model()支持绘制网络结构;

Trainer

操作步骤:

  1. 创建自己的训练类,继承TrainerBase基类;
  2. 参数:网络结构model、训练数据data;
  3. 覆写train(),fit数据,训练网络结构;

注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。

Infer

操作步骤:

  1. 创建自己的预测类,继承InferBase基类;
  2. 覆写load_model(),提供模型加载功能;
  3. 覆写predict(),提供样本预测功能;

Config

定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。

Main

训练:

  1. 创建配置文件config;
  2. 创建数据加载类dataloader;
  3. 创建网络结构类model;
  4. 创建训练类trainer,参数是训练和测试数据、模型;
  5. 执行训练类trainer的train();

预测:

  1. 创建配置文件config;
  2. 处理预测样本test;
  3. 创建预测类infer;
  4. 执行预测类infer的predict();

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-10-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习AI算法工程 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 示例工程
  • 工程架构
  • 主要组件
    • DataLoader
      • Model
        • Trainer
          • Infer
            • Config
              • Main
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档