前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >AI: 从零开始训练一个最小化的Transformer聊天机器人

AI: 从零开始训练一个最小化的Transformer聊天机器人

作者头像
运维开发王义杰
发布2024-06-27 13:29:59
730
发布2024-06-27 13:29:59
举报

这里将介绍如何从零开始,使用Transformer模型训练一个最小化的聊天机器人。该流程将尽量简化,不依赖预训练模型,并手动实现关键步骤,确保每一步都容易理解。

1. 环境准备

首先,确保安装了必要的Python库。我们只需要基本的Numpy和PyTorch库来实现我们的Transformer模型。

代码语言:javascript
复制


pip install numpy torch

2. 数据准备

创建一个简单的对话数据集。对于最小化实现,我们使用手工编写的对话数据集。

代码语言:javascript
复制

python
data = [
    ("你好", "你好!有什么我可以帮助你的?"),
    ("今天天气怎么样?", "今天天气很好,阳光明媚。"),
    ("你会做什么?", "我可以和你聊天,回答你的问题。")
]

3. 数据预处理

手动实现一个简单的分词和编码器。

代码语言:javascript
复制

python
# 建立词汇表
vocab = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2}
for pair in data:
    for sentence in pair:
        for word in sentence:
            if word not in vocab:
                vocab[word] = len(vocab)

# 编码函数
def encode(sentence, vocab):
    return [vocab["<SOS>"]] + [vocab[word] for word in sentence] + [vocab["<EOS>"]]

# 编码数据
encoded_data = [(encode(pair[0], vocab), encode(pair[1], vocab)) for pair in data]

# 确保所有句子长度一致(填充或截断)
max_len = max(max(len(pair[0]), len(pair[1])) for pair in encoded_data)

def pad_sequence(seq, max_len, pad_value):
    return seq + [pad_value] * (max_len - len(seq))

padded_data = [(pad_sequence(pair[0], max_len, vocab["<PAD>"]),
                pad_sequence(pair[1], max_len, vocab["<PAD>"])) for pair in encoded_data]

4. 模型定义

定义一个简单的Transformer模型。

代码语言:javascript
复制

python
import torch
import torch.nn as nn

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads=1)
        self.fc = nn.Linear(embedding_dim, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src).permute(1, 0, 2)
        tgt = self.embedding(tgt).permute(1, 0, 2)
        attn_output, _ = self.attention(tgt, src, src)
        output = self.fc(attn_output.permute(1, 0, 2))
        return output

# 参数设置
vocab_size = len(vocab)
embedding_dim = 16

# 初始化模型
model = SimpleTransformer(vocab_size, embedding_dim)

5. 模型训练

使用简单的交叉熵损失函数和随机梯度下降(SGD)优化器训练模型。

代码语言:javascript
复制

python
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<PAD>"])
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 简单的数据生成器
def data_generator(data, batch_size=1):
    for src, tgt in data:
        yield torch.tensor([src], dtype=torch.long), torch.tensor([tgt], dtype=torch.long)

# 训练模型
epochs = 100
for epoch in range(epochs):
    total_loss = 0
    for src, tgt in data_generator(padded_data):
        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])
        loss = criterion(output.view(-1, vocab_size), tgt[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(padded_data)}")

6. 模型评估

评估模型性能,并测试生成回复。

代码语言:javascript
复制

python
def generate_reply(model, input_sentence, vocab, max_length=20):
    model.eval()
    input_encoded = torch.tensor([pad_sequence(encode(input_sentence, vocab), max_len, vocab["<PAD>"])], dtype=torch.long)
    output_encoded = torch.tensor([[vocab["<SOS>"]]], dtype=torch.long)
    for _ in range(max_length):
        output = model(input_encoded, output_encoded)
        next_word = torch.argmax(output[:, -1, :], dim=-1).item()
        output_encoded = torch.cat([output_encoded, torch.tensor([[next_word]], dtype=torch.long)], dim=1)
        if next_word == vocab["<EOS>"]:
            break
    return "".join([list(vocab.keys())[list(vocab.values()).index(i)] for i in output_encoded[0].numpy()[1:-1]])

# 测试生成回复
print(generate_reply(model, "你好", vocab))

7. 保存模型

保存训练好的模型,以便后续加载和使用。

代码语言:javascript
复制

python
# 保存模型
torch.save(model.state_dict(), "simple_transformer_model.pth")

8. 加载模型

需要时加载之前保存的模型权重,可以继续使用。

代码语言:javascript
复制

python
# 加载模型
model = SimpleTransformer(vocab_size, embedding_dim)
model.load_state_dict(torch.load("simple_transformer_model.pth"))
model.eval()  # 设置模型为评估模式

总结

本文介绍了如何从零开始构建一个最小化的Transformer聊天机器人。从数据准备、模型定义到训练和评估,每一步都尽量简化,以便于理解。希望这个例子能够帮助大家了解Transformer模型在聊天机器人中的基本应用。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-06-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 运维开发王义杰 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 环境准备
  • 2. 数据准备
    • 3. 数据预处理
      • 4. 模型定义
        • 5. 模型训练
          • 6. 模型评估
            • 7. 保存模型
              • 8. 加载模型
              • 总结
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档