调优哪家强——tensorflow命令行参数

深度学习神经网络往往有过多的Hyperparameter需要调优,优化算法、学习率、卷积核尺寸等很多参数都需要不断调整,使用命令行参数是非常方便的。有两种实现方式,一是利用python的argparse包,二是调用tensorflow自带的app.flags实现。

利用python的argparse包

argparse介绍及基本使用:

http://www.jianshu.com/p/b8b09084bd1a

下面代码用argparse实现了命令行参数的输入。

import argparse
import sys
parser = argparse.ArgumentParser()
parser.add_argument('--fake_data', nargs='?', const=True, type=bool,                       
default=False,                       
help='If true, uses fake data for unit testing.')
parser.add_argument('--max_steps', type=int, default=1000,                       
help='Number of steps to run trainer.')
parser.add_argument('--learning_rate', type=float, default=0.001,                       
help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.9,                       
help='Keep probability for training dropout.')
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',                       help='Directory for storing input data') parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',                       
help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

通过调用python的argparse包,调用函数parser.parse_known_args()解析命令行参数。代码运行后得到的FLAGS是一个结构体,内部参数分别为:

FLAGS.data_dir
Out[5]: '/tmp/tensorflow/mnist/input_data'
 FLAGS.fake_data Out[6]: False  FLAGS.max_steps
Out[7]: 1000
 FLAGS.learning_rate
Out[8]: 0.001
 FLAGS.dropout
Out[9]: 0.9
 FLAGS.data_dir
Out[10]: '/tmp/tensorflow/mnist/input_data'
 FLAGS.log_dir
Out[11]: '/tmp/tensorflow/mnist/logs/mnist_with_summaries'

利用tf.app.flags组件

首先需要定义一个tf.app.flags对象,调用自带的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float设置不同类型的命令行参数及其默认值。当然,也可以在终端用命令行参数修改这些默认值。

# Define hyperparameters
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean("enable_colored_log", False, "Enable colored log")                     
"The glob pattern of train TFRecords files")
flags.DEFINE_string("validate_tfrecords_file",                     
"./data/a8a/a8a_test.libsvm.tfrecords",     
"The glob pattern of validate TFRecords files")
flags.DEFINE_integer("label_size", 2, "Number of label size")
flags.DEFINE_float("learning_rate", 0.01, "The learning rate")
 def main():    
 # Get hyperparameters     
if FLAGS.enable_colored_log:         
import coloredlogs         
coloredlogs.install()     
logging.basicConfig(level=logging.INFO)     
FEATURE_SIZE = FLAGS.feature_size     
LABEL_SIZE = FLAGS.label_size       
...   
return 0
if __name__ == ‘__main__’:     main()

这段代码采用的是tensorflow库中自带的tf.app.flags模块实现命令行参数的解析。如果用终端运行tf程序,用上述两种方式都可以,如果用spyder之类的工具,那么只有第一种方式有用,第二种方式会报错。

其中有个tf.app.flags组件,还有个tf.app.run()函数。官网帮助文件是这么说的:

flags module: Implementation of the flags interface.
run(...): Runs the program with an optional 'main' function and 'argv' list.

tf.app.run的源代码:

1."""Generic entry point script."""   
2.from __future__ import absolute_import   
3.from __future__ import division   
4.from __future__ import print_function   
5.   
6.import sys   
7.   
8.from tensorflow.python.platform import flags   
9.   
10.   
11.def run(main=None):   
12.  f = flags.FLAGS   
13.  f._parse_flags()   
14.  main = main or sys.modules['__main__'].main   
15.  sys.exit(main(sys.argv))

也就是处理flag解析,然后执行main函数。

用shell脚本实现训练代码的执行

在终端执行python代码,首先需要在代码文件开头写入shebang,告诉系统环境变量如何设置,用python2还是用python3来编译这段代码。然后修改代码权限为可执行,用 ./python_code.py 就可以执行。同理,这段代码也可以用shell脚本来实现。创建.sh文件,运行python_code.py并设置参数max_steps=100

python python_code.py --max_steps 100

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-10-31

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏张善友的专栏

.NET Core系列 :4 测试

2016.6.27 微软已经正式发布了.NET Core 1.0 RTM,但是工具链还是预览版,同样的大量的开源测试库也都是至少发布了Alpha测试版支持.NE...

19010
来自专栏我和未来有约会

如何在silverlihgt中使用右键

一般我们在silverlight中点击右键会出现如下的对话筐. ? ? 在flash中 其提供了一个可定制话的右键菜单系统.(ContextMenu) 这个...

1827
来自专栏草根专栏

使用xUnit为.net core程序进行单元测试(上)

纵轴表示测试的深度,也就是说测试的细致程度。

3309
来自专栏xingoo, 一个梦想做发明家的程序员

【插件开发】—— 13 GEF双击模型事件

前文回顾: 1 插件学习篇 2 简单的建立插件工程以及模型文件分析 3 利用扩展点,开发透视图 4 SWT编程须知 5 SWT简单控件的使用与...

1908
来自专栏谭伟华的专栏

文件上传那些事儿

最近把产品目前使用的FileUploader从老的组件库分离出来的,自己也查阅了相关的各种资料,对文件上传的这些事有了更进一步的了解。

5.7K4
来自专栏听雨堂

Repeater,DataList,DataGrid

   输出表:    string a="Provider=Microsoft.Jet.OLEDB.4.0;Data Source=c:\\data.mdb;...

1699
来自专栏GuZhenYin

使用localResizeIMG3+WebAPI实现手机端图片上传

前言 惯例~惯例~昨天发表的使用OWIN作为WebAPI的宿主..嗯..有很多人问..是不是缺少了什么 - - 好吧,如果你要把OWIN寄宿在其他的地方...代...

1908
来自专栏james大数据架构

在ASPNET中使用JS集锦

(一).确认删除用法: 1. BtnDel.Attributes.Add("onclick","return confirm('"+"确认删除?"+"')")...

1837
来自专栏逸鹏说道

模块式开发

这两天看到同事的一个小工具,用的是模块式开发,也就是俗称的插件开发,用的是反射+接口的方式实现的。感觉挺好的,也就学习了一下,写个小Demo,在此记录下。 一、...

3156
来自专栏技术之路

【权限的思考】(一)使用反射实现动态权限

  每一个业务系统都会根据业务需要配置各种各样的权限,实现方式也是千差万别,各有各的优缺点。今天我们 利用反射来做一个小的权限管理Demo。也可以说是插件化的权...

1919

扫码关注云+社区