您现在的位置是:首页 > 文章详情

一个 Transformer 训练生成式模型的例子

日期:2025-06-10点击:85

最近在看chatGPT,想着chatGPT 是怎么训练出来的,不涉及神经网络算法,可以使用Transformer玩一下

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# 构造词表
vocab = ["<PAD>", "<BOS>", "<EOS>", "明天", "天气", "很", "好"]
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}

# 示例训练数据:输入和标签都偏移一位
# 输入: <BOS> 明天 天气 很
# 输出: 明天 天气 很 好

inputs = torch.tensor([
    [word2idx["<BOS>"], word2idx["明天"], word2idx["天气"], word2idx["很"]]
])

labels = torch.tensor([
    [word2idx["明天"], word2idx["天气"], word2idx["很"], word2idx["好"]]
])

class TinyTransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=32, nhead=2, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1, 100, d_model))  # 最多100个词
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding[:, :x.size(1)]
        x = self.transformer(x)
        return self.fc(x)
model = TinyTransformerModel(vocab_size=len(vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(50):
    model.train()
    out = model(inputs)  # [batch_size, seq_len, vocab_size]
    loss = loss_fn(out.view(-1, len(vocab)), labels.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, loss: {loss.item():.4f}")



def generate(model, start_tokens, max_len=5):
    model.eval()
    input_ids = torch.tensor([start_tokens])
    for _ in range(max_len):
        with torch.no_grad():
            logits = model(input_ids)
        next_token = logits[0, -1].argmax().item()
        input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
        if next_token == word2idx["<EOS>"] or len(input_ids[0]) > max_len:
            break
    return [idx2word[i] for i in input_ids[0].tolist()]

# 测试
generated = generate(model, [word2idx["<BOS>"]])
print("生成结果:", " ".join(generated))
原文链接:https://www.oschina.net/news/354538
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章