前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架

Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架

作者头像
AI研习社
发布2019-06-19 11:34:07
3.1K0
发布2019-06-19 11:34:07
举报
文章被收录于专栏:AI研习社AI研习社

Github项目地址:https://github.com/williamSYSU/TextGAN-PyTorch

TextGAN是一个用于生成基于GANs的文本生成模型的PyTorch框架。TextGAN是一个基准测试平台,支持基于GAN的文本生成模型的研究。由于大多数基于GAN的文本生成模型都是由Tensorflow实现的,TextGAN可以帮助那些习惯了PyTorch的人更快地进入文本生成领域。

目前,只有少数基于GAN的模型被实现,包括 SeqGAN (Yu et. al, 2017), LeakGAN (Guo et. al, 2018) 和 RelGAN (Nie et. al, 2018)。

环境要求

  • PyTorch >= 1.0.0
  • Python 3.6
  • Numpy 1.14.5
  • CUDA 7.5+ (For GPU)
  • nltk 3.4
  • tqdm 4.32.1

运行 pip install -r requirements.txt 即可安装。 如果出现了CUDA问题,请查看PyTorch官方的入门指南(https://pytorch.org/get-started/locally/)。

实现模型和原始论文

  • SeqGAN - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient https://arxiv.org/abs/1609.05473
  • LeakGAN - Long Text Generation via Adversarial Training with Leaked Information https://arxiv.org/abs/1709.08624
  • RelGAN - RelGAN: Relational Generative Adversarial Networks for Text Generation https://openreview.net/forum?id=rJedV3R5tm

入门

  • 开始
代码语言:javascript
复制
git clone 
cd TextGAN-PyTorch

对于真实数据实验,可以从下载Image COCO和EMNLP新闻数据集,下载链接:

https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing

  • 使用SeqGAN运行
代码语言:javascript
复制
cd run
python3 run_seqgan.py 0 0 # The first 0 is job_id, the second 0 is gpu_id
  • 使用LeakGAN运行
代码语言:javascript
复制
cd run
python3 run_leakgan.py 0 0
  • 使用RelGAN运行
代码语言:javascript
复制
cd run
python3 run_relgan.py 0 0

特点

1.Instructor

对于每个模型,整个运行过程在instructor/oracle_data/seqgan_instructor.py中定义。 (以合成数据实验中的SeqGAN为例)。 init_model()和optimize()等基本函数在instructor.py的基类BasicInstructor中定义。 如果要添加新的基于GAN的文本生成模型,请在Instructor/oracle_data下创建一个新的Instructor,并定义模型的训练过程。

2.可视化

使用utils/visualization.py可视化日志文件,包括模型丢失和度量标准分数。 在log_file_list中自定义日志文件,不超过 len(color_list)。 日志文件名应排除.txt。

3.日志记录

TextGAN-PyTorch使用Python中的logging(日志记录)模块来记录正在运行的进程,如生成器的丢失和度量标准分数。 为了便于可视化,将分别在log/log _****_ ****。txt和save/**/log.txt中保存两个相同的日志文件。 此外,代码将自动保存模型的状态字典和批量大小的生成器样本,每个日志步骤为./save/**/models和./save/**/samples,其中**取决于您的超级参数。

4.运行信号

你可以使用基于字典文件run_signal.txt的Signal类(请查看utils/helpers.py)轻松控制训练过程。

如果要使用Signal,只需编辑本地文件run_signal.txt并将pre_sig设置为Fasle,程序将停止预训练过程并进入下一个训练阶段。 如果你认为当前的训练已经足够,可以非常方便地提前停止训练。

5.自动选择GPU

在config.py中,程序会自动选择nvidia-smi中GPU-Util最少的GPU设备。 默认情况下启用此功能。 如果要手动选择GPU设备,请取消注释run_[run_model].py中的--device args并使用命令指定GPU设备。

TODO

  • 添加实验结果
  • 修复LeakGAN模型中的错误
  • 在instrutor/real_data中添加SeqGAN和LeakGAN的instructors
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-06-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档