Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Holes showing up while training #156

Open
bear96 opened this issue Apr 13, 2023 · 4 comments
Open

Holes showing up while training #156

bear96 opened this issue Apr 13, 2023 · 4 comments

Comments

@bear96
Copy link

bear96 commented Apr 13, 2023

Hi, so I've been trying to replicate your paper by creating a PyTorch model from scratch and training it on the original vangogh2photo dataset provided by Berkeley. Admittedly, it's for fun and not for any research, but I still hate it when it doesn't work out. So, this is the architecture of the model I've made:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random

'''This is a memory storage that stores 50 previously created images.
This is in accordance with the paper that introduced CycleGAN, Unpaired Image to Image translation.'''
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                # Returns newly added image with a probability of 0.5.
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[
                        i
                    ] = element  # replaces the older image with the newly generated image.
                else:
                    # Otherwise, it sends an older generated image and
                    to_return.append(element)
        return Variable(torch.cat(to_return))
    

'''Linear learning rate scheduler.'''

class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        if (n_epochs - decay_start_epoch) < 0:
            raise Exception("Decay should start before training ends. Change decay_start_epoch to a value less than {}.".format(n_epochs))
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
    
'''Single Residual Block. InstanceNorm2d produces blob artefacts. Consider changing it to modulated convolutions later.
Currently using augmentation and a low number of epochs to stop Generator from producing artefacts.'''

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super(ResNetBlock, self).__init__()

        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_channels, output_channels, num_resnet_blocks=9):
        super(GeneratorResNet, self).__init__()

        # Initial convolutional layer
        self.initial_conv = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Downsampling layers
        self.downsampling_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.downsampling_2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Residual layers
        self.residual_layers = nn.Sequential(
            *[ResNetBlock(256) for _ in range(num_resnet_blocks)]
        )

        # Upsampling layers
        self.upsampling_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.upsampling_2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Final convolutional layer
        self.final_conv = nn.Sequential(
            nn.Reflectio
![outputs-2](https://user-images.githubusercontent.com/73417041/231626093-fc0ce59a-7e10-42b0-b38c-ab672f70b0bf.png)
nPad2d(3),
            nn.Conv2d(64, output_channels, kernel_size=7, padding=0, bias=True),
            nn.Tanh()
        )

    def forward(self, x):
        # Apply initial convolutional layer
        x = self.initial_conv(x)

        # Apply downsampling layers
        x = self.downsampling_1(x)
        x = self.downsampling_2(x)

        # Apply residual layers
        x = self.residual_layers(x)

        # Apply upsampling layers
        x = self.upsampling_1(x)
        x = self.upsampling_2(x)

        # Apply final convolutional layer
        x = self.final_conv(x)

        return x

    
'''PatchGAN Discriminator'''

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_channels, out_channels, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # C64 -> C128 -> C256 -> C512
        self.model = nn.Sequential(
            *discriminator_block(channels, out_channels=64, normalize=False),
            *discriminator_block(64, out_channels=128),
            *discriminator_block(128, out_channels=256),
            *discriminator_block(256, out_channels=512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

Now around the third epoch of training, I am getting these "holes" in the generated pictures. Could anyone tell me why these are showing up and how I can prevent it? These are my hyperparameters:

'name': 'CycleGan_VanGogh_Checkpoint', 'n_epochs': 20, 'batch_size': 4, 'lr': 0.0002, 'decay_start_epoch': 19, 'b1': 0.5, 'b2': 0.999, 'img_size': 256, 'channels': 3, 'num_residual_blocks': 9, 'lambda_cyc': 10.0, 'lambda_id': 5.0}

outputs-2

@Joechann0831
Copy link

Hi! I've been plagued by the same problems for a long time. Have u solved this problem?

@bear96
Copy link
Author

bear96 commented Jun 13, 2023

Unfortunately not. I am still getting the same problem.

@Joechann0831
Copy link

Unfortunately not. I am still getting the same problem.

Sorry to hear that. After my comment on this issue, I find two issues highly related to our problems and the author of CycleGAN has answered them. He thinks these artifacts are caused by mode collapse and more training data, larger loss weights of identity/cycle consistency loss, or smaller learning rate can solve this problem. I'm trying them now, maybe you can try these solutions, too. Good luck for us!

BTW, here are the issues I mentioned:

junyanz/pytorch-CycleGAN-and-pix2pix#725

junyanz/pytorch-CycleGAN-and-pix2pix#446

@bear96
Copy link
Author

bear96 commented Jun 13, 2023

Thank you! That helps a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants