Skip to content

Commit

Permalink
Remove __module__ hacks (#976)
Browse files Browse the repository at this point in the history
* Remove __module__ hacks

* Remove unused imports

* Fix typo in reference

* Explicit link
  • Loading branch information
adamjstewart committed Dec 26, 2022
1 parent b7add14 commit fd06985
Show file tree
Hide file tree
Showing 39 changed files with 3 additions and 174 deletions.
5 changes: 0 additions & 5 deletions tests/models/test_changestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@
import pytest
import torch
import torch.nn as nn
from torch.nn.modules import Module

from torchgeo.models import ChangeMixin, ChangeStar, ChangeStarFarSeg

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"

BACKBONE = ["resnet18", "resnet34", "resnet50", "resnet101"]
IN_CHANNELS = [64, 128]
INNNR_CHANNELS = [16, 32, 64]
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,3 @@
"Vaihingen2DDataModule",
"XView2DataModule",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.datamodules"
4 changes: 0 additions & 4 deletions torchgeo/datamodules/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

from ..datasets import BigEarthNet

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class BigEarthNetDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the BigEarthNet dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class ChesapeakeCVPRDataModule(LightningDataModule):
"""LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from ..datasets import COWCCounting

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class COWCCountingDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the COWC Counting dataset."""
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from ..datasets import TropicalCyclone

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class TropicalCycloneDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NASA Cyclone dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from ..datasets import FAIR1M
from .utils import dataset_split

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable number of boxes.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from ..datasets import LandCoverAI

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class LandCoverAIDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LandCover.ai dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

from ..datasets import LoveDA

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class LoveDADataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the LoveDA dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class NAIPChesapeakeDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the NAIP and Chesapeake datasets.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from ..datasets import NASAMarineDebris
from .utils import dataset_split

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

from ..datasets import RESISC45

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class RESISC45DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the RESISC45 dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

from ..datasets import SEN12MS

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class SEN12MSDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SEN12MS dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

from ..datasets import So2Sat

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class So2SatDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the So2Sat dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
from ..datasets import SpaceNet1
from .utils import dataset_split

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class SpaceNet1DataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the SpaceNet1 dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datamodules/ucmerced.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

from ..datasets import UCMerced

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class UCMercedDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the UC Merced dataset.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,3 @@
"stack_samples",
"unbind_samples",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.datasets"
9 changes: 2 additions & 7 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@

from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Dataset.__module__ = "torch.utils.data"
ImageFolder.__module__ = "torchvision.datasets"


class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets containing geospatial information.
Expand Down Expand Up @@ -275,8 +270,8 @@ class RasterDataset(GeoDataset):
#:
#: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion
#:
#: When :attr:`separate_files`` is True, the following additional groups are
#: searched for to find other files:
#: When :attr:`~RasterDataset.separate_files` is True, the following additional
#: groups are searched for to find other files:
#:
#: * ``band``: replaced with requested band name
filename_regex = ".*"
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,3 @@
from .qr import QRLoss, RQLoss

__all__ = ("QRLoss", "RQLoss")

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.losses"
4 changes: 0 additions & 4 deletions torchgeo/losses/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import torch.nn.functional as F
from torch.nn.modules import Module

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"


class QRLoss(Module):
"""The QR (forward) loss between class probabilities and predictions.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,3 @@
"RCF",
"resnet50",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.models"
4 changes: 0 additions & 4 deletions torchgeo/models/changestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

from .farseg import FarSeg

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"


class ChangeMixin(Module):
"""This module enables any segmentation model to detect binary change.
Expand Down
12 changes: 0 additions & 12 deletions torchgeo/models/farseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@
from torchvision.models import resnet
from torchvision.ops import FeaturePyramidNetwork as FPN

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"
ModuleList.__module__ = "nn.ModuleList"
Sequential.__module__ = "nn.Sequential"
Conv2d.__module__ = "nn.Conv2d"
BatchNorm2d.__module__ = "nn.BatchNorm2d"
ReLU.__module__ = "nn.ReLU"
UpsamplingBilinear2d.__module__ = "nn.UpsamplingBilinear2d"
Sigmoid.__module__ = "nn.Sigmoid"
Identity.__module__ = "nn.Identity"


class FarSeg(Module):
"""Foreground-Aware Relation Network (FarSeg).
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from torch import Tensor
from torch.nn.modules import Module

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"


class FCN(Module):
"""A simple 5 layer FCN with leaky relus and 'same' padding."""
Expand Down
2 changes: 0 additions & 2 deletions torchgeo/models/fcsiam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from segmentation_models_pytorch.base.model import SegmentationModel
from torch import Tensor

Unet.__module__ = "segmentation_models_pytorch"


class FCSiamConc(SegmentationModel): # type: ignore[misc]
"""Fully-convolutional Siamese Concatenation (FC-Siam-conc).
Expand Down
5 changes: 1 addition & 4 deletions torchgeo/models/rcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Conv2d, Module

Module.__module__ = "torch.nn"
Conv2d.__module__ = "torch.nn"
from torch.nn.modules import Module


class RCF(Module):
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,3 @@
# Constants
"Units",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.samplers"
4 changes: 0 additions & 4 deletions torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from .constants import Units
from .utils import _to_tuple, get_random_bounding_box, tile_to_chips

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"


class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from .constants import Units
from .utils import _to_tuple, get_random_bounding_box, tile_to_chips

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"


class GeoSampler(Sampler[BoundingBox], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Expand Down
4 changes: 0 additions & 4 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,3 @@
"RegressionTask",
"SemanticSegmentationTask",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
globals()[module].__module__ = "torchgeo.trainers"
4 changes: 0 additions & 4 deletions torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@

from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "torch.nn"


def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
"""Computes the normalized mean squared error between x and y.
Expand Down
6 changes: 0 additions & 6 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch.nn as nn
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import (
Expand All @@ -27,11 +26,6 @@
from ..datasets.utils import unbind_samples
from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"


class ClassificationTask(pl.LightningModule):
"""LightningModule for image classification.
Expand Down
5 changes: 0 additions & 5 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from packaging.version import parse
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
Expand All @@ -36,10 +35,6 @@

from ..datasets.utils import unbind_samples

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
DataLoader.__module__ = "torch.utils.data"


class ObjectDetectionTask(pl.LightningModule):
"""LightningModule for object detection of images.
Expand Down

0 comments on commit fd06985

Please sign in to comment.