-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Fix bugs for Torch::AtenOneHotOp
#3350
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we able to write LIT tests for these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@qedawkins can you take another look and merge if everything's addressed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Looks like the CI is failing for some reason. I can help merge after that's fixed. |
automerge enabled. Fix the CI and we're good to go! |
d831391
to
7415983
Compare
This commit fixes bugs for the onnx.OneHot operator by: 1) Converting negative indices to non-negative indices 2) Handling both int and float types for off and on values 3) Using the correct result type It also includes a new unit test.
This commit fixes the bugs for Torch::AtenOneHotOp by: 1. Using Torch::kUnknownSize as the default value for numClasses in the pattern matching stage for DecomposeAtenOneHotOp 2. Adding AtenIntScalarOp to the patterns in TorchToArith 3. Handling both int and float types for off and on values in ONNX to Torch conversion It also includes: 1. A new test in TorchToArith/basic.mlir, for torch.aten.Int.Scalar, and 2. A new test in decompose-complex-ops.mlir, for torch.aten.one_hot
7415983
to
4c58e11
Compare
This PR fixes the bugs for `Torch::AtenOneHotOp` by: 1) Using `Torch::kUnknownSize` as the default value for `numClasses` in the pattern matching stage in `DecomposeAtenOneHotOp` 2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith` 3) Handling both `int` and `float` types for `off` and `on` values in `TorchOnnxToTorch` conversion It also includes: 1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`, and 2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot` **Dependencies** This PR is dependent on llvm#3334.
This PR fixes the bugs for
Torch::AtenOneHotOp
by:Torch::kUnknownSize
as the default value fornumClasses
inthe pattern matching stage in
DecomposeAtenOneHotOp
AtenIntScalarOp
to the patterns inTorchToArith
int
andfloat
types foroff
andon
values inTorchOnnxToTorch
conversionIt also includes:
TorchToArith/basic.mlir
, fortorch.aten.Int.Scalar
, anddecompose-complex-ops.mlir
, fortorch.aten.one_hot
Dependencies
This PR is dependent on #3334.