首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何应用max_length来截断HuggingFace令牌程序中左边的令牌序列?

如何应用max_length来截断HuggingFace令牌程序中左边的令牌序列?
EN

Stack Overflow用户
提问于 2022-05-11 13:52:23
回答 2查看 844关注 0票数 1

在HuggingFace标记器中,应用max_length参数指定标记化文本的长度。我相信它通过从max_length-2右中剪切多余的令牌来截断序列到(如果是)。为了进行语音分类,我需要从left (即序列的开始)中删除多余的令牌,以保留最后的标记。我怎么能这么做?

代码语言:javascript
复制
from transformers import AutoTokenizer

train_texts = ['text 1', ...]
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
encodings = tokenizer(train_texts, max_length=128, truncation=True)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-06-29 08:47:38

托卡器有一个truncation_side参数,它应该精确地设置这个参数。见文档

票数 1
EN

Stack Overflow用户

发布于 2022-05-13 08:33:49

我写了一个解决方案,它不是很健壮。还在找更好的方法。这是用代码中提到的模型进行测试的。

代码语言:javascript
复制
from typing import Tuple
from transformers import AutoTokenizer

# also tested with: ufal/robeczech-base, Seznam/small-e-czech
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base', use_fast=False)
texts = ["Do not meddle in the affairs of wizards for they are unpredictable.", "Did you meddle?"]
encoded_input = tokenizer(texts)


def cut_seq_left(seq: list, max_length: int, special_ids: dict) -> Tuple[int,int]:
    # cut from left if longer. Keep special tokens.
    normal_idx = 0
    while seq[normal_idx] in special_ids and normal_idx < len(seq)-1:
        normal_idx += 1
    if normal_idx >= len(seq)-1:
        normal_idx = 1
        #raise Exception('normal_idx longer for seq:' + str(seq))
    rest_idx = normal_idx + len(seq) - max_length
    seq[:] = seq[0:normal_idx] + seq[rest_idx:]
    return normal_idx, rest_idx


def pad_seq_right(seq: list, max_length: int, pad_id: int):
    # pad if shorter
    seq.extend(pad_id for _ in range(max_length - len(seq)))


def get_pad_token(tokenizerr) -> str:
    specials = [t.lower() for t in tokenizerr.all_special_tokens]
    pad_candidates = [t for t in specials if 'pad' in t]
    if len(pad_candidates) < 1:
        raise Exception('Cannot find PAD token in: ' + str(tokenizerr.all_special_tokens))
    return tokenizerr.all_special_tokens[specials.index(pad_candidates[0])]


def cut_pad_encodings_left(encodingz, tokenizerr, max_length: int):
    specials = dict(zip(tokenizerr.all_special_ids, tokenizerr.all_special_tokens))
    pad_code = get_pad_token(tokenizerr)
    padd_idx = tokenizerr.all_special_tokens.index(pad_code)
    for i, e in enumerate(encodingz.data['input_ids']):
        if len(e) < max_length:
            pad_seq_right(e, max_length, tokenizerr.all_special_ids[padd_idx])
            pad_seq_right(encodingz.data['attention_mask'][i], max_length, 0)
            if 'token_type_ids' in encodingz.data:
                pad_seq_right(encodingz.data['token_type_ids'][i], max_length, 0)
        elif len(e) > max_length:
            fro, to = cut_seq_left(e, max_length, specials)
            encodingz.data['attention_mask'][i] = encodingz.data['attention_mask'][i][:fro] \
                                                  + encodingz.data['attention_mask'][i][to:]
            if 'token_type_ids' in encodingz.data:
                encodingz.data['token_type_ids'][i] = encodingz.data['token_type_ids'][i][:fro] \
                                                      + encodingz.data['token_type_ids'][i][to:]


cut_pad_encodings_left(encoded_input, tokenizer, 10) # returns nothing: works in-place
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72202295

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档