1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
| import torch import torch.nn as nn from transformers import BertModel, BertTokenizer, AdamW from torch.utils.data import DataLoader, Dataset
class PretrainFineTunePipeline: """预训练-微调完整流程""" def __init__(self, model_name='bert-base-chinese'): self.tokenizer = BertTokenizer.from_pretrained(model_name) self.model = BertModel.from_pretrained(model_name) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def pretrain(self, corpus, epochs=4, batch_size=32, max_length=512): """自监督预训练""" print("=== 阶段1: 预训练 ===") dataset = PretrainDataset(corpus, self.tokenizer, max_length) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = AdamW(self.model.parameters(), lr=2e-5) self.model.train() for epoch in range(epochs): total_loss = 0 for batch in dataloader: input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) outputs = self.model(input_ids, attention_mask=attention_mask) logits = outputs.last_hidden_state loss = self._mlm_loss(logits, input_ids) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}: Loss = {total_loss/len(dataloader):.4f}") def finetune(self, train_data, val_data, task='classification', epochs=3): """下游任务微调""" print(f"=== 阶段2: {task}微调 ===") if task == 'classification': self.task_head = ClassificationHead(self.model.config.hidden_size, num_labels=2) elif task == 'ner': self.task_head = NERHead(self.model.config.hidden_size, num_labels=10) elif task == 'qa': self.task_head = QAHead(self.model.config.hidden_size) self._freeze_layers(freeze_ratio=0.6) optimizer = AdamW( list(self.model.parameters()) + list(self.task_head.parameters()), lr=2e-5 ) for epoch in range(epochs): self.model.train() train_loss = self._train_epoch(train_data, optimizer) val_metrics = self._evaluate(val_data) print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_metrics['accuracy']:.4f}") def _mlm_loss(self, logits, input_ids): """掩码语言模型损失""" loss_fn = nn.CrossEntropyLoss() logits = logits[:, 1:, :] target = input_ids[:, 1:] loss = loss_fn(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) return loss def _freeze_layers(self, freeze_ratio=0.6): """冻结底层参数""" total_layers = self.model.config.num_hidden_layers freeze_layers = int(total_layers * freeze_ratio) for i, layer in enumerate(self.model.bert.encoder.layer): if i < freeze_layers: for param in layer.parameters(): param.requires_grad = False
|