Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 18, 2024
1 parent 980f205 commit 52f14c7
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
8 changes: 4 additions & 4 deletions include/Analysis/RotationAnalysis/RotationAnalysis.h
Expand Up @@ -175,10 +175,10 @@ class RotationSets {
private:
/// The accessed indices of a single SSA value of tensor type.
Value tensor;
// FIXME: there is likely a data structure that can more efficiently
// represent a set of intervals of integers, which properly merges adjacent
// intervals as values are added. Java/Guava has RangeSet, and boost has
// interval_set. Otherwise might want to roll our own.

// There is likely a data structure that can more efficiently represent a set
// of intervals of integers, which properly merges adjacent intervals as
// values are added. Java/Guava has RangeSet, and boost has interval_set.
std::unordered_set<int64_t> accessedIndices;
Status status = Status::Uninitialized;
};
Expand Down
77 changes: 77 additions & 0 deletions tests/tensor_ext/rotate_and_reduce.mlir
Expand Up @@ -369,3 +369,80 @@ func.func @sum_of_linear_rotates(%arg0: !secret.secret<tensor<32xi16>>) -> !secr
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}

// CHECK-LABEL: @rotate_not_applied_because_rotation_missing
// CHECK-COUNT-3: tensor_ext.rotate
func.func @rotate_not_applied_because_rotation_missing(%arg0: !secret.secret<tensor<4xi16>>) -> !secret.secret<i16> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = secret.generic ins(%arg0 : !secret.secret<tensor<4xi16>>) {
^bb0(%arg1: tensor<4xi16>):
%1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index
%2 = arith.addi %1, %arg1 : tensor<4xi16>
%3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index
%4 = arith.addi %2, %3 : tensor<4xi16>
// To make the rotation apply, replace %5 with this line
// %5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index
%5 = tensor_ext.rotate %3, %c2 : tensor<4xi16>, index
%6 = arith.addi %4, %5 : tensor<4xi16>
%extracted = tensor.extract %6[%c0] : tensor<4xi16>
secret.yield %extracted : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}

// CHECK-LABEL: @rotate_not_applied_because_rotation_duplicated
// CHECK-COUNT-3: tensor_ext.rotate
func.func @rotate_not_applied_because_rotation_duplicated(%arg0: !secret.secret<tensor<4xi16>>) -> !secret.secret<i16> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = secret.generic ins(%arg0 : !secret.secret<tensor<4xi16>>) {
^bb0(%arg1: tensor<4xi16>):
%1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index
%2 = arith.addi %1, %arg1 : tensor<4xi16>
%3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index
%4 = arith.addi %2, %3 : tensor<4xi16>
// To return to normal, replace %v4_2 with %4
%v4_2 = arith.addi %4, %3 : tensor<4xi16>
%5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index
%6 = arith.addi %v4_2, %5 : tensor<4xi16>
%extracted = tensor.extract %6[%c1] : tensor<4xi16>
secret.yield %extracted : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}

// CHECK-LABEL: @rotate_not_applied_because_multiple_tensors
// CHECK-COUNT-3: tensor_ext.rotate
func.func @rotate_not_applied_because_multiple_tensors(
%arg0 : tensor<4xi16>, %arg1 : tensor<4xi16>) -> i16 {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index
%2 = arith.addi %1, %arg1 : tensor<4xi16>
%3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index
%4 = arith.addi %2, %3 : tensor<4xi16>
// To return to normal, replace %v4_2 with %4
%v4_2 = arith.addi %4, %arg0 : tensor<4xi16>
%5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index
%6 = arith.addi %v4_2, %5 : tensor<4xi16>
%extracted = tensor.extract %6[%c1] : tensor<4xi16>
return %extracted : i16
}

// CHECK-LABEL: @rotate_not_applied_because_mixed_ops
// CHECK-COUNT-3: tensor_ext.rotate
func.func @rotate_not_applied_because_mixed_ops(%arg1 : tensor<4xi16>) -> i16 {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%1 = tensor_ext.rotate %arg1, %c1 : tensor<4xi16>, index
%2 = arith.addi %1, %arg1 : tensor<4xi16>
%3 = tensor_ext.rotate %1, %c1 : tensor<4xi16>, index
// To return to normal, replace muli with addi
%4 = arith.muli %2, %3 : tensor<4xi16>
%5 = tensor_ext.rotate %3, %c1 : tensor<4xi16>, index
%6 = arith.addi %4, %5 : tensor<4xi16>
%extracted = tensor.extract %6[%c1] : tensor<4xi16>
return %extracted : i16
}

0 comments on commit 52f14c7

Please sign in to comment.