1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| class VAE(nn.Module): """变分自编码器:图像↔潜空间""" def __init__(self, z_channels=4, in_channels=3): super().__init__() self.encoder = nn.Sequential( nn.Conv2d(in_channels, 128, 3, 1, 1), ResBlock(128, 128), Downsample(128), ResBlock(128, 256), Downsample(256), ResBlock(256, 512), Downsample(512), ResBlock(512, 512), nn.Conv2d(512, z_channels * 2, 1) ) self.decoder = nn.Sequential( nn.Conv2d(z_channels, 512, 1), ResBlock(512, 512), Upsample(512), ResBlock(512, 512), Upsample(512), ResBlock(512, 256), Upsample(256), ResBlock(256, 128), nn.Conv2d(128, in_channels, 3, 1, 1), nn.Tanh() ) def encode(self, x): h = self.encoder(x) mean, logvar = h.chunk(2, dim=1) z = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar) return z def decode(self, z): return self.decoder(z)
|