Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change DEMs from mask to image (is_image=True) #1776

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
Dataset,Type,Source,License,Size (px),Resolution (m)
`Aboveground Woody Biomass`_,Masks,"Landsat, LiDAR","CC-BY-4.0","40,000x40,000",30
`Aster Global DEM`_,Masks,Aster,"public domain","3,601x3,601",30
`Aster Global DEM`_,Digital Elevation Model,Aster,"public domain","3,601x3,601",30
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
`Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,-
`Chesapeake Land Cover`_,"Imagery, Masks",NAIP,"CC-BY-4.0",-,1
`Global Mangrove Distribution`_,Masks,"Remote Sensing, In Situ Measurements","public domain",-,3
`Cropland Data Layer`_,Masks,Landsat,"public domain",-,30
`EDDMapS`_,Points,Citizen Scientists,-,-,-
`EnviroAtlas`_,"Imagery, Masks","NAIP, NLCD, OpenStreetMap","CC-BY-4.0",-,1
`Esri2020`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`EU-DEM`_,Masks,"Aster, SRTM, Russian Topomaps","CSCDA-ESA",-,25
`EU-DEM`_,Digital Elevation Model,"Aster, SRTM, Russian Topomaps","CSCDA-ESA",-,25
`GBIF`_,Points,Citizen Scientists,"CC0-1.0 OR CC-BY-4.0 OR CC-BY-NC-4.0",-,-
`GlobBiomass`_,Masks,Landsat,"CC-BY-4.0","45,000x45,000",100
`iNaturalist`_,Points,Citizen Scientists,-,-,-
Expand Down
6 changes: 5 additions & 1 deletion docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@
"\n",
"### `is_image`\n",
"\n",
"If your data only contains image files, as is the case with Sentinel-2, use `is_image = True`. If your data only contains segmentation masks, use `is_image = False` instead.\n",
"If your dataset only contains source data, such as image files, like Sentinel-2, or a digital surface, like a Digital Elevation Model, Digital Surface Model, Digital Terrain Model, or a raster of temperature values, use `is_image = True`. If your dataset only contains target data, such as a segmentation mask, like land use or land cover classification, use `is_image = False` instead.\n",
"\n",
"### `dtype`\n",
"\n",
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what is usually wanted for 99% of datasets but can be overridden for integer images (like some Digital Elevation Models) or pixel-wise regression masks (where the target should be float32). Uint16 and uint32 are automatically cast to int32 and int64, respectively, because numpy supports the former but torch does not.\n",
"\n",
"### `separate_files`\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/data/astergdem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
# Create image file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

Expand Down
2 changes: 1 addition & 1 deletion tests/data/eudem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
# Create image file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)

Expand Down
9 changes: 1 addition & 8 deletions tests/datasets/test_astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["image"], torch.Tensor)

def test_and(self, dataset: AsterGDEM) -> None:
ds = dataset & dataset
Expand All @@ -55,13 +55,6 @@ def test_plot(self, dataset: AsterGDEM) -> None:
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: AsterGDEM) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: AsterGDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
Expand Down
9 changes: 1 addition & 8 deletions tests/datasets/test_eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_getitem(self, dataset: EUDEM) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["image"], torch.Tensor)

def test_extracted_already(self, dataset: EUDEM) -> None:
assert isinstance(dataset.paths, str)
Expand Down Expand Up @@ -70,13 +70,6 @@ def test_plot(self, dataset: EUDEM) -> None:
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: EUDEM) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: EUDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
Expand Down
54 changes: 29 additions & 25 deletions torchgeo/datasets/astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS

Expand Down Expand Up @@ -36,7 +37,8 @@ class AsterGDEM(RasterDataset):
.. versionadded:: 0.3
"""

is_image = False
is_image = True
all_bands = ["elevation"]
filename_glob = "ASTGTMV003_*_dem*"
filename_regex = r"""
(?P<name>[ASTGTMV003]{10})
Expand Down Expand Up @@ -74,8 +76,11 @@ def __init__(
self.paths = paths

self._verify()
bands = self.all_bands

super().__init__(paths, crs, res, transforms=transforms, cache=cache)
super().__init__(
paths, crs, res, bands=bands, transforms=transforms, cache=cache
)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
Expand All @@ -85,6 +90,17 @@ def _verify(self) -> None:

raise DatasetNotFoundError(self)

@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset.

Overrides the dtype property of RasterDataset to return `torch.long`.

Returns:
the dtype of the dataset cast to a torch dtype
"""
return torch.long

def plot(
self,
sample: dict[str, Any],
Expand All @@ -101,29 +117,17 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(prediction)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")
image = sample["image"][0]

image = torch.clamp(image, min=0, max=1)

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Elevation")

if suptitle is not None:
plt.suptitle(suptitle)
Expand Down
43 changes: 18 additions & 25 deletions torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Optional, Union

import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS

Expand Down Expand Up @@ -46,7 +47,8 @@ class EUDEM(RasterDataset):
.. versionadded:: 0.3
"""

is_image = False
is_image = True
all_bands = ["elevation"]
filename_glob = "eu_dem_v11_*.TIF"
zipfile_glob = "eu_dem_v11_*[A-Z0-9].zip"
filename_regex = "(?P<name>[eudem_v11]{10})_(?P<id>[A-Z0-9]{6})"
Expand Down Expand Up @@ -114,8 +116,11 @@ def __init__(
self.checksum = checksum

self._verify()
bands = self.all_bands

super().__init__(paths, crs, res, transforms=transforms, cache=cache)
super().__init__(
paths, crs, res, bands=bands, transforms=transforms, cache=cache
)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
Expand Down Expand Up @@ -152,29 +157,17 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(pred)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")
image = sample["image"][0]

image = torch.clamp(image, min=0, max=1)

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(image)
ax.axis("off")

if show_titles:
ax.set_title("Elevation")

if suptitle is not None:
plt.suptitle(suptitle)
Expand Down
27 changes: 26 additions & 1 deletion torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
(e.g. Landsat and CDL)
* Combine datasets for multiple image sources for multimodal learning or data fusion
(e.g. Landsat and Sentinel)
* Combine image and digital surface (e.g., elevation, temperature,
pressure) and sample from both simultaneously (e.g. Sentinel-2 and an Aster
Global DEM tile)


These combinations require that all queries are present in *both* datasets,
and can be combined using an :class:`IntersectionDataset`:
Expand Down Expand Up @@ -342,7 +346,21 @@ class RasterDataset(GeoDataset):
#: ``start`` and ``stop`` groups.
date_format = "%Y%m%d"

#: True if dataset contains imagery, False if dataset contains mask
#: True if the dataset contains imagery or a digital surface, such as a Digital
#: Elevation Model, or another source of source data. False if the dataset
#: contains a mask,that is the target data, normally categorical data like land
#: cover classification.
#:
#: The value of ``is_image`` controls two things.
#:
#: The first is ``dtype``. See below for the impacts.
#:
#: The second is the name of the key used in the sample returned by
#: __getitem__. ``True`` uses "image"; ``False`` uses "mask". For consistency
#: purposes the same names as Kornia are used. When multiple datasets with different
#: keys are combined and the same key is used for multiple datasets, for example 2
#: "image" and 1 "mask", the channels will be stacked so that there's still a single
#: value for that key.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can make this more succinct like the other attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

is_image = True

#: True if data is stored in a separate file for each band, else False.
Expand All @@ -361,6 +379,13 @@ class RasterDataset(GeoDataset):
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).

Defaults to float32 for is_image = True and long for is_image = False. This is
what we usually want for 99% of datasets but can be overridden for integer
images (like some imagery or Digital Elevation Models) or pixel-wise regression
masks (where it should be float32). Uint16 and uint32 are automatically cast to
int32 and int64, respectively, because numpy supports the former but torch does
not.

Returns:
the dtype of the dataset

Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def plot(
# Add masks
if show_feats in {"masks", "both"} and "masks" in sample:
mask = masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
contours = find_contours(mask, 0.5) # _ type: ignore[no-untyped-call]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to get it to pass validation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should pass fine in CI without modification. It may not pass locally if you don't have pyvista installed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I'll revert the change and install pyvista locally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, I still get this error from mypy when I run the suggested validation checks: torchgeo/datasets/vhr10.py:476: error: Unused "type: ignore" comment [unused-ignore]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That suggests you still don't have pyvista installed in the same environment where mypy is installed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this is blocking my commit.
mypy.....................................................................Failed

  • hook id: mypy
  • exit code: 1

torchgeo/datasets/vhr10.py:476: error: Unused "type: ignore" comment [unused-ignore]
torchgeo/datasets/vhr10.py:528: error: Unused "type: ignore" comment [unused-ignore]
Found 2 errors in 1 file (checked 292 source files)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why I don't use precommit 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can fix this by adding pyvista to precommit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by #1781. You can undo this change now.

for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
Expand Down Expand Up @@ -525,7 +525,9 @@ def plot(
# Add masks
if show_pred_masks:
mask = prediction_masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
contours = find_contours(
mask, 0.5
) # _ type: ignore[no-untyped-call]
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
Expand Down