Skip to content

How can I use boolean indexing in Burn? #1585

Answered by laggui
zemelLeong asked this question in Q&A
Discussion options

You must be logged in to vote

On main we added tensor.nonzero() and tensor.argwhere(), so you could do the following for your original example:

let t = Tensor::<B, 3>::from_floats(
    [[
        [0.5213, 0.6049, 0.2158, 0.7163],
        [0.4655, 0.7438, 0.6514, 0.6525],
        [0.5567, 0.9781, 0.9310, 0.3846],
        [0.8512, 0.7049, 0.5219, 0.9497],
        [0.9796, 0.7220, 0.7281, 0.3046],
        [0.9927, 0.6197, 0.5130, 0.1818],
        [0.7108, 0.9334, 0.4279, 0.8117],
        [0.7960, 0.3307, 0.8622, 0.3465],
        [0.5505, 0.5056, 0.0849, 0.0585],
        [0.9278, 0.5415, 0.5889, 0.7620],
    ]],
    &device,
);
let mask = t.clone().slice([0..1, 0..10, 3..4]).greater_elem(0.5);
let indices = mask.nonzero();

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by laggui
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants