Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: num_workers > 0 results in pickling error on macOS/Windows #886

Closed
mohscorpion opened this issue Nov 3, 2022 · 33 comments · Fixed by #992
Closed

Trainers: num_workers > 0 results in pickling error on macOS/Windows #886

mohscorpion opened this issue Nov 3, 2022 · 33 comments · Fixed by #992
Labels
trainers PyTorch Lightning trainers
Milestone

Comments

@mohscorpion
Copy link

mohscorpion commented Nov 3, 2022

i have a problem running torchgeo on eurosat MS tif images i wrote a simple code :

eurosat = EuroSAT(euro_root, split="train", download=False)
# sampler = RandomGeoSampler(eurosat, size=64, length=10000)
# dataloader = DataLoader(eurosat, batch_size=128, sampler=sampler, collate_fn=stack_samples)
dataloader = DataLoader(eurosat, batch_size=128, collate_fn=stack_samples)
num_classes = 10
channels = 13
num_workers = 4
batch_size = 4
backbone = "resnet50"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
datamodule = EuroSATDataModule(
     root_dir=euro_root,
     batch_size=batch_size,
     num_workers=num_workers,
)
task = ClassificationTask(
    classification_model=backbone,
    weights=weights,
    num_classes=num_classes,
    in_channels=channels,
    loss="ce",
    learning_rate=lr,
    learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    save_top_k=1,
    save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=10,
)

# Train
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=epochs
)
trainer.fit(model=task, datamodule=datamodule)

but i get this error on fit line :

.........
_pickle.PicklingError: Can't pickle <class 'nn.BatchNorm2d.BatchNorm2d'>: import of module 'nn.BatchNorm2d' failed
..........
OSError: [Errno 22] Invalid argument: 'C:\\Users\\mohsc\\PycharmProjects\\pythonProject\\<input>'

am i missing a preparing step or something ?
help would be appreciated, by the way tutorials are still lacking essential sample codes and documentations like this issue on MS images.
thanks

@calebrob6
Copy link
Member

Hi @mohscorpion, I think the formatting of your paste is messed up, do you mind trying again? Also, can you post the entire file you are using (including imports) and the entire stack trace of the error?

@isaaccorley
Copy link
Collaborator

For now the EuroSat dataset doesn't require a geo sampler so removing that should work.

@calebrob6
Copy link
Member

If you follow this tutorial, but use the EuroSat dataset, then it should work https://torchgeo.readthedocs.io/en/latest/tutorials/trainers.html.

@mohscorpion
Copy link
Author

For now the EuroSat dataset doesn't require a geo sampler so removing that should work.

i have commented that line out because of error , but still the problem remains

@mohscorpion
Copy link
Author

If you follow this tutorial, but use the EuroSat dataset, then it should work https://torchgeo.readthedocs.io/en/latest/tutorials/trainers.html.

my code is not essentially different the only difference is classifier task instead of regressor

@calebrob6
Copy link
Member

Can you post a properly formatted version of your code and error?

@calebrob6
Copy link
Member

Here's a demo script https://gist.github.com/calebrob6/2e111a61fe8e6b531d9a0844a79e9d30 that uses torchgeo version 0.3.1

I used a conda environment I created with conda create -n torchgeo python pip torchgeo -c conda-forge

@adamjstewart
Copy link
Collaborator

Fixed the formatting. Code blocks require triple backticks, see here.

@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Nov 3, 2022
@calebrob6
Copy link
Member

I can't reproduce, the above code (minus the dataset/dataloader that are not needed) works fine for me.

@adamjstewart
Copy link
Collaborator

This is likely a multiprocessing issue. The reason @calebrob6 can't reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

If I'm correct, I should be able to reproduce this on macOS. Let me give it a shot.

@mohscorpion
Copy link
Author

mohscorpion commented Nov 3, 2022

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import stack_samples, EuroSAT
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import ClassificationTask

euro_root = "./Eurosat/"
eurosat = EuroSAT(euro_root, split="train", download=False)
dataloader = DataLoader(eurosat, batch_size=128, collate_fn=stack_samples)
num_classes = 10
channels = 13
num_workers = 4
batch_size = 4
backbone = "resnet50"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
datamodule = EuroSATDataModule(
     root_dir=euro_root,
     batch_size=batch_size,
     num_workers=num_workers,
)
task = ClassificationTask(
    classification_model=backbone,
    weights=weights,
    num_classes=num_classes,
    in_channels=channels,
    loss="ce",
    learning_rate=lr,
    learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    save_top_k=1,
    save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=10,
)

trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=epochs
)
trainer.fit(model=task, datamodule=datamodule)
test_metrics = trainer.test(model=task, datamodule=datamodule)

@mohscorpion
Copy link
Author

This is likely a multiprocessing issue. The reason @calebrob6 can't reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

If I'm correct, I should be able to reproduce this on macOS. Let me give it a shot.

i will try on linux and report back

@adamjstewart
Copy link
Collaborator

Yep, I'm seeing the same issue on macOS. For now, a quick workaround is to set num_workers = 0. Let me see if I can figure out why this isn't working.

@adamjstewart
Copy link
Collaborator

Full stack trace for anyone curious:

Traceback (most recent call last):
  File "/Users/ajstewart/torchgeo/test.py", line 52, in <module>
    trainer.fit(model=task, datamodule=datamodule)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _run_train
    self._run_sanity_check()
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1343, in _run_sanity_check
    val_loop.run()
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 195, in run
    self.on_run_start(*args, **kwargs)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 88, in on_run_start
    self._data_fetcher = iter(data_fetcher)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 178, in __iter__
    self.dataloader_iter = iter(self.dataloader)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 444, in __iter__
    return self._get_iterator()
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 390, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/ajstewart/.spack/.spack-env/view/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1077, in __init__
    w.start()
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/ajstewart/.spack/.spack-env/._view/lf2rw2hmyoq3rkmbrmtnww2k2j6ynhmy/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'nn.BatchNorm2d.BatchNorm2d'>: import of module 'nn.BatchNorm2d' failed

@adamjstewart
Copy link
Collaborator

The thing that's odd to me is that multiprocessing (and therefore pickling) only happens within the data loader, but there is no batch norm in the dataset/data module. It's almost like it's trying to pickle the ResNet inside the dataset for some reason...

@mohscorpion
Copy link
Author

i can confirm now , it runs ok on linux

@FlorisCalkoen
Copy link

Bit obvious, but this erorr now also happens in the trainers tutorial for folks running MacOS (and probably also windows) .

This seems to work:

trainer = pl.Trainer(
    accelerator="mps",
    devices=1,
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[csv_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=10,
    fast_dev_run=in_tests
)

Although there are warnings:

/Users/calkoen/miniconda3/envs/torchgeo/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

@adamjstewart
Copy link
Collaborator

Spack has two types of tests:

  • unit tests: fast, run on every commit, every Python, every OS
  • integration tests: slow, run only when creating release, latest Python, Linux-only

The former run in serial, but the latter run in parallel and include our tutorials. Once we fix this issue, we should probably also start running our integration tests on macOS and Windows as well so we can prevent this issue from coming back.

@Seyed-Ali-Ahmadi
Copy link

Seyed-Ali-Ahmadi commented Nov 12, 2022

I get the same error on windows. But I'm using xView2 dataset for semantic segmentation. It is working on colab, but not on windows.

PicklingError                             Traceback (most recent call last)
Cell In [25], line 11
      1 trainer = pl.Trainer(
      2     callbacks=[checkpoint_callback, early_stopping_callback],
      3     logger=[tb_logger],
   (...)
      8     devices=[gpu_id]
      9 )
---> 11 _ = trainer.fit(model=task, datamodule=datamodule)

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:579, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    577     raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
    578 self.strategy._lightning_module = model
--> 579 call._call_and_handle_interrupt(
    580     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    581 )

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37     else:
---> 38         return trainer_fn(*args, **kwargs)
     40 except _TunerExitException:
     41     trainer._call_teardown_hook()

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:621, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    614 ckpt_path = ckpt_path or self.resume_from_checkpoint
    615 self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
    616     self.state.fn,
    617     ckpt_path,  # type: ignore[arg-type]
    618     model_provided=True,
    619     model_connected=self.lightning_module is not None,
    620 )
--> 621 self._run(model, ckpt_path=self.ckpt_path)
    623 assert self.state.stopped
    624 self.training = False

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1058, in Trainer._run(self, model, ckpt_path)
   1054 self._checkpoint_connector.restore_training_state()
   1056 self._checkpoint_connector.resume_end()
-> 1058 results = self._run_stage()
   1060 log.detail(f"{self.__class__.__name__}: trainer tearing down")
   1061 self._teardown()

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1137, in Trainer._run_stage(self)
   1135 if self.predicting:
   1136     return self._run_predict()
-> 1137 self._run_train()

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1150, in Trainer._run_train(self)
   1147 self._pre_training_routine()
   1149 with isolate_rng():
-> 1150     self._run_sanity_check()
   1152 # enable train mode
   1153 assert self.model is not None

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1222, in Trainer._run_sanity_check(self)
   1220 # run eval step
   1221 with torch.no_grad():
-> 1222     val_loop.run()
   1224 self._call_callback_hooks("on_sanity_check_end")
   1226 # reset logger connector

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\loops\loop.py:199, in Loop.run(self, *args, **kwargs)
    197 try:
    198     self.on_advance_start(*args, **kwargs)
--> 199     self.advance(*args, **kwargs)
    200     self.on_advance_end()
    201     self._restarting = False

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py:152, in EvaluationLoop.advance(self, *args, **kwargs)
    150 if self.num_dataloaders > 1:
    151     kwargs["dataloader_idx"] = dataloader_idx
--> 152 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    154 # store batch level output per dataloader
    155 self._outputs.append(dl_outputs)

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\loops\loop.py:194, in Loop.run(self, *args, **kwargs)
    190     return self.on_skip()
    192 self.reset()
--> 194 self.on_run_start(*args, **kwargs)
    196 while not self.done:
    197     try:

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py:84, in EvaluationEpochLoop.on_run_start(self, data_fetcher, dl_max_batches, kwargs)
     82 self._reload_dataloader_state_dict(data_fetcher)
     83 # creates the iterator inside the fetcher but returns `self`
---> 84 self._data_fetcher = iter(data_fetcher)
     85 # add the previous `fetched` value to properly track `is_last_batch` with no prefetching
     86 data_fetcher.fetched += self.batch_progress.current.ready

File ~\Desktop\Place\TorchEnv\lib\site-packages\pytorch_lightning\utilities\fetching.py:178, in AbstractDataFetcher.__iter__(self)
    176 def __iter__(self) -> "AbstractDataFetcher":
    177     self.reset()
--> 178     self.dataloader_iter = iter(self.dataloader)
    179     self._apply_patch()
    180     self.prefetching()

File ~\Desktop\Place\TorchEnv\lib\site-packages\torch\utils\data\dataloader.py:444, in DataLoader.__iter__(self)
    442     return self._iterator
    443 else:
--> 444     return self._get_iterator()

File ~\Desktop\Place\TorchEnv\lib\site-packages\torch\utils\data\dataloader.py:390, in DataLoader._get_iterator(self)
    388 else:
    389     self.check_worker_number_rationality()
--> 390     return _MultiProcessingDataLoaderIter(self)

File ~\Desktop\Place\TorchEnv\lib\site-packages\torch\utils\data\dataloader.py:1077, in _MultiProcessingDataLoaderIter.__init__(self, loader)
   1070 w.daemon = True
   1071 # NB: Process.start() actually take some time as it needs to
   1072 #     start a process and pass the arguments over via a pipe.
   1073 #     Therefore, we only add a worker to self._workers list after
   1074 #     it started, so that we do not call .join() if program dies
   1075 #     before it starts, and __del__ tries to join but will get:
   1076 #     AssertionError: can only join a started process.
-> 1077 w.start()
   1078 self._index_queues.append(index_queue)
   1079 self._workers.append(w)

File ~\Desktop\Place\TorchEnv\lib\multiprocessing\process.py:121, in BaseProcess.start(self)
    118 assert not _current_process._config.get('daemon'), \
    119        'daemonic processes are not allowed to have children'
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel
    123 # Avoid a refcycle if the target function holds an indirect
    124 # reference to the process object (see bpo-30775)

File ~\Desktop\Place\TorchEnv\lib\multiprocessing\context.py:224, in Process._Popen(process_obj)
    222 @staticmethod
    223 def _Popen(process_obj):
--> 224     return _default_context.get_context().Process._Popen(process_obj)

File ~\Desktop\Place\TorchEnv\lib\multiprocessing\context.py:327, in SpawnProcess._Popen(process_obj)
    324 @staticmethod
    325 def _Popen(process_obj):
    326     from .popen_spawn_win32 import Popen
--> 327     return Popen(process_obj)

File ~\Desktop\Place\TorchEnv\lib\multiprocessing\popen_spawn_win32.py:93, in Popen.__init__(self, process_obj)
     91 try:
     92     reduction.dump(prep_data, to_child)
---> 93     reduction.dump(process_obj, to_child)
     94 finally:
     95     set_spawning_popen(None)

File ~\Desktop\Place\TorchEnv\lib\multiprocessing\reduction.py:60, in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)

PicklingError: Can't pickle <class 'nn.BatchNorm2d.BatchNorm2d'>: import of module 'nn.BatchNorm2d' failed.

@adamjstewart
Copy link
Collaborator

adamjstewart commented Dec 25, 2022

Further minimized the bug reproducer:

from pytorch_lightning import Trainer
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.trainers import ClassificationTask


datamodule = EuroSATDataModule(
     root="tests/data/eurosat",
     num_workers=4,
)
model = ClassificationTask(
    model="resnet18",
    weights="random",
    num_classes=2,
    in_channels=13,
    loss="ce",
    learning_rate=0.01,
    learning_rate_schedule_patience=5,
)
trainer = Trainer(max_epochs=1)
trainer.fit(model=model, datamodule=datamodule)

Doesn't get much smaller than that. Interestingly, the following does not exhibit the same issue:

datamodule.setup()
trainer.fit(model=model, train_dataloaders=datamodule.train_dataloader())

I'm pretty confident this is a PyTorch Lightning bug. I'm trying to reproduce this outside of TorchGeo, but haven't gotten it working yet. Will let you know if I figure this out.

@adamjstewart
Copy link
Collaborator

The following also raises the same error, which confirms that PyTorch Lightning is trying to pickle the model for some reason:

import pickle
from torchgeo.trainers import ClassificationTask


model = ClassificationTask(
    model="resnet18",
    weights="random",
    num_classes=10,
    in_channels=3,
    loss="ce",
    learning_rate=0.01,
    learning_rate_schedule_patience=5,
)

pickle.dumps(model)

@adamjstewart
Copy link
Collaborator

Interestingly, the following does not raise an error:

import pickle
from torchvision.models import resnet18

model = resnet18()
pickle.dumps(model)

@adamjstewart
Copy link
Collaborator

@calebrob6 you're going to love this. Remember when I complained the other day about how some of our transforms are done in on_after_batch_transfer and some are done in preprocess? Well it turns out that this preprocess function is the source of the bug.

On macOS/Windows, the default multiprocessing start method requires all objects necessary to run the subprocess to be pickleable. The LightningModule isn't pickleable, but normally it isn't necessary for parallel data loading, so that's fine. However, our DataLoaders use transforms that include a reference to an instance method of the LightningDataModule. The LightningDataModule itself is pickleable, but during training it acquires a reference back to the LightningModule, making it no longer pickleable. Chaos ensues.

So the real "bug" is that LightningModules aren't pickleable even though the models they contain are pickleable. However, there's also an obvious workaround for this which is to not use instance methods during data loading. I'll open a bug report with the PyTorch Lightning folks, but I'll also open a PR here to remove all preprocess functions.

Thanks @mohscorpion @FlorisCalkoen @Seyed-Ali-Ahmadi for reporting this bug, and sorry it took so long to track down!

@adamjstewart
Copy link
Collaborator

Update: LightningModules are pickleable as long as you don't hack your import space. It turns out all of the:

BatchNorm2d.__module__ = "nn.BatchNorm2d"

stuff we have littered throughout TorchGeo is the reason that our trainers can't be pickled. This stuff was in there to fix the docs. I'm going to try to remove as much of it as I can and see if it's still needed with the latest version of Sphinx. Even if that fixes it, I might still remove the preprocess functions to speed up multiprocessing on macOS/Windows. No reason to pickle the entire LightningModule/LightningDataModule just to logically separate preprocessing and data augmentation.

