首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >是否有优化SpaCy培训的方法?

是否有优化SpaCy培训的方法?
EN

Stack Overflow用户
提问于 2020-09-04 20:08:58
回答 1查看 920关注 0票数 1

我目前正在训练一个多标签文本分类的SpaCy模型。有6个标签:愤怒、期待、厌恶、恐惧、joy、悲伤、惊讶和信任。数据集超过200 K。然而,每一时期都需要4个小时。我想知道是否有一种方法可以优化训练并加快速度,也许我在这里跳过了一些可以改进模型的东西。

TRAINING_DATA

代码语言:javascript
运行
复制
TRAIN_DATA = list(zip(train_texts, [{"cats": cats} for cats in final_train_cats]))

[...
  {'cats': {'anger': 1,
    'anticipation': 0,
    'disgust': 0,
    'fear': 0,
    'joy': 0,
    'sadness': 0,
    'surprise': 0,
    'trust': 0}}),
 ('mausoleum',
  {'cats': {'anger': 1,
    'anticipation': 0,
    'disgust': 0,
    'fear': 0,
    'joy': 0,
    'sadness': 0,
    'surprise': 0,
    'trust': 0}}),
 ...]

训练

代码语言:javascript
运行
复制
nlp = spacy.load("en_core_web_sm")
category = nlp.create_pipe("textcat", config={"exclusive_classes": True})
nlp.add_pipe(category)

# add label to text classifier
category.add_label("trust")
category.add_label("fear")
category.add_label("disgust")
category.add_label("surprise")
category.add_label("anticipation")
category.add_label("anger")
category.add_label("joy")

optimizer = nlp.begin_training()
losses = {}

for i in range(100):
    random.shuffle(TRAIN_DATA)

    print('...')
    for batch in minibatch(TRAIN_DATA, size=8):
        texts = [nlp(text) for text, entities in batch]
        annotations = [{"cats": entities} for text, entities in batch]
        nlp.update(texts, annotations, sgd=optimizer, losses=losses)
    print(i, losses)

...
0 {'parser': 0.0, 'tagger': 27.018985521040854, 'textcat': 0.0, 'ner': 0.0}
...
1 {'parser': 0.0, 'tagger': 27.01898552104131, 'textcat': 0.0, 'ner': 0.0}
...
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-09-06 03:27:54

“200 K记录数据每一时代需要4小时”并没有告诉我们多少:

  1. 确保你没有把记忆吹灭(是吗?)这需要多少内存?
  2. 您大概是在运行单线程,这是由于GIL。参见关于如何关闭GIL以运行多核训练的。你有多少个核心?
  • 在您的内环texts = [nlp(text) ...] for batch in minibatch(TRAIN_DATA, size=8): 中放置看起来很麻烦,因为您的代码将始终保持GIL,即使您只需要它来处理输入文本的C库字符串调用,即 parser 阶段,而不是用于培训。
  • 重构您的代码,以便您首先在所有输入上运行nlp()管道,然后保存一些中间表示形式(数组或其他什么)。将这些代码分离到您的培训循环中,这样培训就可以被多线程化了。
  1. 我不能评论您对minibatch()参数的选择,但是8看起来很小,而且这些参数似乎对性能很重要,所以尝试调整它们(/grid-搜索几个值)。
  2. 最后,一旦您首先检查了以上所有内容,就可以找到最快的单芯/多核盒,并且有足够的RAM。
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63747454

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档