首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >具有不同嵌入尺寸的经过训练模型上的resize_token_embeddings

具有不同嵌入尺寸的经过训练模型上的resize_token_embeddings
EN

Stack Overflow用户
提问于 2022-06-27 16:38:00
回答 1查看 1.2K关注 0票数 2

我想问一下如何改变经过训练的模型的嵌入大小。

我有一个训练有素的模特models/BERT-pretrain-1-step-5000.pkl。现在,我将向令牌添加一个新的令牌[TRA],并尝试将resize_token_embeddings使用到已接收的令牌。

代码语言:javascript
运行
复制
from pytorch_pretrained_bert_inset import BertModel #BertTokenizer 
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
import tqdm

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model_bert = BertModel.from_pretrained('bert-base-uncased', state_dict=torch.load('models/BERT-pretrain-1-step-5000.pkl', map_location=torch.device('cpu')))

#print(tokenizer.all_special_tokens) #--> ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
#print(tokenizer.all_special_ids)    #--> [100, 102, 0, 101, 103]

num_added_toks = tokenizer.add_tokens(['[TRA]'], special_tokens=True)
model_bert.resize_token_embeddings(len(tokenizer))  # --> Embedding(30523, 768)
print('[TRA] token id: ', tokenizer.convert_tokens_to_ids('[TRA]'))  # --> 30522

但我遇到了一个错误:

代码语言:javascript
运行
复制
AttributeError: 'BertModel' object has no attribute 'resize_token_embeddings'

我认为这是因为我所拥有的model_bert(BERT-pretrain-1-step-5000.pkl)具有不同的嵌入大小。我想知道是否有任何方法来适应我修改的令牌的嵌入大小,以及我想使用的模型作为初始权重。

非常感谢!!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-03 15:22:37

嵌入是一种抱面变压器方法。您使用的是来自BERTModel的pytorch_pretrained_bert_inset类,它没有提供这样的方法。从代码的角度来看,他们似乎是在一段时间前从拥抱中抄袭了伯特的代码。

您可以从INSET等待更新(可能会创建github问题),或者编写自己的代码来扩展word_embedding层:

代码语言:javascript
运行
复制
from torch import nn 

embedding_layer = model.embeddings.word_embeddings

old_num_tokens, old_embedding_dim = embedding_layer.weight.shape

num_new_tokens = 1

# Creating new embedding layer with more entries
new_embeddings = nn.Embedding(
        old_num_tokens + num_new_tokens, old_embedding_dim
)

# Setting device and type accordingly
new_embeddings.to(
    embedding_layer.weight.device,
    dtype=embedding_layer.weight.dtype,
)

# Copying the old entries
new_embeddings.weight.data[:old_num_tokens, :] = embedding_layer.weight.data[
    :old_num_tokens, :
]

model.embeddings.word_embeddings = new_embeddings
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72775559

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档