1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| class UNet(nn.Module): """ 用于去噪的UNet网络 预测噪声 εθ(x_t, t) """ def __init__(self, channels=3, base_channels=128, channel_multipliers=[1, 2, 4, 8], num_res_blocks=2): super().__init__() self.channels = channels self.time_channels = base_channels * 4 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(base_channels), nn.Linear(base_channels, self.time_channels), nn.GELU(), nn.Linear(self.time_channels, self.time_channels) ) self.input_conv = nn.Conv2d(channels, base_channels, 3, padding=1) self.encoder_blocks = nn.ModuleList() self.encoder_downs = nn.ModuleList() in_ch = base_channels for i, mult in enumerate(channel_multipliers): out_ch = base_channels * mult for _ in range(num_res_blocks): self.encoder_blocks.append( ResidualBlock(in_ch, out_ch, self.time_channels) ) in_ch = out_ch if i < len(channel_multipliers) - 1: self.encoder_downs.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1)) self.mid_block = ResidualBlock(in_ch, in_ch, self.time_channels) self.decoder_blocks = nn.ModuleList() self.decoder_ups = nn.ModuleList() for i, mult in reversed(list(enumerate(channel_multipliers))): out_ch = base_channels * mult for _ in range(num_res_blocks + 1): self.decoder_blocks.append( ResidualBlock(in_ch + out_ch, out_ch, self.time_channels) ) in_ch = out_ch if i > 0: self.decoder_ups.append(nn.ConvTranspose2d( in_ch, in_ch, 4, stride=2, padding=1 )) self.output_conv = nn.Sequential( nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.GroupNorm(8, base_channels), nn.SiLU(), nn.Conv2d(base_channels, channels, 3, padding=1) ) def forward(self, x, t): t_emb = self.time_mlp(t) h = self.input_conv(x) hs = [] for block in self.encoder_blocks: h = block(h, t_emb) hs.append(h) h = self.encoder_downs[-1](h) if self.encoder_downs else h h = self.mid_block(h, t_emb) for block in self.decoder_blocks: h = torch.cat([h, hs.pop()], dim=1) h = block(h, t_emb) return self.output_conv(h)
|