In [ ]:
import torch.nn as nn
class Reshape(nn.Module):
"""A custom reshape layer."""
def __init__(self, shape):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
In [ ]:
class Generator(nn.Module):
def __init__(self, Z):
super().__init__()
self.Z = Z
self.gen_model = nn.Sequential(
nn.Linear(Z, 1024*8*8),
nn.BatchNorm1d(1024*8*8),
nn.LeakyReLU(0.2),
Reshape((-1, 1024, 8, 8)),
*self._make_deconv_blocks(),
nn.Conv2d(64, 3, 5, 1, 1),
nn.Tanh()
)
def _make_deconv_blocks(self):
channels = [1024, 512, 256, 128, 64]
kernels = [5, 5, 5, 5]
strides = [2, 2, 2, 2]
paddings = [1, 2, 2, 2]
output_pads = [0, 0, 0, 1]
blocks = []
for i in range(4):
blocks += [
nn.ConvTranspose2d(channels[i], channels[i+1], kernels[i], strides[i], paddings[i], output_pads[i]),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(0.2)
]
return blocks
def forward(self, noise):
return self.gen_model(noise)
In [ ]:
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
channels = [3, 64, 128, 256, 512, 1024]
conv_params = [(5, 2, 1), (5, 2, 1), (5, 2, 1), (5, 1, 2), (5, 2, 2)]
conv_blocks = []
for i in range(5):
k, s, p = conv_params[i]
conv_blocks += [
nn.Conv2d(channels[i], channels[i+1], k, s, p),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(0.2)
]
self.model = nn.Sequential(
*conv_blocks,
nn.Flatten(),
nn.Linear(1024*8*8, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)