专栏首页深度学习那些事儿提升python项目完成效率的调试方法技巧(上)

提升python项目完成效率的调试方法技巧(上)

效率提升是极为重要的事情,我们的时间本来就不充裕,不应该过多将时间浪费在调试过程中。对于大型项目光有dubug是不够的,如果需要提高产品调试进度,必须需要采取一些其他的方法,这里打算利用python的特性以及一些规范方法来说明一下,当然类似的方法不仅仅可以用于python,其他编程语言也是通用的。

注意,我们要吸收的不是方法,而是思想。

使用assert断言来判断代码是否正常运行

assert断言从C语言中就已经有了,在C++中也经常使用。当然python语言也继承了这个特性。总体来说assert就是首先假设一个语句没有问题,如果有问题则抛出Assert异常。

断言一般用于在调试中使用,也就是用于debug代码的时候。个人建议写在一眼看不出来是否正确需要仔细检查的部分,比如:

def style_hook(self, module, grad_input, grad_output):
    self.mask = self.mask[:, 0:1, :, :]
    # 这段代码对grad_input[0]和self.mask的shape进行了匹配,当然如果输入到这个函数时,前两者变量就是错
    # 误的那么这段代码无论如何也不可能正确执行,因此在调试的时候在这个项目的其他代码中对这两个变量进行了
    # 修改从而使这段代码正确执行,修改完后,断言部分可注释掉,但是如果保留可以预防其他数据导致的bug
    assert grad_input[0].shape == self.mask.shape, \
        'grad_input:{} is not matchable with mask:{}'.format(grad_input[0].shape, self.mask.shape)

    grad_input_1 = grad_input[0].div(torch.norm(grad_input[0], 1) + 1e-8)
    grad_input_1 = grad_input_1 * self.weight
    grad_input_1 = grad_input_1 * self.mask
    grad_input = tuple([grad_input_1, grad_input[1], grad_input[2]])

    return grad_input

当然Assert断言操作也不是越多越好,容易造成程序的混乱和效率问题,总之我们可以在一些重要的数据,需要作注释来检查的地方使用断言进行替代,可以有效对代码进行调试。

更加详细阅读:http://blog.jobbole.com/76285/

使用tqdm来代替print显示结果

tqdm是一个比python自带的原生的print打印函数打印信息更好的打印进度工具。

下面的动图来源自官网。

看起来比我们普通使用print打印更加“好看一些”,其实使用这个工具主要作用在于时间预测和速度估计,在训练的时候可以自己设定一个速度的标准,然后在训练的时候可以通过观察速度变化做出一些判断,或者根据估计出来的需要时间来估计训练所需要的时间。

项目地址:https://github.com/tqdm/tqdm

在训练中保存记录log

保存Log记录是很重要的,我们不应当只在训练中实时观察输出结果,也应该在训练过程中对中间数据进行统计。一般是将每一次训练的时间、训练参数、以及训练结果打包放到一个目录下,目录的名称根据训练参数和结果而定,保证之后的查阅方便以及便于分析:

# 下方的代码返回一个我们要生成目录的名称,并且保存了参数信息
def get_log_dir(model_name, config_id, cfg):
    # load config
    name = 'MODEL-%s_CFG-%03d' % (model_name, config_id)
    for k, v in cfg.items():
        v = str(v)
        if '/' in v:
            continue
        name += '_%s-%s' % (k.upper(), v)
    now = datetime.datetime.now(pytz.timezone('Asia/Shanghai'))
    name += '_VCS-%s' % git_hash()
    name += '_TIME-%s' % now.strftime('%Y%m%d-%H%M%S')
    # create out
    log_dir = osp.join(here, 'logs', name)
    if not osp.exists(log_dir):
        os.makedirs(log_dir)
    with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
        yaml.safe_dump(cfg, f, default_flow_style=False)
    return log_dir

参数使用字典或者list包含起来或者使用命令进行输入

如果训练参数数量多,建议将训练参数写成字典的形式:

configurations = {
    1: dict(
        max_iteration=100000,
        lr=1.0e-10,
        momentum=0.99,
        weight_decay=0.0005,
        interval_validate=4000,
    )
}

