You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, there seems to be a problem with keeping track of the number of dimensions when doing some kind of combination of max_dim and gather. The following code will lead to a panic complaining about the number of dimensions, while it won't have any issue if we get rid of the max_dim line. This also doesn't seem related to any specific backend: I found it initially when using the tch backend
To Reproduce
#[test]fnwhatthe(){let a:Vec<f32> = vec![0.0, 0.0];let b = [0,0];let b:Tensor<Autodiff<NdArray>,2,Int> =
Tensor::from_data(Data::from(b.as_slice()),&NdArrayDevice::default()).reshape([2,1]);let a = Tensor::from_data(Data::from(a.as_slice()),&NdArrayDevice::default()).reshape([2,1]).require_grad();let loss = a.gather(1, b);let loss = loss.clone().max_dim(0) + loss;//No panic if this line is commented outlet loss = loss.sum();let g = loss.backward();}
When run, produces the following output:
---- tests::whatthe stdout ----
thread '<unnamed>' panicked at /home/michael/.cargo/git/checkouts/burn-178c6829f420dae1/886a1de/crates/burn-ndarray/src/ops/base.rs:442:17:
Unsupported dimension, only the last dimension can differ: Tensor [2, 1] Index [1, 1]
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread 'tests::whatthe' panicked at /home/michael/.cargo/git/checkouts/burn-178c6829f420dae1/886a1de/crates/burn-autodiff/src/runtime/mspc.rs:91:25:
Error during backward RecvError
The text was updated successfully, but these errors were encountered:
I was only able to reproduce the bug on the ndarray backend, it seems to work on the tch backend. You can see the test on the branch: fix/max_dim_gather
Hmm, odd that it didn't reproduce with tch! I had whittled down the example to be minimal, and indeed, that one doesn't cause a panic on tch for me either. Perhaps this is two bugs in a trenchcoat pretending to be one!
Here's a specific snippet which does crash for me:
let a:Vec<f32> = vec![-0.35060948, -0.6759874, -1.2398422, -0.55234957];let b = [2,2,2,3,2,2,3,2,2,2,3,2,2,2,2,3,2,3,2,2];let b:Tensor<Autodiff<LibTorch>,2,Int> =
Tensor::from_data(Data::from(b.as_slice()),&LibTorchDevice::default()).reshape([5,4]);let a = Tensor::from_data(Data::from(a.as_slice()),&LibTorchDevice::default()).reshape([1,4]).require_grad();let grammar:Tensor<_,2> = a.clone().repeat(0,5);let loss = grammar.gather(1, b);let loss = loss.clone().max_dim(0) + loss;let loss = loss.sum();let g = loss.backward();
Hi, there seems to be a problem with keeping track of the number of dimensions when doing some kind of combination of
max_dim
andgather
. The following code will lead to a panic complaining about the number of dimensions, while it won't have any issue if we get rid of themax_dim
line. This also doesn't seem related to any specific backend: I found it initially when using the tch backendTo Reproduce
When run, produces the following output:
The text was updated successfully, but these errors were encountered: