Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of train.py and adding Chesapeake CVPR trainer #103

Merged
merged 15 commits into from
Sep 9, 2021
Merged
24 changes: 15 additions & 9 deletions conf/defaults.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
config_file: null # The user can pass a filename here on the command line
config_file: null # This lets the user pass a config filename to load other arguments from

program: # These are the default arguments
batch_size: 32
num_workers: 4
program: # These are the arguments that define how the train.py script works
seed: 1337
experiment_name: ??? # This is OmegaConf syntax that makes this a required field
output_dir: output
data_dir: data
log_dir: logs
overwrite: False

task:
name: ??? # this must be defined so we can get the task specific arguments
experiment: # These are arugments specific to the experiment we are running
name: ??? # this is the name given to this experiment run
task: ??? # this is the type of task to use for this experiement (e.g. "landcoverai")
module: # these will be passed as kwargs to the LightningModule assosciated with the task
learning_rate: 1e-3
datamodule: # these will be passed as kwargs to the LightningDataModule assosciated with the task
root_dir: ${program.data_dir}
seed: ${program.seed}
batch_size: 32
num_workers: 4

# Taken from https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init

# The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
# this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
trainer:
trainer: # These are the parameters passed to the pytorch lightning Trainer object
logger: True
checkpoint_callback: True
callbacks: null
Expand Down
13 changes: 13 additions & 0 deletions conf/task_defaults/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
experiment:
task: "chesapeake_cvpr"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4
14 changes: 9 additions & 5 deletions conf/task_defaults/cyclone.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
task:
name: "cyclone"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
model: "resnet18"
experiment:
task: "cyclone"
module:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4
24 changes: 14 additions & 10 deletions conf/task_defaults/landcoverai.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
task:
name: "landcoverai"
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
experiment:
task: "landcoverai"
module:
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4
24 changes: 14 additions & 10 deletions conf/task_defaults/naipchesapeake.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
task:
name: "naipchesapeake"
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
experiment:
task: "naipchesapeake"
module:
loss: "ce"
segmentation_model: "deeplabv3+"
encoder_name: "resnet34"
encoder_weights: "imagenet"
encoder_output_stride: 16
optimizer: "adamw"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4
21 changes: 13 additions & 8 deletions conf/task_defaults/sen12ms.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
task:
name: "sen12ms"
learning_rate: 1e-3
learning_rate_schedule_patience: 2
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
experiment:
task: "sen12ms"
module:
loss: "ce"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: "imagenet"
encoder_output_stride: 16
learning_rate: 1e-3
learning_rate_schedule_patience: 2
datamodule:
batch_size: 32
num_workers: 4
30 changes: 15 additions & 15 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_output_file(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_file),
"task.name=test",
"experiment.task=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
Expand All @@ -43,9 +43,9 @@ def test_experiment_dir_not_empty(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=test",
"experiment.task=test",
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
Expand All @@ -64,11 +64,11 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"task.name=cyclone",
"experiment.task=cyclone",
"program.overwrite=True",
"trainer.fast_dev_run=1",
]
Expand All @@ -87,9 +87,9 @@ def test_invalid_task(task: str, tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=" + task,
"experiment.task=" + task,
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert ps.returncode != 0
Expand All @@ -102,9 +102,9 @@ def test_missing_config_file(tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"task.name=test",
"experiment.task=test",
"config_file=" + str(config_file),
]
ps = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand All @@ -120,12 +120,12 @@ def test_config_file(tmp_path: Path) -> None:
config_file.write_text(
f"""
program:
experiment_name: test
output_dir: {output_dir}
data_dir: {data_dir}
log_dir: {log_dir}
task:
name: cyclone
experiment:
name: test
task: cyclone
trainer:
fast_dev_run: true
"""
Expand All @@ -146,12 +146,12 @@ def test_tasks(task: str, tmp_path: Path) -> None:
args = [
sys.executable,
"train.py",
"program.experiment_name=test",
"experiment.name=test",
"program.output_dir=" + str(output_dir),
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"trainer.fast_dev_run=1",
"task.name=" + task,
"experiment.task=" + task,
"program.overwrite=True",
]
subprocess.run(args, check=True)
30 changes: 28 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from typing import Any, Callable, Dict, List, Optional

import fiona
import numpy as np
import pyproj
import rasterio
import rasterio.mask
import shapely.geometry
import shapely.ops
import torch
from rasterio.crs import CRS

from .geo import GeoDataset, RasterDataset
Expand Down Expand Up @@ -291,6 +293,9 @@ class ChesapeakeCVPR(GeoDataset):
filename = "cvpr_chesapeake_landcover.zip"
md5 = "0ea5e7cb861be3fb8a06fedaaaf91af9"

crs = CRS.from_epsg(3857)
res = 1

valid_layers = [
"naip-new",
"naip-old",
Expand Down Expand Up @@ -402,6 +407,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
filepaths = [hit.object for hit in hits]

sample = {
"image": [],
"mask": [],
"crs": self.crs,
"bbox": query,
}
Expand Down Expand Up @@ -436,11 +443,30 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
f, [query_geom_transformed], crop=True, all_touched=True
)

sample[layer] = data.squeeze()

if layer in [
"naip-new",
"naip-old",
"landsat-leaf-on",
"landsat-leaf-off",
]:
sample["image"].append(data)
elif layer in ["lc", "nlcd", "buildings"]:
sample["mask"].append(data)
else:
raise IndexError(f"query: {query} spans multiple tiles which is not valid")

sample["image"] = np.concatenate( # type: ignore[no-untyped-call]
sample["image"], axis=0
)
sample["mask"] = np.concatenate( # type: ignore[no-untyped-call]
sample["mask"], axis=0
)

sample["image"] = torch.from_numpy( # type: ignore[attr-defined]
sample["image"]
)
sample["mask"] = torch.from_numpy(sample["mask"]) # type: ignore[attr-defined]

if self.transforms is not None:
sample = self.transforms(sample)

Expand Down
3 changes: 3 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

"""TorchGeo trainers."""

from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask

__all__ = (
"ChesapeakeCVPRSegmentationTask",
"ChesapeakeCVPRDataModule",
"CycloneDataModule",
"CycloneSimpleRegressionTask",
"LandcoverAIDataModule",
Expand Down