也可以使用命令行参数的功能单独对特定的参数进行调整:

parser.add_argument("-content_weight", type=int, default=8) 
parser.add_argument("-style_weight", type=int, default=2000)
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=8000)
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-init", default="random", choices=["random", "image"])
parser.add_argument("-init_image", help="initial image")
parser.add_argument("-optimizer", help="optimiser", default="lbfgs", choices=["lbfgs", "adam"])
parser.add_argument("-learning_rate", type=float, default=1e0)

# 使用时调用 args.content_weight args.init 即可

关于如何使用命令行参数:https://oldpan.me/archives/argparse-python-order-command

使用git进行版本控制

git不多说,版本控制神器,不管是公司中正式的项目代码还是你个人使用的练习项目,使用git的的好处都是不言而喻。

我们还可以通过使用python中的subprocess模块来自动执行git命令来获取当前的commot信息从而让我们对我们代码的版本和修改过程有一个直观的过程。

# 下面的代码返回当前一个commit的commit哈希码值和提交作者的信息
def git_hash():
    cmd = 'git log -n 1 --pretty="%h -%ar"'
    hash = subprocess.check_output(shlex.split(cmd)).strip()
    return hash

注意

好的习惯并不能直接帮助我们提高效率,需要我们在习惯这些技巧的过程中逐渐掌握一些窍门。另外,虽然这些方法需要我们投入时间去学习,但是如果熟练了这些技巧,对调试程序的好处是巨大的。

文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

Oldpan博客

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • pytorch中autograd以及hook函数详解

    有些公式为图片,如果这个页面加载不出来,请看这里:https://oldpan.me/archives/pytorch-autograd-hook

    OLDPAN
  • pytorch中autograd以及hook函数详解

    pytorch中的Autograd mechanics(自动求梯度机制)是实现前向以及后向反馈运算极为重要的一环,pytorch官方专门针对这个机制进行了一个版...

    OLDPAN
  • 随身GPU服务器:Kaggle中kernels的快速入门指南

    对于很多入门深度学习领域的小伙伴来说啊,拥有一款合适的显卡是必要的,只有拥有好的装备才能更好更快地进行神经网络的训练、调试网络结构、改善我们的代码,进而更快地产...

    OLDPAN
  • 如何用pytorch打印出隐藏层梯度

    我们在训练神经网络时, 有时会发现自己的网络学习不到东西,loss不下降或者下降很慢,这时除了检查修改学习率以外还有可能是碰见了梯度消失的问题。检...

    我爱自然语言处理
  • 奇数阶魔方阵

    印"魔方阵"。所谓"魔方阵"是指这样的方阵,它的每一行、每一列以及对角线之和均相等。例如,三阶魔方阵为: 8   1   6...

    lop
  • 分享一个shell脚本用于“基于已安装的rpm包,创建tar压缩包”

    有时候,我们不想在系统中安装rpm包,但是又需要这个rpm包的功能,那么可否实现呢? 这个需求是可以实现的,当我们用tar包去安装这个功能的时候就可以实现了....

    qsjs
  • 链表相交

    找到两个链表相交的起始节点,交点表示一个链表的结尾与另一个链表中的某个节点链接,形成Y形。如图所示:

    九州暮云
  • 你认可《后浪》吗

    前几天 B 站上线了一个小视频《后浪》,在全网引起了热烈反响,有赞扬也有批评,视频地址:https://www.bilibili.com/video/BV1FV...

    Python小二
  • 捷报再传!腾讯T-block拿下CDCC中国2016数据中心年度能效奖

    2016年11月16日星期三下午的北京,CDCC在中国数据中心年度论坛上,宣布腾讯T-block凭借出色的能效表现,从数十个申报项目中脱颖而出,获得2016数据...

    腾讯数据中心
  • 腾讯汤道生:与更多生态合作伙伴共建智能教育生态

    ? 5月22日上午,2019腾讯全球数字生态大会智慧教育分论坛在昆明滇池国际会展中心召开。腾讯高级执行副总裁、云与智慧产业事业群总裁汤道生以《共创智能教育时代...

    腾讯智慧教育

扫码关注云+社区

领取腾讯云代金券