1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| class CustomOpModel(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.SiLU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.SiLU(), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(64), nn.SiLU(), nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid(), ) def forward(self, x): features = self.encoder(x) return self.decoder(features)
def export_with_custom_config(model, output_path): model.eval() dummy_input = torch.randn(1, 3, 256, 256) from torch.onnx import register_custom_op_symbolic torch.onnx.export( model, dummy_input, output_path, opset_version=13, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch', 2: 'height', 3: 'width'}, 'output': {0: 'batch', 2: 'height', 3: 'width'} }, verbose=False ) import onnx model_onnx = onnx.load(output_path) onnx.checker.check_model(model_onnx) print("验证通过")
|