前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch学习笔记(二十): ignite (写更少的代码训练模型)

pytorch学习笔记(二十): ignite (写更少的代码训练模型)

作者头像
ke1th
发布2019-05-26 12:24:07
1.6K0
发布2019-05-26 12:24:07
举报

最近自己想写一个高级一点的抽象来更方便的训练 pytorch 网络, 无意间发现, pytorch 用户下面有个 ignite repo, 好奇就看了一下这是个什么东西. 原来是 pytorch 已经提供了一个高级抽象库来训练 pytorch模型了, 既然有了轮子, 那就没必要自己造了, 好好用着就行了. 没事读读源码, 也可以学习一下大佬们是怎么抽象的. 本博文主要是对 ignite 做一个宏观上的介绍.

官方文档

例子

为了减少源码篇幅, 特地将与 ignite 关系不大的代码给删除了, 如果想跑完整示例的话, 可以查看上面提到的链接.

代码语言:javascript
复制
from argparse import ArgumentParser
from torch import nn
from torch.optim import SGD
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engines import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss

def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    cuda = torch.cuda.is_available()
    device = torch.device("gpu") if torch.cuda.is_avaliable() else ("cpu")
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)

    model = Net()
    
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': CategoricalAccuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, iter, len(train_loader), engine.state.output))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        metrics = evaluator.run(val_loader).metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll))

    trainer.run(train_loader, max_epochs=epochs)

先对流程做一下总结, 再看API做了些什么

  • 创建模型, 创建 Dataloader
  • 创建 trainer
  • 创建 evaluator
  • 为一些事件注册事件处理函数, @trainer.on()
  • trainer.run()

Event

代码语言:javascript
复制
"""
类似枚举类, 定义了几个事件
"""
class Events(Enum):
    EPOCH_STARTED = "epoch_started"               # 当一个新的 epoch 开始时会触发此事件
    EPOCH_COMPLETED = "epoch_completed"           # 当一个 epoch 结束时, 会触发此事件
    STARTED = "started"                           # 开始训练模型是, 会触发此事件
    COMPLETED = "completed"                       # 当训练结束时, 会触发此事件
    ITERATION_STARTED = "iteration_started"       # 当一个 iteration 开始时, 会触发此事件
    ITERATION_COMPLETED = "iteration_completed"   # 当一个 iteration 结束时, 会触发此事件
    EXCEPTION_RAISED = "exception_raised"         # 当有异常发生时, 会触发此事件

State

代码语言:javascript
复制
class State(object):
    def __init__(self, **kwargs):
        self.iteration = 0            # 记录 当前的 iteration
        self.output = None            # 当前 iteration 的 输出. 对于 Supervised Trainer 来说, 是 loss.
        self.batch = None             # 本次 iteration 的 mini-batch 样本
        for k, v in kwargs.items():   # 其它一些希望 State 记录下来的 状态
            setattr(self, k, v)

Engine

代码语言:javascript
复制
def __init__(self, process_function):
	pass 

"""
对于 训练过程 来说,process_function 是一个 前向+反向+参数更新 过程
process_function 的 signature 是 func(batch)->anything
def func(self, batch): # batch会保存在 state.batch 中, self用来接收当前对象
	1. process batch
	2. forward compution
	3. compute loss
	4. computer gradient
	5. update parameters
	6. return loss or else # 返回的值会被保存在 state.output 中

对于 评估过程来说, process_function 是一个 前向+计算 metrics 的过程。
def func(self, batch): # batch会保存在 state.batch 中
	1. process batch
	2. forward compution
	3. return something # 返回的值会被保存在 state.output 中,
	#  用来计算 Metric
"""


""" 为某事件注册函数, 当事件发生时, 此函数就会被调用
函数的 signature 必须是 def func(trainer, state)
"""
@engine.on(...)
def some_func(trainer):
    pass

Engine.run() # 训练/评估 模型

Metric

定义了一些模型评估标准

  • 在创建 evaluator 的时候会指定一些 metric,这些metric 会由框架自动注册成为 evaluator 的 event handler,会在相应的事件发生时调用相应的处理函数
代码语言:javascript
复制
# reset : epoch 开始之前调用一次
# update: 每次iteration 结束时调用
# compute : epoch 结束时调用

工厂方法,用来创建 Engine

可以根据自己的需要改写以下两个函数

create_supervised_trainer

代码语言:javascript
复制
def create_supervised_trainer(model, optimizer, loss_fn, cuda=False):
    """
    Factory function for creating a trainer for supervised models

    Args:
        model (torch.nn.Module): the model to train
        optimizer (torch.optim.Optimizer): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Trainer: a trainer instance with supervised update function
    """

create_supervised_evaluator

代码语言:javascript
复制
def create_supervised_evaluator(model, metrics={}, cuda=False):
    """
    Factory function for creating an evaluator for supervised models

    Args:
        model (torch.nn.Module): the model to train
        metrics (dict of str: Metric): a map of metric names to Metrics
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Evaluator: a evaluator instance with supervised inference function
    """
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年03月16日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 例子
  • 工厂方法,用来创建 Engine
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档