前言
2018年,Google发布了BERT(Bidirectional Encoder Representations from Transformers),在多项NLP基准测试中创下了新纪录。本文将深入解析BERT的核心原理、预训练方法和微调策略。
BERT的核心创新
BERT的核心创新在于双向上下文建模。传统的语言模型(如GPT)只能从左到右或从右到左单向建模,而BERT同时利用左右上下文信息。
BERT基于Transformer的编码器部分,使用双向自注意力机制:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch import torch.nn as nn from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') model = BertModel.from_pretrained('bert-base-chinese')
text = "今天天气真好,适合出去游玩" encoded = tokenizer(text, return_tensors='pt')
with torch.no_grad(): outputs = model(**encoded)
last_hidden_state = outputs.last_hidden_state pooled_output = outputs.pooler_output
print(f"Last hidden state shape: {last_hidden_state.shape}") print(f"Pooled output shape: {pooled_output.shape}")
|
预训练任务
BERT使用两个预训练任务进行训练:
1. 掩码语言模型(Masked Language Model, MLM)
随机遮盖15%的token,模型需要预测被遮盖的词:
1 2
| Input: 今天[MASK]天很好 Output: 今天天气很好
|
2. 下一句预测(Next Sentence Prediction, NSP)
判断句子B是否是句子A的下一句:
1 2 3 4 5 6 7
| Sentence A: 小明去了学校 Sentence B: 他在图书馆学习 Label: IsNext
Sentence A: 小明去了学校 Sentence B: 今天天气晴朗 Label: NotNext
|
完整的BERT预训练代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| import torch from torch.utils.data import DataLoader from transformers import BertForPreTraining, BertConfig
class BERTPretrainingDataset(torch.utils.data.Dataset): def __init__(self, texts, tokenizer, max_length=512): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].squeeze(), 'attention_mask': encoding['attention_mask'].squeeze(), 'token_type_ids': encoding['token_type_ids'].squeeze() }
def mlm_loss(outputs, masked_lm_labels, vocab_size): """计算MLM损失""" prediction_scores = outputs.prediction_logits loss_fct = nn.CrossEntropyLoss() masked_lm_loss = loss_fct( prediction_scores.view(-1, vocab_size), masked_lm_labels.view(-1) ) return masked_lm_loss
def nsp_loss(outputs, next_sentence_labels): """计算NSP损失""" seq_relationship_score = outputs.seq_relationship_logits loss_fct = nn.CrossEntropyLoss() next_sentence_loss = loss_fct( seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1) ) return next_sentence_loss
|
BERT变体与改进
中文BERT
对于中文,BERT使用字级别分词而非词级别:
1 2 3 4 5 6 7 8 9 10 11 12 13
| from transformers import BertForSequenceClassification, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') model = BertForSequenceClassification.from_pretrained( 'bert-base-chinese', num_labels=2 )
text = "人工智能将改变未来" tokens = tokenizer.tokenize(text) print(tokens)
|
RoBERTa
RoBERTa是BERT的优化版本,做了以下改进:
- 更大的batch size和更多数据
- 动态掩码策略
- 移除NSP任务
- 更长的训练时间
ALBERT
ALBERT通过参数共享和因子分解减少模型参数量:
1 2 3 4 5
| from transformers import AlbertModel, AlbertTokenizer
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') model = AlbertModel.from_pretrained('albert-base-v2')
|
下游任务微调
BERT的预训练-微调范式是其成功的关键:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| class BertFineTuner(nn.Module): def __init__(self, bert_model, num_classes): super(BertFineTuner, self).__init__() self.bert = bert_model self.classifier = nn.Linear(768, num_classes) def forward(self, input_ids, attention_mask, token_type_ids): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) pooled_output = outputs.pooler_output logits = self.classifier(pooled_output) return logits
def fine_tune_model(train_loader, model, optimizer, num_epochs): model.train() for epoch in range(num_epochs): total_loss = 0 for batch in train_loader: optimizer.zero_grad() logits = model( batch['input_ids'], batch['attention_mask'], batch['token_type_ids'] ) loss = nn.CrossEntropyLoss()(logits, batch['labels']) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}")
|
实际应用场景
BERT及变体在以下任务中表现优异:
- 文本分类:情感分析、新闻分类
- 命名实体识别:提取人名、地名等实体
- 问答系统:从文档中提取答案
- 文本生成:结合其他模型使用
总结
BERT开创了预训练-微调的范式,通过大规模无监督预训练和任务特定微调,显著提升了NLP任务的性能。其核心的双向上下文建模和Transformer架构已成为现代NLP的基础。
参考资源