GPT-2概述
GPT-2是OpenAI开发的大型语言模型,以其强大的文本生成能力闻名。
模型架构
GPT-2基于Transformer decoder:
1 2 3 4
| from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2LMHeadModel.from_pretrained('gpt2')
|
核心原理
自回归生成
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| def generate_text(prompt, max_length=100): input_ids = tokenizer.encode(prompt, return_tensors='pt') output = model.generate( input_ids, max_length=max_length, num_beams=5, temperature=0.8, top_k=50, top_p=0.95, do_sample=True ) return tokenizer.decode(output[0], skip_special_tokens=True)
|
采样策略
1 2 3 4 5 6 7 8 9 10 11
| output = model.generate(input_ids, max_length=50, do_sample=False)
output = model.generate(input_ids, max_length=50, num_beams=5)
output = model.generate(input_ids, max_length=50, do_sample=True, top_k=50)
output = model.generate(input_ids, max_length=50, do_sample=True, top_p=0.95)
|
中文生成
1 2 3 4 5 6 7 8 9
| from transformers import GPT2LMHeadModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2-medium-chinese-cluecorpussmall') model = GPT2LMHeadModel.from_pretrained('gpt2-medium-chinese-cluecorpussmall')
def generate_chinese(prompt, length=100): input_ids = tokenizer.encode(prompt, return_tensors='pt') output = model.generate(input_ids, max_length=length, do_sample=True, top_p=0.9) return tokenizer.decode(output[0], skip_special_tokens=True)
|
训练自己的GPT-2
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| from transformers import TextDataset, DataCollatorForLanguageModeling from transformers import Trainer, TrainingArguments
train_dataset = TextDataset( tokenizer=tokenizer, file_path='train.txt', block_size=128 )
training_args = TrainingArguments( output_dir='./gpt2', num_train_epochs=3, per_device_train_batch_size=4, save_steps=500, save_total_limit=2, )
trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, )
trainer.train()
|
总结
graph LR
A[输入文本] --> B[Tokenization]
B --> C[GPT-2 Forward]
C --> D[下一个词概率]
D --> E[采样/贪婪]
E --> F[生成文本]
F --> C