Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notes on incorporating voxel shift maps and per-volume transforms #184

Open
effigies opened this issue Oct 11, 2023 · 0 comments
Open

Notes on incorporating voxel shift maps and per-volume transforms #184

effigies opened this issue Oct 11, 2023 · 0 comments

Comments

@effigies
Copy link
Member

Voxel shift maps

As mentioned in today's call, I think that adding the full fieldmap into nitransforms would be difficult. But thinking a bit more about voxel shifts, I think we can do that fairly straightforwardly as an argument to apply():

class TransformBase:

    def apply(
        self,
        spatialimage,
        reference=None,
        voxel_shift_map=None,
        order=3,
        mode="constant",
        cval=0.0,
        prefilter=True,
        output_dtype=None,
    ):
        ...
        # Ignoring transposes and homogeneous coordinates for brevity
        rascoords = self.map(reference.ndcoords)
        voxcoords = Affine(spatialimage.affine).map(rascoords).reshape((reference.ndim, *reference.shape))
        if voxel_shift_map:
            # voxel_shift_map must have shape (reference.ndim, *reference.shape)
            # Alternately, we could accept it in (*reference.shape, reference.ndim) and roll axes
            voxcoords += voxel_shift_map

        resampled = ndi.map_coordinates(
            data,
            voxcoords,
            output=output_dtype,
            order=order,
            mode=mode,
            cval=cval,
            prefilter=prefilter,
        )

Because map operates on RAS coordinates and not voxel indices, we cannot use it in that context, so we probably do not want to include it as part of the transform itself.

We specifically do not want to describe voxel shift maps in the world space of the target image. While it may be possible to fit it at the end of the chain, after motion correction transforms, any solution would be more complicated than the above.

Per-volume transformations

The above discussion works for an individual volume. In order to correctly handle VSMs in a motion-corrected frame, we need TransformChains to become aware that they are involved in a per-volume transform. Unfortunately, right now, TransformChains are iterable over transforms, while LinearTransformsMapping are iterable over volumes, which at the very least means straightforward API composition isn't going to work.

Currently, LinearTransformsMapping operates in apply():

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Parameters
----------
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is "constant".
cval : float, optional
Constant value for ``mode="constant"``. Default is 0.0.
prefilter: bool, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
Returns
-------
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.
"""
if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))
_ref = (
self.reference if reference is None else SpatialReference.factory(reference)
)
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype
ycoords = self.map(_ref.ndcoords.T)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
)
if data.ndim == 4:
if len(self) != data.shape[-1]:
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), data.shape[-1])
)
targets = targets.reshape((len(self), -1, targets.shape[-1]))
resampled = np.stack(
[
ndi.map_coordinates(
data[..., t],
targets[t, ..., : _ref.ndim].T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
for t in range(data.shape[-1])
],
axis=0,
)
elif data.ndim in (2, 3):
resampled = ndi.map_coordinates(
data,
targets[..., : _ref.ndim].T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape((len(self), *_ref.shape))
moved = spatialimage.__class__(
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
)
moved.header.set_data_dtype(output_dtype)
return moved
return resampled

A VSM+multivolume-aware TransformChain could do what we want in apply(). Another thought is that we could treat transforms as data objects and not actors. The interface could be:

def apply_transform(
    source: SpatialImage,
    target: Pointset,
    transform: TransformBase,
    shift_map: np.ndarray,
    # map_coordinates args
    ...
) -> np.ndarray:
    ...

If we give up on defining apply() correctly for each transform, and leave them to focus on composing and mapping, it might make things cleaner. Just imagining how we might approach chains that include per-volume transforms:

class TransformBase:
    n_transforms: int = 1

    def iter_transforms(self) -> Iterator[TransformBase]:
        """Repeat current transform as often as required"""
        return itertools.repeat(self)

class AffineSeries(TransformBase):
    @property
    def n_transforms(self) -> int:
        return len(self.series)

    def iter_transforms(self) -> Iterator[TransformBase]:
        """Iterate over the defined series"""
        return iter(self.series)

class TransformChain(TransformBase):
    @property
    def n_transforms(self) -> int:
        lengths = [xfm.n_transforms for xfm in self.chain if xfm.n_transforms != 1]
        return min(lengths) if lengths else 1

    def iter_transforms(self) -> Iterator[TransformChain]:
        """Iterate over all transforms in chain, simultaneously, stopping with first to stop"""
        return map(TransformChain, zip(*(xfm.iter_transforms() for xfm in self.chain)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant