首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >设置迭代次数gpt-2

设置迭代次数gpt-2
EN

Stack Overflow用户
提问于 2019-09-04 06:09:29
回答 1查看 705关注 0票数 0

按照本教程,我正在微调gpt-2模型:

https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f

与其关联的GitHub存储库:

https://github.com/nshepperd/gpt-2

我已经能够复制这些例子,我的问题是我没有找到一个参数来设置迭代次数。基本上,培训脚本每100个迭代显示一个示例,每1000个迭代保存一个模型版本。但是我没有找到一个参数来训练它,比如说,5000次迭代,然后关闭它。

用于培训的脚本如下:https://github.com/nshepperd/gpt-2/blob/finetuning/train.py

编辑:

正如cronoik所建议的那样,我正在尝试将while替换为for循环。

我要添加这些更改:

  1. 增加一个额外的论点: Parser.add_argument(‘-训练步骤“,metavar=' steps ',type=int,default=1000,help=’一个代表模型要训练多少个训练步骤‘的数字’)
  2. 更改循环: try: for iter_count in range(training_steps):如果计数器% args.save_every == 0: save()
  3. 使用新的论点: python3 train.py --训练步骤300

但我发现了一个错误:

代码语言:javascript
运行
复制
  File "train.py", line 259, in main
    for iter_count in range(training_steps):
NameError: name 'training_steps' is not defined
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-07 02:11:11

您所要做的就是将while True循环修改为for循环:

代码语言:javascript
运行
复制
try:
    #replaced
    #while True:
    for i in range(5000):
        if counter % args.save_every == 0:
            save()
        if counter % args.sample_every == 0:
            generate_samples()
        if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
            validation()

        if args.accumulate_gradients > 1:
            sess.run(opt_reset)
            for _ in range(args.accumulate_gradients):
                sess.run(
                    opt_compute, feed_dict={context: sample_batch()})
            (v_loss, v_summary) = sess.run((opt_apply, summaries))
        else:
            (_, v_loss, v_summary) = sess.run(
                (opt_apply, loss, summaries),
                feed_dict={context: sample_batch()})

        summary_log.add_summary(v_summary, counter)

        avg_loss = (avg_loss[0] * 0.99 + v_loss,
                    avg_loss[1] * 0.99 + 1.0)

        print(
            '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
            .format(
                counter=counter,
                time=time.time() - start_time,
                loss=v_loss,
                avg=avg_loss[0] / avg_loss[1]))

        counter += 1
except KeyboardInterrupt:
    print('interrupted')
    save()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57782409

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档