Skip to content

Commit

Permalink
Getting rid of pin_memory=False and proper typing of kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Sep 9, 2021
1 parent 9a180e6 commit 9780216
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 20 deletions.
5 changes: 1 addition & 4 deletions torchgeo/trainers/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ChesapeakeCVPRSegmentationTask(LightningModule):
``pytorch_segmentation_models`` package.
"""

def config_task(self, kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
Expand Down Expand Up @@ -341,7 +341,6 @@ def train_dataloader(self) -> DataLoader[Any]:
self.train_dataset,
batch_sampler=sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand All @@ -356,7 +355,6 @@ def val_dataloader(self) -> DataLoader[Any]:
self.val_dataset,
batch_sampler=sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand All @@ -371,5 +369,4 @@ def test_dataloader(self) -> DataLoader[Any]:
self.test_dataset,
batch_sampler=sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)
5 changes: 1 addition & 4 deletions torchgeo/trainers/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CycloneSimpleRegressionTask(pl.LightningModule):
This does not take into account other per-sample features available in this dataset.
"""

def config_task(self, kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["model"] == "resnet18":
self.model = models.resnet18(pretrained=False, num_classes=1)
Expand Down Expand Up @@ -237,7 +237,6 @@ def train_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand All @@ -247,7 +246,6 @@ def val_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand All @@ -257,5 +255,4 @@ def test_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)
5 changes: 1 addition & 4 deletions torchgeo/trainers/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LandcoverAISegmentationTask(pl.LightningModule):
``pytorch_segmentation_models`` package.
"""

def config_task(self, kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
Expand Down Expand Up @@ -300,7 +300,6 @@ def train_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand All @@ -310,7 +309,6 @@ def val_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand All @@ -320,5 +318,4 @@ def test_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)
5 changes: 1 addition & 4 deletions torchgeo/trainers/naipchesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class NAIPChesapeakeSegmentationTask(pl.LightningModule):
# TODO: tune this hyperparam
num_filters = 64

def config_task(self, kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
Expand Down Expand Up @@ -327,7 +327,6 @@ def train_dataloader(self) -> DataLoader[Any]:
self.dataset,
batch_sampler=self.train_sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand All @@ -337,7 +336,6 @@ def val_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
sampler=self.val_sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand All @@ -347,5 +345,4 @@ def test_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
sampler=self.test_sampler, # type: ignore[arg-type]
num_workers=self.num_workers,
pin_memory=False,
)
5 changes: 1 addition & 4 deletions torchgeo/trainers/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SEN12MSSegmentationTask(pl.LightningModule):
``pytorch_segmentation_models`` package.
"""

def config_task(self, kwargs: Dict[str, Any]) -> None:
def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
self.model = smp.Unet(
Expand Down Expand Up @@ -305,7 +305,6 @@ def train_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=False,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand All @@ -315,7 +314,6 @@ def val_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand All @@ -325,5 +323,4 @@ def test_dataloader(self) -> DataLoader[Any]:
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=False,
)

0 comments on commit 9780216

Please sign in to comment.