Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeeze Onnx Import #1753

Merged
merged 17 commits into from May 17, 2024
Merged

Conversation

agelas
Copy link
Contributor

@agelas agelas commented May 11, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1714

Changes

Added the squeeze operation for to burn-import

Testing

  • onnx_tests
  • codegen_nodes test

Copy link

codecov bot commented May 11, 2024

Codecov Report

Attention: Patch coverage is 88.31169% with 18 lines in your changes are missing coverage. Please review.

Project coverage is 86.41%. Comparing base (1073752) to head (3e60613).
Report is 7 commits behind head on main.

Files Patch % Lines
crates/burn-import/src/onnx/dim_inference.rs 65.51% 10 Missing ⚠️
crates/burn-import/src/onnx/op_configuration.rs 77.14% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1753      +/-   ##
==========================================
- Coverage   86.61%   86.41%   -0.21%     
==========================================
  Files         700      735      +35     
  Lines       83423    85729    +2306     
==========================================
+ Hits        72258    74083    +1825     
- Misses      11165    11646     +481     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@agelas agelas marked this pull request as ready for review May 13, 2024 02:08
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for completing a missing ONNX op in Burn. One down!

You've got almost perfect. We just need a couple fixes:

  1. Add test case to verify axes by adding ONNX OpSet 13 (please see my inlined comments)
  2. Account for negative values in axes.

Comment on lines 977 to 979
match key.as_str() {
"axes" => return value.clone().into_i64s(),
_ => {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This will support ONNX OPset 13

Can we, to be sure it works, add a unit test for OPset 13? This is similar to unsqueeze with opset 16 and 13.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsqueeze uses OPset 16 and 11. I think my original one used 16, so I added one for 13. Is that what you meant?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, both should be good. OpSet 13 has axes attribute so it should work.

}
}
_ => panic!("Arg for squeeze must be tensor or scalar"),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axes can contain negative values, which means counting dimensions from the back (see squeeze spec), and Burn squeeze only supports positive values (see doc). So we should account for this.

We are already doing this for one dimension (see code) for gather_config, so you can see the logic. We need to do this for all items in axes.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! I have the same comments as already mentioned, and one question/comment regarding the squeeze op spec support.

This PR adds support for squeeze op on a single dim, so support is not complete but as long as it is explicit I have no issues adding full support if needed in another PR.

crates/burn-import/src/onnx/dim_inference.rs Outdated Show resolved Hide resolved
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! I have two minor comments and then it should be good to go! 🙂

crates/burn-import/onnx-tests/tests/squeeze/squeeze.onnx Outdated Show resolved Hide resolved
Comment on lines 272 to 278
if let Some(Data::Int64s(axes)) = &node.inputs[1].value {
if axes.len() != 1 {
panic!("Squeeze: Only one axis should be specified for squeezing.");
}
} else {
panic!("Squeeze: Axes input must be an integer list.");
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're checking for axes in the attributes in squeeze_config to support another opset we need to check here too to make sure the output is properly adjusted.

See the unsqueeze_update_output function which is doing something similar (though it captures the axes to support unsqueeze on multiple axes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, should be a little closer to what you had in mind now!

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for making the changes. I only have one minor comment left (we lost a check with the last committed change and I think we should keep it).

Should be good to go after that!

crates/burn-import/src/onnx/dim_inference.rs Show resolved Hide resolved
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for going through the changes!

LGTM 🚀

@nathanielsimard nathanielsimard merged commit 9c5b07c into tracel-ai:main May 17, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants