1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| class PostTrainingQuantizer: """后训练量化器""" def __init__(self, model, dataloader, num_bits=8): self.model = model self.dataloader = dataloader self.num_bits = num_bits def calibrate(self): """校准:收集激活值统计信息""" self.model.eval() activation_ranges = {} hooks = [] def hook_fn(name): def hook(module, input, output): if isinstance(output, tuple): output = output[0] activation_ranges[name] = { 'min': output.detach().min(), 'max': output.detach().max() } return hook for name, module in self.model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): hooks.append(module.register_forward_hook(hook_fn(name))) with torch.no_grad(): for i, (inputs, _) in enumerate(self.dataloader): if i >= 100: break self.model(inputs) for hook in hooks: hook.remove() return activation_ranges def quantize_model(self): """量化模型""" activation_ranges = self.calibrate() quantized_state_dict = {} for name, param in self.model.state_dict().items(): quantizer = SymmetricQuantizer(self.num_bits) quantized_param = quantizer.quantize(param) quantized_state_dict[name] = { 'data': quantized_param, 'scale': quantizer.scale } return quantized_state_dict
def static_quantize_model(model, example_inputs, quantized_dtype=torch.qint8): """PyTorch静态量化""" model.eval() model.fuse_model() qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.set_qconfig(qconfig) torch.quantization.prepare(model, inplace=True) with torch.no_grad(): for inputs, _ in dataloader: model(inputs) quantized_model = torch.quantization.convert(model, inplace=False) return quantized_model
|