Skip to content

Commit

Permalink
RasterDataset: add control over resampling algorithm (#2015)
Browse files Browse the repository at this point in the history
* RasterDataset: add control over resampling algorithm

* Fix type hints

* cubic -> bilinear

* Ruff: single quotes
  • Loading branch information
adamjstewart committed May 13, 2024
1 parent 5976bd1 commit 25fb9cc
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/tutorials/custom_raster_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@
"\n",
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n",
"\n",
"### `resampling`\n",
"\n",
"Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n",
"\n",
"### `separate_files`\n",
"\n",
"If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n",
Expand Down
28 changes: 28 additions & 0 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rasterio.enums import Resampling
from torch.utils.data import ConcatDataset

from torchgeo.datasets import (
Expand Down Expand Up @@ -49,6 +51,16 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': bounds}


class CustomRasterDataset(RasterDataset):
def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._dtype = dtype

@property
def dtype(self) -> torch.dtype:
return self._dtype


class CustomVectorDataset(VectorDataset):
filename_glob = '*.geojson'
date_format = '%Y'
Expand Down Expand Up @@ -274,6 +286,22 @@ def test_getitem_uint_dtype(self, dtype: str) -> None:
assert isinstance(x['image'], torch.Tensor)
assert x['image'].dtype == torch.float32

@pytest.mark.parametrize('dtype', [torch.float, torch.double])
def test_resampling_float_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join('tests', 'data', 'raster', 'uint16')
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
assert x['image'].dtype == dtype
assert ds.resampling == Resampling.bilinear

@pytest.mark.parametrize('dtype', [torch.long, torch.bool])
def test_resampling_int_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join('tests', 'data', 'raster', 'uint16')
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
assert x['image'].dtype == dtype
assert ds.resampling == Resampling.nearest

def test_invalid_query(self, sentinel: Sentinel2) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
Expand Down
24 changes: 22 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import shapely
import torch
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
Expand Down Expand Up @@ -309,7 +310,7 @@ def files(self) -> list[str]:
files |= set(glob.iglob(pathname, recursive=True))
elif os.path.isfile(path) or path_is_vsi(path):
files.add(path)
elif not hasattr(self, "download"):
elif not hasattr(self, 'download'):
warnings.warn(
f"Could not find any relevant files for provided path '{path}'. "
f'Path was ignored.',
Expand Down Expand Up @@ -384,6 +385,23 @@ def dtype(self) -> torch.dtype:
else:
return torch.long

@property
def resampling(self) -> Resampling:
"""Resampling algorithm used when reading input files.
Defaults to bilinear for float dtypes and nearest for int dtypes.
Returns:
The resampling method to use.
.. versionadded:: 0.6
"""
# Based on torch.is_floating_point
if self.dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]:
return Resampling.bilinear
else:
return Resampling.nearest

def __init__(
self,
paths: str | Iterable[str] = 'data',
Expand Down Expand Up @@ -555,7 +573,9 @@ def _merge_files(
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]

bounds = (query.minx, query.miny, query.maxx, query.maxy)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes)
dest, _ = rasterio.merge.merge(
vrt_fhs, bounds, self.res, indexes=band_indexes, resampling=self.resampling
)
# Use array_to_tensor since merge may return uint16/uint32 arrays.
tensor = array_to_tensor(dest)
return tensor
Expand Down

0 comments on commit 25fb9cc

Please sign in to comment.