Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Panic w/ backwards pass when combining gather and max_dim #1687

Open
MichaelGoodale opened this issue Apr 24, 2024 · 2 comments
Open

Panic w/ backwards pass when combining gather and max_dim #1687

MichaelGoodale opened this issue Apr 24, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@MichaelGoodale
Copy link

MichaelGoodale commented Apr 24, 2024

Hi, there seems to be a problem with keeping track of the number of dimensions when doing some kind of combination of max_dim and gather. The following code will lead to a panic complaining about the number of dimensions, while it won't have any issue if we get rid of the max_dim line. This also doesn't seem related to any specific backend: I found it initially when using the tch backend

To Reproduce

#[test]
fn whatthe() {
    let a: Vec<f32> = vec![0.0, 0.0];
    let b = [0, 0];
    let b: Tensor<Autodiff<NdArray>, 2, Int> =
        Tensor::from_data(Data::from(b.as_slice()), &NdArrayDevice::default()).reshape([2, 1]);
    let a = Tensor::from_data(Data::from(a.as_slice()), &NdArrayDevice::default())
        .reshape([2, 1])
        .require_grad();

    let loss = a.gather(1, b);
    let loss = loss.clone().max_dim(0) + loss; //No panic if this line is commented out
    let loss = loss.sum();
    let g = loss.backward();
}

When run, produces the following output:

---- tests::whatthe stdout ----
thread '<unnamed>' panicked at /home/michael/.cargo/git/checkouts/burn-178c6829f420dae1/886a1de/crates/burn-ndarray/src/ops/base.rs:442:17:
Unsupported dimension, only the last dimension can differ: Tensor [2, 1] Index [1, 1]
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread 'tests::whatthe' panicked at /home/michael/.cargo/git/checkouts/burn-178c6829f420dae1/886a1de/crates/burn-autodiff/src/runtime/mspc.rs:91:25:
Error during backward RecvError

@antimora antimora added the bug Something isn't working label Apr 24, 2024
@nathanielsimard
Copy link
Member

nathanielsimard commented Apr 25, 2024

I was only able to reproduce the bug on the ndarray backend, it seems to work on the tch backend. You can see the test on the branch: fix/max_dim_gather

@MichaelGoodale
Copy link
Author

Hmm, odd that it didn't reproduce with tch! I had whittled down the example to be minimal, and indeed, that one doesn't cause a panic on tch for me either. Perhaps this is two bugs in a trenchcoat pretending to be one!

Here's a specific snippet which does crash for me:

    let a: Vec<f32> = vec![-0.35060948, -0.6759874, -1.2398422, -0.55234957];
    let b = [2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2];
    let b: Tensor<Autodiff<LibTorch>, 2, Int> =
        Tensor::from_data(Data::from(b.as_slice()), &LibTorchDevice::default()).reshape([5, 4]);
    let a = Tensor::from_data(Data::from(a.as_slice()), &LibTorchDevice::default())
        .reshape([1, 4])
        .require_grad();
    let grammar: Tensor<_, 2> = a.clone().repeat(0, 5);
    let loss = grammar.gather(1, b);
    let loss = loss.clone().max_dim(0) + loss;
    let loss = loss.sum();
    let g = loss.backward();

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants