您现在的位置是:首页 > 文章详情

一个 Transformer 训练生成式模型的例子

日期:2025-06-10点击:14

最近在看chatGPT,想着chatGPT 是怎么训练出来的,不涉及神经网络算法,可以使用Transformer玩一下

 import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader # 构造词表 vocab = ["<PAD>", "<BOS>", "<EOS>", "明天", "天气", "很", "好"] word2idx = {w: i for i, w in enumerate(vocab)} idx2word = {i: w for w, i in word2idx.items()} # 示例训练数据:输入和标签都偏移一位 # 输入: <BOS> 明天 天气 很 # 输出: 明天 天气 很 好 inputs = torch.tensor([ [word2idx["<BOS>"], word2idx["明天"], word2idx["天气"], word2idx["很"]] ]) labels = torch.tensor([ [word2idx["明天"], word2idx["天气"], word2idx["很"], word2idx["好"]] ]) class TinyTransformerModel(nn.Module): def __init__(self, vocab_size, d_model=32, nhead=2, num_layers=2): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = nn.Parameter(torch.randn(1, 100, d_model)) # 最多100个词 encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) self.fc = nn.Linear(d_model, vocab_size) def forward(self, x): x = self.embedding(x) + self.pos_encoding[:, :x.size(1)] x = self.transformer(x) return self.fc(x) model = TinyTransformerModel(vocab_size=len(vocab)) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) loss_fn = nn.CrossEntropyLoss() for epoch in range(50): model.train() out = model(inputs) # [batch_size, seq_len, vocab_size] loss = loss_fn(out.view(-1, len(vocab)), labels.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 20 == 0: print(f"Epoch {epoch}, loss: {loss.item():.4f}") def generate(model, start_tokens, max_len=5): model.eval() input_ids = torch.tensor([start_tokens]) for _ in range(max_len): with torch.no_grad(): logits = model(input_ids) next_token = logits[0, -1].argmax().item() input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1) if next_token == word2idx["<EOS>"] or len(input_ids[0]) > max_len: break return [idx2word[i] for i in input_ids[0].tolist()] # 测试 generated = generate(model, [word2idx["<BOS>"]]) print("生成结果:", " ".join(generated)) 
原文链接:https://www.oschina.net/news/354538
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章