forked from google/heir
/
straight_line_vectorizer.mlir
139 lines (124 loc) · 6.54 KB
/
straight_line_vectorizer.mlir
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// RUN: heir-opt --straight-line-vectorize %s | FileCheck %s
#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 3>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>
// CHECK-LABEL: add_one
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<8x!lwe.lwe_ciphertext
func.func @add_one(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%true = arith.constant true
%false = arith.constant false
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%extracted = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%extracted_0 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%extracted_1 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%extracted_2 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%extracted_3 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%extracted_4 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%extracted_5 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%extracted_6 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
%0 = lwe.encode %true {encoding = #encoding} : i1 to !pt_ty
%1 = lwe.trivial_encrypt %0 : !pt_ty to !ct_ty
%2 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty
%3 = lwe.trivial_encrypt %2 : !pt_ty to !ct_ty
%4 = cggi.lut3(%extracted, %1, %3) {lookup_table = 8 : ui8} : !ct_ty
%5 = cggi.lut3(%4, %extracted_0, %3) {lookup_table = 150 : ui8} : !ct_ty
%6 = cggi.lut3(%4, %extracted_0, %3) {lookup_table = 23 : ui8} : !ct_ty
%7 = cggi.lut3(%6, %extracted_1, %3) {lookup_table = 43 : ui8} : !ct_ty
%8 = cggi.lut3(%7, %extracted_2, %3) {lookup_table = 43 : ui8} : !ct_ty
%9 = cggi.lut3(%8, %extracted_3, %3) {lookup_table = 43 : ui8} : !ct_ty
%10 = cggi.lut3(%9, %extracted_4, %3) {lookup_table = 43 : ui8} : !ct_ty
%11 = cggi.lut3(%10, %extracted_5, %3) {lookup_table = 105 : ui8} : !ct_ty
%12 = cggi.lut3(%10, %extracted_5, %3) {lookup_table = 43 : ui8} : !ct_ty
%13 = cggi.lut3(%12, %extracted_6, %3) {lookup_table = 105 : ui8} : !ct_ty
%14 = cggi.lut3(%extracted, %1, %3) {lookup_table = 6 : ui8} : !ct_ty
%15 = cggi.lut3(%6, %extracted_1, %3) {lookup_table = 105 : ui8} : !ct_ty
%16 = cggi.lut3(%7, %extracted_2, %3) {lookup_table = 105 : ui8} : !ct_ty
%17 = cggi.lut3(%8, %extracted_3, %3) {lookup_table = 105 : ui8} : !ct_ty
%18 = cggi.lut3(%9, %extracted_4, %3) {lookup_table = 105 : ui8} : !ct_ty
%from_elements = tensor.from_elements %13, %11, %18, %17, %16, %15, %5, %14 : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}
// CHECK-LABEL: require_post_pass_toposort
func.func @require_post_pass_toposort(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%0 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%1 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%2 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%3 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%4 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%5 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%6 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%7 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
// Four ops that can be vectorized
%r1 = cggi.lut3(%0, %1, %2) {lookup_table = 8 : ui8} : !ct_ty
%r2 = cggi.lut3(%3, %4, %5) {lookup_table = 8 : ui8} : !ct_ty
%r3 = cggi.lut3(%4, %5, %6) {lookup_table = 8 : ui8} : !ct_ty
%r4 = cggi.lut3(%5, %6, %7) {lookup_table = 8 : ui8} : !ct_ty
// A non-vectorizable op that uses one of the results
%x = cggi.not %r4 : !ct_ty
// Four more ops that can be vectorized
%r5 = cggi.lut3(%0, %3, %1) {lookup_table = 8 : ui8} : !ct_ty
%r6 = cggi.lut3(%2, %5, %6) {lookup_table = 8 : ui8} : !ct_ty
%r7 = cggi.lut3(%7, %1, %6) {lookup_table = 8 : ui8} : !ct_ty
%r8 = cggi.lut3(%3, %6, %0) {lookup_table = 8 : ui8} : !ct_ty
// The not op has to occur after the lut3s, since it depends on one of the
// results.
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 8 : ui8} : tensor<8x!lwe.lwe_ciphertext
// CHECK: cggi.not
%from_elements = tensor.from_elements %r1, %r2, %r3, %r4, %r5, %r6, %r7, %x : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}
// CHECK-LABEL: transitive_dep_splits_level
func.func @transitive_dep_splits_level(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c7 = arith.constant 7 : index
%0 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty>
%1 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty>
%2 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty>
%3 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty>
%4 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty>
%5 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty>
%6 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty>
%7 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty>
// Four ops that can be vectorized
%r1 = cggi.lut3(%0, %1, %2) {lookup_table = 8 : ui8} : !ct_ty
%r2 = cggi.lut3(%3, %4, %5) {lookup_table = 8 : ui8} : !ct_ty
%r3 = cggi.lut3(%4, %5, %6) {lookup_table = 8 : ui8} : !ct_ty
%r4 = cggi.lut3(%5, %6, %7) {lookup_table = 8 : ui8} : !ct_ty
// A non-vectorizable op that uses one of the results
%n1 = cggi.not %r1 : !ct_ty
%n2 = cggi.not %r2 : !ct_ty
%n3 = cggi.not %r3 : !ct_ty
%n4 = cggi.not %r4 : !ct_ty
// Four more ops that can be vectorized
%r5 = cggi.lut3(%0, %n1, %1) {lookup_table = 8 : ui8} : !ct_ty
%r6 = cggi.lut3(%2, %n2, %6) {lookup_table = 8 : ui8} : !ct_ty
%r7 = cggi.lut3(%7, %n3, %6) {lookup_table = 8 : ui8} : !ct_ty
%r8 = cggi.lut3(%3, %n4, %0) {lookup_table = 8 : ui8} : !ct_ty
// The slice analysis ensures these are split into two levels of 4 ops each.
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 8 : ui8} : tensor<4x!lwe.lwe_ciphertext
// CHECK-COUNT-4: cggi.not
// CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 8 : ui8} : tensor<4x!lwe.lwe_ciphertext
%from_elements = tensor.from_elements %r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8 : tensor<8x!ct_ty>
return %from_elements : tensor<8x!ct_ty>
}