/
TensorExtOps.td
40 lines (31 loc) · 1.35 KB
/
TensorExtOps.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
include "include/Dialect/TensorExt/IR/TensorExtDialect.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class TensorExt_Op<string mnemonic, list<Trait> traits = []> :
Op<TensorExt_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::tensor_ext";
}
def TensorExt_RotateOp : TensorExt_Op<"rotate", [Pure, AllTypesMatch<["tensor", "output"]>]> {
let summary = "Rotate a tensor some number of indices left.";
let description = [{
This op represents a left-rotation of a tensor by given number of indices.
Negative shift values are interpreted as right-rotations.
This corresponds to the `rotate` operation in arithmetic FHE schemes like
BGV.
Examples:
```mlir
%0 = ... : tensor<16xi32>
%c7 = arith.constant 7 : i32
%1 = tensor_ext.rotate %0, %c7 : tensor<16xi32>, i32
```
}];
let arguments = (ins AnyTensor:$tensor, SignlessIntegerLike:$shift);
let results = (outs AnyTensor:$output);
let assemblyFormat = "operands attr-dict `:` qualified(type($tensor)) `,` type($shift)";
}
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_