Skip to content

How do I implement the following pytorch operations in burn? #1368

Answered by nathanielsimard
zemelLeong asked this question in Q&A
Discussion options

You must be logged in to vote
let tensor = Tensor::random(..);
let [b, d1, d2] = tensor.dims();
let tensor_partial = tensor.slice([0..b, 0..d1, 0..1]);

You have to specify the range of each dimension; at some point, we might have syntax sugar to avoid specifying the start and end position of dimensions where you want to keep everything, maybe something like that:

let tensor_partial = tensor.slice(index![.., .., 0]);

We might also implement the Index trait with multiple different generics: https://doc.rust-lang.org/std/ops/trait.Index.html, dispatching to slice when passing Range and to select when using an int tensor.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@zemelLeong
Comment options

Answer selected by zemelLeong
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants