diff --git a/include/Analysis/RotationAnalysis/RotationAnalysis.h b/include/Analysis/RotationAnalysis/RotationAnalysis.h index 3ddb0bb1d..66900acd8 100644 --- a/include/Analysis/RotationAnalysis/RotationAnalysis.h +++ b/include/Analysis/RotationAnalysis/RotationAnalysis.h @@ -175,10 +175,10 @@ class RotationSets { private: /// The accessed indices of a single SSA value of tensor type. Value tensor; - // FIXME: there is likely a data structure that can more efficiently - // represent a set of intervals of integers, which properly merges adjacent - // intervals as values are added. Java/Guava has RangeSet, and boost has - // interval_set. Otherwise might want to roll our own. + + // There is likely a data structure that can more efficiently represent a set + // of intervals of integers, which properly merges adjacent intervals as + // values are added. Java/Guava has RangeSet, and boost has interval_set. std::unordered_set accessedIndices; Status status = Status::Uninitialized; }; diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir index 0694e6f43..3cf00c190 100644 --- a/tests/tensor_ext/rotate_and_reduce.mlir +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -369,3 +369,80 @@ func.func @sum_of_linear_rotates(%arg0: !secret.secret>) -> !secr } -> !secret.secret return %0 : !secret.secret } + +// CHECK-LABEL: @rotate_not_applied_because_rotation_missing +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_rotation_missing(%arg0: !secret.secret>) -> !secret.secret { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<4xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To make the rotation apply, replace %5 with this line + // %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %5 = tensor_ext.rotate %3, %c2 : tensor<4xi16>, index + %6 = arith.addi %4, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c0] : tensor<4xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +} + +// CHECK-LABEL: @rotate_not_applied_because_rotation_duplicated +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_rotation_duplicated(%arg0: !secret.secret>) -> !secret.secret { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<4xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To return to normal, replace %v4_2 with %4 + %v4_2 = arith.addi %4, %3 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %v4_2, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +} + +// CHECK-LABEL: @rotate_not_applied_because_multiple_tensors +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_multiple_tensors( + %arg0 : tensor<4xi16>, %arg1 : tensor<4xi16>) -> i16 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + %4 = arith.addi %2, %3 : tensor<4xi16> + // To return to normal, replace %v4_2 with %4 + %v4_2 = arith.addi %4, %arg0 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %v4_2, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + return %extracted : i16 +} + +// CHECK-LABEL: @rotate_not_applied_because_mixed_ops +// CHECK-COUNT-3: tensor_ext.rotate +func.func @rotate_not_applied_because_mixed_ops(%arg1 : tensor<4xi16>) -> i16 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index + %2 = arith.addi %1, %arg1 : tensor<4xi16> + %3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index + // To return to normal, replace muli with addi + %4 = arith.muli %2, %3 : tensor<4xi16> + %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index + %6 = arith.addi %4, %5 : tensor<4xi16> + %extracted = tensor.extract %6[%c1] : tensor<4xi16> + return %extracted : i16 +}