Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of train.py and adding Chesapeake CVPR trainer #103

Merged
merged 15 commits into from
Sep 9, 2021
Merged
30 changes: 28 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from typing import Any, Callable, Dict, List, Optional

import fiona
import numpy as np
import pyproj
import rasterio
import rasterio.mask
import shapely.geometry
import shapely.ops
import torch
from rasterio.crs import CRS

from .geo import GeoDataset, RasterDataset
Expand Down Expand Up @@ -291,6 +293,9 @@ class ChesapeakeCVPR(GeoDataset):
filename = "cvpr_chesapeake_landcover.zip"
md5 = "0ea5e7cb861be3fb8a06fedaaaf91af9"

crs = CRS.from_epsg(3857)
res = 1

valid_layers = [
"naip-new",
"naip-old",
Expand Down Expand Up @@ -402,6 +407,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
filepaths = [hit.object for hit in hits]

sample = {
"image": [],
"mask": [],
"crs": self.crs,
"bbox": query,
}
Expand Down Expand Up @@ -436,11 +443,30 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
f, [query_geom_transformed], crop=True, all_touched=True
)

sample[layer] = data.squeeze()

if layer in [
"naip-new",
"naip-old",
"landsat-leaf-on",
"landsat-leaf-off",
]:
sample["image"].append(data)
elif layer in ["lc", "nlcd", "buildings"]:
sample["mask"].append(data)
else:
raise IndexError(f"query: {query} spans multiple tiles which is not valid")

sample["image"] = np.concatenate( # type: ignore[no-untyped-call]
sample["image"], axis=0
)
sample["mask"] = np.concatenate( # type: ignore[no-untyped-call]
sample["mask"], axis=0
)

sample["image"] = torch.from_numpy( # type: ignore[attr-defined]
sample["image"]
)
sample["mask"] = torch.from_numpy(sample["mask"]) # type: ignore[attr-defined]

if self.transforms is not None:
sample = self.transforms(sample)

Expand Down