GPT-2文本生成原理与实战

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

GPT-2概述

GPT-2是OpenAI开发的大型语言模型,以其强大的文本生成能力闻名。

模型架构

GPT-2基于Transformer decoder:

1
2
3
4
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

核心原理

自回归生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def generate_text(prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors='pt')

output = model.generate(
input_ids,
max_length=max_length,
num_beams=5,
temperature=0.8,
top_k=50,
top_p=0.95,
do_sample=True
)

return tokenizer.decode(output[0], skip_special_tokens=True)

采样策略

1
2
3
4
5
6
7
8
9
10
11
# Greedy Search
output = model.generate(input_ids, max_length=50, do_sample=False)

# Beam Search
output = model.generate(input_ids, max_length=50, num_beams=5)

# Top-K Sampling
output = model.generate(input_ids, max_length=50, do_sample=True, top_k=50)

# Nucleus Sampling
output = model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95)

中文生成

1
2
3
4
5
6
7
8
9
from transformers import GPT2LMHeadModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2-medium-chinese-cluecorpussmall')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium-chinese-cluecorpussmall')

def generate_chinese(prompt, length=100):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = model.generate(input_ids, max_length=length, do_sample=True, top_p=0.9)
return tokenizer.decode(output[0], skip_special_tokens=True)

训练自己的GPT-2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='train.txt',
block_size=128
)

training_args = TrainingArguments(
output_dir='./gpt2',
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=500,
save_total_limit=2,
)

trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)

trainer.train()

总结

graph LR
A[输入文本] --> B[Tokenization]
B --> C[GPT-2 Forward]
C --> D[下一个词概率]
D --> E[采样/贪婪]
E --> F[生成文本]
F --> C
© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero