New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Indices Operator #1735
base: main
Are you sure you want to change the base?
Indices Operator #1735
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1735 +/- ##
=======================================
Coverage 86.54% 86.54%
=======================================
Files 699 700 +1
Lines 83223 83255 +32
=======================================
+ Hits 72025 72056 +31
- Misses 11198 11199 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this 🙂
Implementation looks good, some minor comments on form.
@@ -285,6 +285,7 @@ Those operations are only available for `Int` tensors. | |||
| ------------------------------------------------ | ------------------------------------------------------- | | |||
| `tensor.arange(5..10, device) ` | `tensor.arange(start=5, end=10, device=device)` | | |||
| `tensor.arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | | |||
| `tensor.indices(shape, device)` | `torch.meshgrid(tensors)` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two notes:
- Usage should be
Tensor::indices(shape, device)
likeTensor::cat
orTensor::empty
. - Given the implementation here, there is no 1-to-1 equivalent for torch since meshgrid takes tensors as input and not the shape of the desired grid. In this case this is much closer to
numpy.indices
but the indexing is in cartesian space. The actual torch equivalent for the 2D example in your unit tests would be:
yv, xv = torch.meshgrid([torch.arange(2), torch.arange(2)], indexing='xy')
grid = torch.stack((xv, yv), 2)
So in this case, I'm not sure we would have to provide the comparison in the table.
/// Produces an indices tensor for the given shape and device. | ||
/// The resulting tensor contains coordinates corresponding to each element in the shape at dimension D. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `shape` - The shape specifying the dimensions of the tensor. | ||
/// * `device` - The device to create the tensor on. | ||
/// | ||
/// # Panics | ||
/// | ||
/// Panics if `D2` is not equal to `D+1`. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust | ||
/// use burn_tensor::Int; | ||
/// use burn_tensor::{backend::Backend, Shape, Tensor}; | ||
/// fn example<B: Backend>() { | ||
/// let device = Default::default(); | ||
/// let result = Tensor::<B, 2, Int>::indices::<3>(Shape { dims: [2, 3] }, &device); | ||
/// println!("{}", result); | ||
/// } | ||
/// ``` | ||
pub fn indices<const D2: usize>(shape: Shape<D>, device: &B::Device) -> Tensor<B, D2, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually indices will be produced to use with a matrix/tensor (i.e., indexing in i, j
like np.indices
) but I think the intention with this method is to produce a grid in cartesian space. With that in mind, I think I would make it a bit more explicit in the method's doc and rename the method to something like nd_grid
, meshgrid
or even more explicit cartesian_grid
(also opened to suggestions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cartesian_grid
seems to be a good name!
/// Produces an indices tensor for the given shape and device. | ||
/// The resulting tensor contains coordinates corresponding to each element in the shape at dimension D. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `shape` - The shape specifying the dimensions of the tensor. | ||
/// * `device` - The device to create the tensor on. | ||
/// | ||
/// # Panics | ||
/// | ||
/// Panics if `D2` is not equal to `D+1`. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust | ||
/// use burn_tensor::Int; | ||
/// use burn_tensor::{backend::Backend, Shape, Tensor}; | ||
/// fn example<B: Backend>() { | ||
/// let device = Default::default(); | ||
/// let result = Tensor::<B, 2, Int>::indices::<3>(Shape { dims: [2, 3] }, &device); | ||
/// println!("{}", result); | ||
/// } | ||
/// ``` | ||
pub fn indices<const D2: usize>(shape: Shape<D>, device: &B::Device) -> Tensor<B, D2, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of forcing users to pass a Shape
, we could have the api like this:
pub fn indices<S: Into<Shape<D>>, const D2: usize,>(shape: S, device: &B::Device) -> Tensor<B, D2, Int>
So an array could be provided by the user and it would work.
@nathanielsimard what do you think about the naming for this method? (see my comments for possible suggestions) |
/// Produces an indices tensor for the given shape and device. | ||
/// The resulting tensor contains coordinates corresponding to each element in the shape at dimension D. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `shape` - The shape specifying the dimensions of the tensor. | ||
/// * `device` - The device to create the tensor on. | ||
/// | ||
/// # Panics | ||
/// | ||
/// Panics if `D2` is not equal to `D+1`. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust | ||
/// use burn_tensor::Int; | ||
/// use burn_tensor::{backend::Backend, Shape, Tensor}; | ||
/// fn example<B: Backend>() { | ||
/// let device = Default::default(); | ||
/// let result = Tensor::<B, 2, Int>::indices::<3>(Shape { dims: [2, 3] }, &device); | ||
/// println!("{}", result); | ||
/// } | ||
/// ``` | ||
pub fn indices<const D2: usize>(shape: Shape<D>, device: &B::Device) -> Tensor<B, D2, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cartesian_grid
seems to be a good name!
/// println!("{}", result); | ||
/// } | ||
/// ``` | ||
pub fn indices<const D2: usize>(shape: Shape<D>, device: &B::Device) -> Tensor<B, D2, Int> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move this method to the backend API so that backends can optimize it, since calling a lot of arange and repeat can be very expansive for big matrices. We should keep default implementation in the backend definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@McArthur-Alford if you're not sure what that means, see for example the narrow
op. It is defined by a default implementation but also overridden by some backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cartesian_grid
sounds good to me. Ill get started on moving it to a backend op.
Indices Operator for Int Tensors
Checklist
run-checks all
script has been executed.Related Issues/PRs
None
Changes
Added a indices function for int tensors. This is similar to pytorches
meshgrid
, or numpysindices
functions though with slightly different arrangement. For example, the output ofTensor::<B, 2, Int>::indices::<3>(Shape { dims: [2, 3] }, &device);
would be:Testing
Added a super basic but functional test to make sure indices produces some expected typical outputs.