1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| import torch import torch.nn as nn import torch.nn.functional as F
class DoubleConv(nn.Module): """双卷积块""" def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x)
class UNet(nn.Module): """U-Net分割网络""" def __init__(self, in_channels=1, out_channels=2, features=[64, 128, 256, 512]): super().__init__() self.downs = nn.ModuleList() self.ups = nn.ModuleList() self.pool = nn.MaxPool2d(2, 2) for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature self.bottleneck = DoubleConv(features[-1], features[-1] * 2) for feature in reversed(features): self.ups.append( nn.ConvTranspose2d(feature * 2, feature, 2, 2) ) self.ups.append(DoubleConv(feature * 2, feature)) self.final_conv = nn.Conv2d(features[0], out_channels, 1) def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip = skip_connections[idx // 2] x = torch.cat([skip, x], dim=1) x = self.ups[idx + 1](x) return self.final_conv(x)
|