对 keras : tensorflow https://github.com/zdx3578/DeepLearningImplementations/tree/master/WassersteinGAN 代码进行了运行测试,及环境配置等
内容目录:
celebA人脸数据集训练效果
mnist 数字训练学习效果
环境搭建要点。
训练显示训练过程的确很稳定,很快出现可识别有意义的图像。
celebA 人脸数据集训练
下面两行是标准照片。
loss:
mnist:
效果:
loss:
一个epoch内的训练loss下降:
epoch0
epoch1
一 环境准备:
主机aws云,镜像采样之前文章介绍的镜像ami-97ba3a80,已经安装好tensoflow及GPU配置等。
环境碰到需要安装:
#source activate tensorflow进入anaconda python环境:
conda install opencv
conda install -c yikelu parmap=1.2.0
conda install pydot keras(如果执行这一步会安装cpu版本tensorflow!需重新安装0.9GPU版本tf加快训练速度)tensorflow 版本重回此镜像的0.9GPU版本ok。
conda install natsort
tensorflow安装:https://www.tensorflow.org/get_started/os_setup#anaconda_installation
测试tensorlfow版本:python -c "import tensorflow;print(tensorflow.__version__)"
如果跑celebA数据集需要64G内存,因为有一个数据一次性的计算操作未优化。
二 数据准备:参考https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/WassersteinGAN
python make_dataset.py --img_size 64 /home/ubuntu/celeba/img_align_celeba 生成H5数据文件。
三 训练:
python -c "import tensorflow;print(tensorflow.__version__)" 确认使用GPU版本tensorflow;确认有相关cuda信息输出。
Only the CPU version of TensorFlow is available at the moment and can be installed in the conda environment for Python 2 or Python 3. 所有需使用镜像默认GPU版本或手动安装0.9GPU版本tensorflow。
训练脚本命令:
celebA
python main.py --backend tensorflow --generator deconv --dset celebA --img_dim 64
mnist:
python main.py --backend tensorflow --generator deconv
更多参考官方代码readme。