P.S. Apologies to the PyTorch Lightning devs for assuming that the bug was in their code and not ours!

@adamjstewart
Copy link
Collaborator

#976 is sufficient to fix ClassificationTask pickling, but I run into a new pickling issue when training on EuroSAT:

Traceback (most recent call last):
  File "/Users/Adam/torchgeo/test.py", line 20, in <module>
    trainer.fit(model=model, datamodule=datamodule)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
    results = self._run_stage()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
    self._run_train()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1190, in _run_train
    self._run_sanity_check()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1262, in _run_sanity_check
    val_loop.run()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 194, in run
    self.on_run_start(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 84, in on_run_start
    self._data_fetcher = iter(data_fetcher)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 178, in __iter__
    self.dataloader_iter = iter(self.dataloader)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 435, in __iter__
    return self._get_iterator()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 381, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1034, in __init__
    w.start()
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/Adam/.spack/.spack-env/._view/cbdygghyh2sfuqaua6oj37kwgduijeq7/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'EvaluationLoop.advance.<locals>.batch_to_device'

Didn't bother digging into this too much, but replacing preprocess solves the issue.

@adamjstewart adamjstewart added this to the 0.3.2 milestone Dec 25, 2022
@adamjstewart adamjstewart changed the title eurosat multiband classification problem Trainers: num_workers > 0 results in pickling error on macOS/Windows Dec 25, 2022
@calebrob6
Copy link
Member

calebrob6 commented Jan 3, 2023

@adamjstewart what about EuroSat is different from other datamodules that work? I'm asking because EuroSat seems like a very simple datamodule that shouldn't cause problems (which makes me suspicious that we don't fully understand what is going on).

@adamjstewart
Copy link
Collaborator

None of our data modules work in parallel on macOS/Windows. #992 will fix that.

@calebrob6
Copy link
Member

I know, but what is going on with "I run into a new pickling issue when training on EuroSAT" -- this smells suspicious.

@adamjstewart
Copy link
Collaborator

Let me do some digging...

@adamjstewart
Copy link
Collaborator

So the above error happens during sanity checks to the validation set. If you add num_sanity_val_steps=0 to the Trainer, you get a similar error during fit, but with a little more detail:

Traceback (most recent call last):
  File "/Users/Adam/torchgeo/test.py", line 23, in <module>
    trainer.fit(model=model, datamodule=datamodule)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
    results = self._run_stage()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
    self._run_train()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train
    self.fit_loop.run()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 194, in run
    self.on_run_start(*args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 161, in on_run_start
    _ = iter(data_fetcher)  # creates the iterator inside the fetcher
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 179, in __iter__
    self._apply_patch()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 120, in _apply_patch
    apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 156, in loader_iters
    return self.dataloader_iter.loader_iters
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 555, in loader_iters
    self._loader_iters = self.create_loader_iters(self.loaders)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 595, in create_loader_iters
    return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 47, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 435, in __iter__
    return self._get_iterator()
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 381, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/Adam/.spack/.spack-env/view/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1034, in __init__
    w.start()
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/Adam/.spack/.spack-env/._view/rkeajhsveq36c3uivlwt2atdsppn5r3v/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'FitLoop.advance.<locals>.batch_to_device'. Did you mean: '_loader_iters'?

This didn't turn out to be helpful, but while digging around, I did notice a couple things that look sus:

Could it be possible that one of the components (Trainer, LightningModule) becomes unpickleable, but only during fit/validate? All 3 components are pickleable before and after fit/validate. The only reason that this bug surfaces at the moment is because the data loader points back to the LightningDataModule (because of preprocess), which points to Trainer/LightningModule.

I'm not really sure how to dig further without help from the PL devs. We know one possible fix (#992). It's possible that Trainer/LightningModule is temporarily unpickleable, but I also don't think they ever intended for anyone to try to pickle it in the first place.

@isaaccorley
Copy link
Collaborator

Is this even unique to torchgeo or is it just a PL problem in general? Wonder if this happens just when training using a PL CIFAR10 example on Windows.

@adamjstewart
Copy link
Collaborator

If you use data module instance methods in your parallel data loader (like we do in TorchGeo) then you should be able to reproduce it.

@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
@adamjstewart
Copy link
Collaborator

adamjstewart commented Mar 30, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants