一个 Transformer 训练生成式模型的例子
最近在看chatGPT,想着chatGPT 是怎么训练出来的,不涉及神经网络算法,可以使用Transformer玩一下
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# 构造词表
vocab = ["<PAD>", "<BOS>", "<EOS>", "明天", "天气", "很", "好"]
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
# 示例训练数据:输入和标签都偏移一位
# 输入: <BOS> 明天 天气 很
# 输出: 明天 天气 很 好
inputs = torch.tensor([
[word2idx["<BOS>"], word2idx["明天"], word2idx["天气"], word2idx["很"]]
])
labels = torch.tensor([
[word2idx["明天"], word2idx["天气"], word2idx["很"], word2idx["好"]]
])
class TinyTransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=32, nhead=2, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = nn.Parameter(torch.randn(1, 100, d_model)) # 最多100个词
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x) + self.pos_encoding[:, :x.size(1)]
x = self.transformer(x)
return self.fc(x)
model = TinyTransformerModel(vocab_size=len(vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
out = model(inputs) # [batch_size, seq_len, vocab_size]
loss = loss_fn(out.view(-1, len(vocab)), labels.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"Epoch {epoch}, loss: {loss.item():.4f}")
def generate(model, start_tokens, max_len=5):
model.eval()
input_ids = torch.tensor([start_tokens])
for _ in range(max_len):
with torch.no_grad():
logits = model(input_ids)
next_token = logits[0, -1].argmax().item()
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
if next_token == word2idx["<EOS>"] or len(input_ids[0]) > max_len:
break
return [idx2word[i] for i in input_ids[0].tolist()]
# 测试
generated = generate(model, [word2idx["<BOS>"]])
print("生成结果:", " ".join(generated))