Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model collapse from the get go on grumpycat dataset #178

Open
Matans844 opened this issue Jan 21, 2024 · 6 comments
Open

Model collapse from the get go on grumpycat dataset #178

Matans844 opened this issue Jan 21, 2024 · 6 comments

Comments

@Matans844
Copy link

Matans844 commented Jan 21, 2024

Steps to reproduce the issue:

git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
conda env create -f environment.yml
conda activate contrastive-unpaired-translation
bash ./datasets/download_cut_dataset.sh grumpifycat
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT_default --CUT_mode CUT 

Now the output:
I swapped asbolute paths to the clone location with <MYPATH>

I first got:

----------------- Options ---------------                                                                                                                     
                 CUT_mode: CUT                                                                                                                                
               batch_size: 1                                                                                                                                  
                    beta1: 0.5                                                                                                                                
                    beta2: 0.999                                                                                                                              
          checkpoints_dir: ./checkpoints                                                                                                                      
           continue_train: False                                                                                                                              
                crop_size: 256                                                                                                                                
                 dataroot: ./datasets/grumpifycat               [default: placeholder]                                                                        
             dataset_mode: unaligned                                                                                                                          
                direction: AtoB                                                                                                                               
              display_env: main                                                                                                                               
             display_freq: 400                                                                                                                                
               display_id: None                                                                                                                               
            display_ncols: 4                                                                                                                                  
             display_port: 8097                                                                                                                               
           display_server: http://localhost                                                                                                                   
          display_winsize: 256                                                                                                                                
               easy_label: experiment_name                                                                                                                    
                    epoch: latest                                                                                                                             
              epoch_count: 1                                                                                                                                  
          evaluation_freq: 5000                                                                                                                               
        flip_equivariance: False                                                                                                                              
                 gan_mode: lsgan                                                                                                                              
                  gpu_ids: 0                                                                                                                                  
                init_gain: 0.02                                                                                                                               
                init_type: xavier                                                                                                                             
                 input_nc: 3                                                                                                                                  
                  isTrain: True                                 [default: None]
               lambda_GAN: 1.0                                                 
               lambda_NCE: 1.0                                                 
                load_size: 286                                                 
                       lr: 0.0002                                              
           lr_decay_iters: 50                                                  
                lr_policy: linear                                              
         max_dataset_size: inf                                                 
                    model: cut                                                 
                 n_epochs: 200                                                 
           n_epochs_decay: 200                                                 
               n_layers_D: 3                                                   
                     name: grumpycat_CUT_default                [default: experiment_name]                                                                    
                    nce_T: 0.07                                                
                  nce_idt: True                                                
nce_includes_all_negatives_from_minibatch: False                               
               nce_layers: 0,4,8,12,16                                         
                      ndf: 64                                                  
                     netD: basic                                               
                     netF: mlp_sample                                          
                  netF_nc: 256                                                 
                     netG: resnet_9blocks                                      
                      ngf: 64                                                  
             no_antialias: False                                               
          no_antialias_up: False                                               
               no_dropout: True                                                
                  no_flip: False                                               
                  no_html: False                                               
                    normD: instance                                            
                    normG: instance                                            
              num_patches: 256                                                 
              num_threads: 4                                                   
                output_nc: 3                                                   
                    phase: train
                pool_size: 0                                                   
               preprocess: resize_and_crop                                     
          pretrained_name: None                                                
               print_freq: 100                                                 
         random_scale_max: 3.0                                                 
             save_by_iter: False                                               
          save_epoch_freq: 5                                                   
         save_latest_freq: 5000                                                
           serial_batches: False                                               
stylegan2_G_num_downsampling: 1                                                
                   suffix:                                                     
         update_html_freq: 1000                                                
                  verbose: False                                               
----------------- End -------------------

And then I got:

dataset [UnalignedDataset] was created                                         
model [CUTModel] was created           
The number of training images = 214    
Socket had error 'TypeError' object has no attribute 'errno', attempting restart                                                                              


Could not connect to Visdom server.    
 Trying to start a server....          
Command: <MYPATH>/.conda/envs/contrastive-unpaired-translation/bin/python -m visdom.server -p 8097 &>/dev/null &                                           
create web directory ./checkpoints/grumpycat_CUT_default/web...                
<MYPATH>/CUT/models/networks.py:569: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).                             
  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)      
---------- Networks initialized -------------                                  
[Network G] Total number of parameters : 11.378 M                              
[Network F] Total number of parameters : 0.560 M                               
[Network D] Total number of parameters : 2.765 M                               
-----------------------------------------------                                
(epoch: 1, iters: 100, time: 2.393, data: 0.275) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan                                                
(epoch: 1, iters: 200, time: 1.537, data: 0.002) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan                                                
End of epoch 1 / 400     Time Taken: 786 sec                                   
learning rate = 0.0002000              
(epoch: 2, iters: 86, time: 1.019, data: 0.001) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan                                                 
(epoch: 2, iters: 186, time: 0.705, data: 0.001) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan                                                
End of epoch 2 / 400     Time Taken: 48 sec                                    
learning rate = 0.0002000

So from the first epoch, all values for the loss function are nan.

In addition, there's a mismatch between: conda env create -f environment.yml and pip install -r requirements.txt. Specificallly, visdom and networkx aren't installed when using the the environment.yml setup option, but are installed when the pip setup option is used.

Edit: Also after using pip install -r requirements.txt and then running:

python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT

the model collapses:

dataset [UnalignedDataset] was created
model [CUTModel] was created
The number of training images = 214
Setting up a new session...
create web directory ./checkpoints/grumpycat_CUT/web...
/seedoodata/datasets/data/asio/drone_image_analysis/2_04_models/CUT/models/networks.py:569: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
[Network F] Total number of parameters : 0.560 M
[Network D] Total number of parameters : 2.765 M
-----------------------------------------------
(epoch: 1, iters: 100, time: 3.653, data: 0.128) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan 
(epoch: 1, iters: 200, time: 2.259, data: 0.002) G_GAN: nan D_real: nan D_fake: nan G: nan NCE: nan NCE_Y: nan 
End of epoch 1 / 400     Time Taken: 1190 sec
learning rate = 0.0002000
@TattooPro
Copy link

Have you solved your problem yet?

1 similar comment
@Air1000thsummer
Copy link

Have you solved your problem yet?

@TattooPro
Copy link

Have you solved your problem yet?

My code works fine now, odds are it's the version of the package. Just change it all to zhuyanjun's version.

@Air1000thsummer
Copy link

My sincere thanks!

@Luo-YaFei
Copy link

Have you solved your problem yet?

My code works fine now, odds are it's the version of the package. Just change it all to zhuyanjun's version.
What do you mean zhuyanjun's version, can you please tell me which version you used?

@Luo-YaFei
Copy link

Problem resolved. Don't use the PyTorch and torchvision versions from the requirements file. I was able to run it using the newer versions.
The versions I used are torch==1.12, torchvision==0.13.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants