Skip to content

stared/pytorch-named-dims

Repository files navigation

pytorch-named-dims

PyTorch tensor dimension names for all nn.Modules.

Extends PyTorch Named Tensors (new in PyTorch 1.4.0, still experimental as of PyTorch 1.5.0). It works in Python 3.6+.

Inspired by:

Installation

Not yet on PyPI. Install:

pip install git+git://github.com/stared/pytorch-named-dims.git

Example

import torch
from torch import nn
from pytorch_named_dims import nm

convs = nn.Sequential(
    nm.Conv2d(3, 5, kernel_size=3, padding=1),
    nn.ReLU(),  # preserves dims on its own
    nm.MaxPool2d(2, 2),
    nm.Conv2d(5, 2, kernel_size=3, padding=1)
)

x_input_1 = torch.rand((4, 3, 2, 2), names=('N', 'C', 'H', 'W'))  # good
x_input_2 = torch.rand((4, 3, 2, 2), names=('N', 'C', 'W', 'H'))  # bad

convs(x_input_1)  # returns ('N', 'C', 'H', 'W')
convs(x_input_2)  # raises:
# Layer Conv2d requires dimensions ['N', 'C', 'H', 'W'] but got ('N', 'C', 'W', 'H') instead.
  • TODO: Colab

Funding

Project is supported by Program Operacyjny Inteligentny Rozwój grant for ECC Games for GearShift project.

About

PyTorch tensor dimension names for all nn.Modules

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages