Skip to content

Commit

Permalink
Some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Sep 9, 2021
1 parent aac2199 commit cb3d7e1
Showing 1 changed file with 77 additions and 22 deletions.
99 changes: 77 additions & 22 deletions torchgeo/trainers/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,35 @@ class ChesapeakeCVPRSegmentationTask(LightningModule):
``pytorch_segmentation_models`` package.
"""

def config_task(self, kwargs: Any) -> None:
"""Configures the task based on kwargs parameters."""
if kwargs["segmentation_model"] == "unet":
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
if self.hparams["segmentation_model"] == "unet":
self.model = smp.Unet(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
encoder_name=self.hparams["encoder_name"],
encoder_weights=self.hparams["encoder_weights"],
in_channels=4,
classes=7,
)
elif kwargs["segmentation_model"] == "deeplabv3+":
elif self.hparams["segmentation_model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=kwargs["encoder_name"],
encoder_weights=kwargs["encoder_weights"],
encoder_name=self.hparams["encoder_name"],
encoder_weights=self.hparams["encoder_weights"],
in_channels=4,
classes=7,
)
else:
raise ValueError(
f"Model type '{kwargs['segmentation_model']}' is not valid."
f"Model type '{self.hparams['segmentation_model']}' is not valid."
)

if kwargs["loss"] == "ce":
if self.hparams["loss"] == "ce":
self.loss = nn.CrossEntropyLoss( # type: ignore[attr-defined]
ignore_index=7
)
elif kwargs["loss"] == "jaccard":
elif self.hparams["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(mode="multiclass")
else:
raise ValueError(f"Loss type '{kwargs['loss']}' is not valid.")
raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.")

def __init__(
self,
Expand All @@ -82,11 +82,14 @@ def __init__(
encoder_weights: None or "imagenet" to use imagenet pretrained weights in
the encoder model
loss: Name of the loss function
Raises:
ValueError: if kwargs arguments are invalid
"""
super().__init__()
self.save_hyperparameters() # creates `self.hparams` from kwargs

self.config_task(kwargs)
self.config_task()

self.train_accuracy = Accuracy()
self.val_accuracy = Accuracy()
Expand All @@ -97,13 +100,28 @@ def __init__(
self.test_iou = IoU(num_classes=7)

def forward(self, x: Tensor) -> Any: # type: ignore[override]
"""Forward pass of the model."""
"""Forward pass of the model.
Args:
x: tensor of data to run through the model
Returns:
output from the model
"""
return self.model(x)

def training_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> Tensor:
"""Training step - reports average accuracy and average IoU."""
"""Training step - reports average accuracy and average IoU.
Args:
batch: Current batch
batch_idx: Index of current batch
Returns:
training loss
"""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
Expand All @@ -120,7 +138,11 @@ def training_step( # type: ignore[override]
return cast(Tensor, loss)

def training_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level training metrics."""
"""Logs epoch level training metrics.
Args:
outputs: list of items returned by training_step
"""
self.log("train_acc", self.train_accuracy.compute())
self.log("train_iou", self.train_iou.compute())
self.train_accuracy.reset()
Expand All @@ -129,7 +151,15 @@ def training_epoch_end(self, outputs: Any) -> None:
def validation_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Validation step - reports average accuracy and average IoU."""
"""Validation step - reports average accuracy and average IoU.
Logs the first 10 validation samples to tensorboard as images with 3 subplots
showing the image, mask, and predictions.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
Expand Down Expand Up @@ -168,7 +198,11 @@ def validation_step( # type: ignore[override]
plt.close()

def validation_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level validation metrics."""
"""Logs epoch level validation metrics.
Args:
outputs: list of items returned by validation_step
"""
self.log("val_acc", self.val_accuracy.compute())
self.log("val_iou", self.val_iou.compute())
self.val_accuracy.reset()
Expand All @@ -177,7 +211,12 @@ def validation_epoch_end(self, outputs: Any) -> None:
def test_step( # type: ignore[override]
self, batch: Dict[str, Any], batch_idx: int
) -> None:
"""Test step identical to the validation step."""
"""Test step identical to the validation step.
Args:
batch: Current batch
batch_idx: Index of current batch
"""
x = batch["image"]
y = batch["mask"]
y_hat = self.forward(x)
Expand All @@ -191,14 +230,22 @@ def test_step( # type: ignore[override]
self.test_iou(y_hat_hard, y)

def test_epoch_end(self, outputs: Any) -> None:
"""Logs epoch level test metrics."""
"""Logs epoch level test metrics.
Args:
outputs: list of items returned by test_step
"""
self.log("test_acc", self.test_accuracy.compute())
self.log("test_iou", self.test_iou.compute())
self.test_accuracy.reset()
self.test_iou.reset()

def configure_optimizers(self) -> Dict[str, Any]:
"""Initialize the optimizer and learning rate scheduler."""
"""Initialize the optimizer and learning rate scheduler.
Returns:
a "lr dict" according to the pytorch lightning documentation
"""
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=self.hparams["learning_rate"],
Expand Down Expand Up @@ -256,7 +303,15 @@ def __init__(
def custom_transform(
self, sample: Dict[str, Any], patch_size: int = 256
) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
"""Transform a single sample from the Dataset.
Args:
sample: a single sample from the Dataset
patch_size: size of center cropped patch to return
Returns:
a transformed sample
"""
# Center crop
num_image_channels, height, width = sample["image"].shape
num_mask_channels = sample["mask"].shape[0]
Expand Down

0 comments on commit cb3d7e1

Please sign in to comment.