From 312d56795dac6c033ee72b10796de13ea3d2f5bf Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 8 Mar 2024 14:56:35 -0800 Subject: [PATCH 01/16] implement first version of rotate-and-reduce --- include/Dialect/TensorExt/Transforms/BUILD | 5 +- include/Dialect/TensorExt/Transforms/Passes.h | 1 + .../Dialect/TensorExt/Transforms/Passes.td | 43 +++ .../TensorExt/Transforms/RotateAndReduce.h | 17 + lib/Dialect/TensorExt/Transforms/BUILD | 20 ++ .../TensorExt/Transforms/RotateAndReduce.cpp | 195 +++++++++++ tests/tensor_ext/rotate_and_reduce.mlir | 307 ++++++++++++++++++ 7 files changed, 586 insertions(+), 2 deletions(-) create mode 100644 include/Dialect/TensorExt/Transforms/RotateAndReduce.h create mode 100644 lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp create mode 100644 tests/tensor_ext/rotate_and_reduce.mlir diff --git a/include/Dialect/TensorExt/Transforms/BUILD b/include/Dialect/TensorExt/Transforms/BUILD index 574ad102f..5f2a908e8 100644 --- a/include/Dialect/TensorExt/Transforms/BUILD +++ b/include/Dialect/TensorExt/Transforms/BUILD @@ -1,4 +1,4 @@ -# InsertRotate tablegen and headers. +# TensorExt pass tablegen and headers. load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") @@ -49,7 +49,8 @@ gentbl_cc_library( ) exports_files([ - "Passes.h", "CollapseInsertionChains.h", "InsertRotate.h", + "Passes.h", + "RotateAndReduce.h", ]) diff --git a/include/Dialect/TensorExt/Transforms/Passes.h b/include/Dialect/TensorExt/Transforms/Passes.h index 01eafff25..6e101454b 100644 --- a/include/Dialect/TensorExt/Transforms/Passes.h +++ b/include/Dialect/TensorExt/Transforms/Passes.h @@ -4,6 +4,7 @@ #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" #include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" #include "include/Dialect/TensorExt/Transforms/InsertRotate.h" +#include "include/Dialect/TensorExt/Transforms/RotateAndReduce.h" namespace mlir { namespace heir { diff --git a/include/Dialect/TensorExt/Transforms/Passes.td b/include/Dialect/TensorExt/Transforms/Passes.td index 88d396611..1b9f2b487 100644 --- a/include/Dialect/TensorExt/Transforms/Passes.td +++ b/include/Dialect/TensorExt/Transforms/Passes.td @@ -60,4 +60,47 @@ def CollapseInsertionChains : Pass<"collapse-insertion-chains"> { let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; } +def RotateAndReduce : Pass<"rotate-and-reduce"> { + let summary = "Use a logarithmic number of rotations to reduce a tensor."; + let description = [{ + This pass identifies when a commutative, associative binary operation is used + to reduce all of the entries of a tensor to a single value, and optimizes the + operations by using a logarithmic number of reduction operations. + + In particular, this pass identifies an unrolled set of operations of the form + (the binary ops may come in any order): + + ```mlir + %0 = tensor.extract %t[0] : tensor<8xi32> + %1 = tensor.extract %t[1] : tensor<8xi32> + %2 = tensor.extract %t[2] : tensor<8xi32> + %3 = tensor.extract %t[3] : tensor<8xi32> + %4 = tensor.extract %t[4] : tensor<8xi32> + %5 = tensor.extract %t[5] : tensor<8xi32> + %6 = tensor.extract %t[6] : tensor<8xi32> + %7 = tensor.extract %t[7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + ``` + + and replaces it with a logarithmic number of `rotate` and `addi` operations: + + ```mlir + %0 = tensor_ext.rotate %t, 4 : tensor<8xi32> + %1 = arith.addi %t, %0 : tensor<8xi32> + %2 = tensor_ext.rotate %1, 2 : tensor<8xi32> + %3 = arith.addi %1, %2 : tensor<8xi32> + %4 = tensor_ext.rotate %3, 1 : tensor<8xi32> + %5 = arith.addi %3, %4 : tensor<8xi32> + ``` + }]; + let dependentDialects = ["mlir::heir::tensor_ext::TensorExtDialect"]; +} + + #endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_TD_ diff --git a/include/Dialect/TensorExt/Transforms/RotateAndReduce.h b/include/Dialect/TensorExt/Transforms/RotateAndReduce.h new file mode 100644 index 000000000..51605ec1e --- /dev/null +++ b/include/Dialect/TensorExt/Transforms/RotateAndReduce.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ +#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DECL_ROTATEANDREDUCE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_ROTATEANDREDUCE_H_ diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index f06b7becd..7bc540508 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -11,6 +11,7 @@ cc_library( deps = [ ":CollapseInsertionChains", ":InsertRotate", + ":RotateAndReduce", "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//mlir:IR", @@ -54,3 +55,22 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "RotateAndReduce", + srcs = ["RotateAndReduce.cpp"], + hdrs = [ + "@heir//include/Dialect/TensorExt/Transforms:RotateAndReduce.h", + ], + deps = [ + "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp new file mode 100644 index 000000000..a747ef5ee --- /dev/null +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -0,0 +1,195 @@ +#include "include/Dialect/TensorExt/Transforms/RotateAndReduce.h" + +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Iterators.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +#define DEBUG_TYPE "rotate-and-reduce" + +namespace mlir { +namespace heir { +namespace tensor_ext { + +#define GEN_PASS_DEF_ROTATEANDREDUCE +#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" + +/// A pass that searches for a length N sequence of binary operations that +/// reduces a length N vector to a single scalar, and replaces it with a +/// logarithmic number of rotations and binary operations. +struct RotateAndReduce : impl::RotateAndReduceBase { + using RotateAndReduceBase::RotateAndReduceBase; + + template + void tryReplace(ArithOp op, DenseSet &visited) { + LLVM_DEBUG(llvm::dbgs() << "Trying to replace " << *op << "\n"); + SetVector backwardSlice; + BackwardSliceOptions options; + // asserts that the parent op has a single region with a single block. + options.omitBlockArguments = false; + + DenseSet inputTensors; + DenseSet visitedReductionOps; + DenseSet accessIndices; + DenseMap opCounts; + opCounts[op->getName().getStringRef()]++; + + // TODO(#523): replace backward slice with a dataflow analysis + getBackwardSlice(op.getOperation(), &backwardSlice, options); + for (Operation *upstreamOpPtr : backwardSlice) { + auto result = + llvm::TypeSwitch(upstreamOpPtr) + .Case( + [&](auto upstreamOp) { return success(); }) + .template Case( + [&](auto upstreamOp) { + opCounts[upstreamOp->getName().getStringRef()]++; + // More than one reduction op is mixed in the reduction. + if (opCounts.size() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "contains multiple incompatible ops " + << op->getName() << " and " + << upstreamOp->getName() << "\n"); + return failure(); + } + + // TODO(#522): support these non-tensor-extract operands by + // saving the values, and applying them again to the final + // result. + for (Value operand : upstreamOp->getOperands()) { + if (operand.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because reduction " + "includes non-tensor value operands " + << operand << "\n"); + return failure(); + } + } + visitedReductionOps.insert(upstreamOp); + return success(); + }) + .template Case([&](auto tensorOp) { + inputTensors.insert(tensorOp.getTensor()); + if (inputTensors.size() > 1) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op due to multiple input tensors\n"); + return failure(); + } + + // If the tensor is not 1D, we can't replace it with a rotate. + if (tensorOp.getIndices().size() != 1) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op due to >1D input tensor\n"); + return failure(); + } + + // If the access index is not constant, we can't tell if we are + // reducing the entire vector (each index occurs exactly once in + // the redution). + arith::ConstantOp indexConstant = + tensorOp.getIndices() + .front() + .template getDefiningOp(); + if (!indexConstant) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op due to non constant index access;" + << " (do you need to run --canonicalize or --sccp?)\n"); + return failure(); + } + int64_t accessIndex = + indexConstant.getValue().cast().getInt(); + + // If the access index was already seen, then fail because some + // tensor element contributes more than once to the reduction. + if (accessIndices.count(accessIndex)) { + LLVM_DEBUG( + llvm::dbgs() + << "Not replacing op because input tensor was accessed " + "multiple times in at same index\n"); + return failure(); + } + LLVM_DEBUG(llvm::dbgs() + << "Adding valid index " << accessIndex << "\n"); + accessIndices.insert(accessIndex); + return success(); + }) + .Default([&](Operation *op) { return failure(); }); + + if (failed(result)) { + return; + } + } + + // The test for a match is now: does the number of accessed indices exactly + // match the size of the tensor? I.e., does each tensor element show up + // exactly once in the reduction? + auto tensorShape = + inputTensors.begin()->getType().cast().getShape(); + if (tensorShape.size() != 1 || tensorShape[0] != accessIndices.size()) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because tensor shape (" + << inputTensors.begin()->getType() + << ") is not fully reduced. Only " << accessIndices.size() + << " of " << tensorShape[0] << " indices were accessed\n"); + return; + } + + // From here we know we will succeed. + auto b = ImplicitLocOpBuilder(op->getLoc(), op); + Value inputTensor = *inputTensors.begin(); + Operation *finalOp; + for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0; + shiftSize /= 2) { + auto rotatedTensor = b.create( + inputTensor, b.create(b.getIndexAttr(shiftSize))); + auto addOp = b.create(inputTensor, rotatedTensor); + finalOp = addOp; + inputTensor = addOp->getResult(0); + } + + auto *parentOp = op->getParentOp(); + // We can extract at any index; every index contains the same reduced value. + auto extractOp = b.create( + finalOp->getResult(0), b.create(0).getResult()); + op->replaceAllUsesWith(extractOp); + LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); + + // Mark all ops in the reduction as visited so we don't try to replace them + // twice. + for (Operation *visitedOp : visitedReductionOps) { + visited.insert(visitedOp); + } + } + + void runOnOperation() override { + DenseSet visited; + // Traverse the IR in reverse order so that we can eagerly compute backward + // slices for each operation. + getOperation()->walk( + [&](Operation *op) { + if (visited.count(op)) { + return; + } + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplace(arithOp, visited); + }) + .Case([&](auto arithOp) { + tryReplace(arithOp, visited); + }); + }); + } +}; + +} // namespace tensor_ext +} // namespace heir +} // namespace mlir diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir new file mode 100644 index 000000000..126f70577 --- /dev/null +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -0,0 +1,307 @@ +// RUN: heir-opt --rotate-and-reduce --canonicalize %s | FileCheck %s + + +// Sum all entries of a tensor into a single scalar +// CHECK-LABEL: @simple_sum +// CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c4]] +// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[arg0]], %[[v0]] +// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[v1]], %[[c2]] +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] +// CHECK-NEXT: %[[v4:.*]] = tensor_ext.rotate %[[v3]], %[[c1]] +// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] +// CHECK-NEXT: %[[v6:.*]] = tensor.extract %[[v5]][%[[c0]]] +// CHECK-NEXT: return %[[v6]] +func.func @simple_sum(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_mixed_ops +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_mixed_ops(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.muli %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_missing_indices +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_missing_indices(%arg0: tensor<16xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c9 = arith.constant 9 : index + %c10 = arith.constant 10 : index + %c11 = arith.constant 11 : index + %c12 = arith.constant 12 : index + %c13 = arith.constant 13 : index + %c14 = arith.constant 14 : index + %0 = tensor.extract %arg0[%c0] : tensor<16xi32> + %1 = tensor.extract %arg0[%c1] : tensor<16xi32> + %2 = tensor.extract %arg0[%c2] : tensor<16xi32> + %3 = tensor.extract %arg0[%c3] : tensor<16xi32> + %4 = tensor.extract %arg0[%c4] : tensor<16xi32> + %5 = tensor.extract %arg0[%c5] : tensor<16xi32> + %6 = tensor.extract %arg0[%c6] : tensor<16xi32> + %7 = tensor.extract %arg0[%c7] : tensor<16xi32> + %8 = tensor.extract %arg0[%c8] : tensor<16xi32> + %9 = tensor.extract %arg0[%c9] : tensor<16xi32> + %10 = tensor.extract %arg0[%c10] : tensor<16xi32> + %11 = tensor.extract %arg0[%c11] : tensor<16xi32> + %12 = tensor.extract %arg0[%c12] : tensor<16xi32> + %13 = tensor.extract %arg0[%c13] : tensor<16xi32> + %14 = tensor.extract %arg0[%c14] : tensor<16xi32> + // missing element 15 + %v1 = arith.addi %0, %1 : i32 + %v2 = arith.addi %v1, %2 : i32 + %v3 = arith.addi %v2, %3 : i32 + %v4 = arith.addi %v3, %4 : i32 + %v5 = arith.addi %v4, %5 : i32 + %v6 = arith.addi %v5, %6 : i32 + %v7 = arith.addi %v6, %7 : i32 + %v8 = arith.addi %v7, %8 : i32 + %v9 = arith.addi %v8, %9 : i32 + %v10 = arith.addi %v9, %10 : i32 + %v11 = arith.addi %v10, %11 : i32 + %v12 = arith.addi %v11, %12 : i32 + %v13 = arith.addi %v12, %13 : i32 + %v14 = arith.addi %v13, %14 : i32 + return %v14 : i32 +} + +// CHECK-LABEL: @not_supported_repeated_indices +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_repeated_indices(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + // repeats element 3 + %4 = tensor.extract %arg0[%c3] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_unsupported_op +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_unsupported_op(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.subi %0, %1 : i32 + %9 = arith.subi %8, %2 : i32 + %10 = arith.subi %9, %3 : i32 + %11 = arith.subi %10, %4 : i32 + %12 = arith.subi %11, %5 : i32 + %13 = arith.subi %12, %6 : i32 + %14 = arith.subi %13, %7 : i32 + return %14 : i32 +} + +// 2D tensor not supported +// CHECK-LABEL: @not_supported_bad_tensor_shape +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_bad_tensor_shape(%arg0: tensor<1x8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c1, %c0] : tensor<1x8xi32> + %1 = tensor.extract %arg0[%c1, %c1] : tensor<1x8xi32> + %2 = tensor.extract %arg0[%c1, %c2] : tensor<1x8xi32> + %3 = tensor.extract %arg0[%c1, %c3] : tensor<1x8xi32> + %4 = tensor.extract %arg0[%c1, %c4] : tensor<1x8xi32> + %5 = tensor.extract %arg0[%c1, %c5] : tensor<1x8xi32> + %6 = tensor.extract %arg0[%c1, %c6] : tensor<1x8xi32> + %7 = tensor.extract %arg0[%c1, %c7] : tensor<1x8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// reducing from multiple input tensors +// CHECK-LABEL: @not_supported_multiple_tensors +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_multiple_tensors(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + // uses %arg1 + %2 = tensor.extract %arg1[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_non_constant_index_access +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_non_constant_index_access(%arg0: tensor<8xi32>, %arg1: index) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + // uses non-constant index + %2 = tensor.extract %arg0[%arg1] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + +// CHECK-LABEL: @not_supported_non_tensor_operands +// CHECK-NOT: tensor_ext.rotate +// TODO(#522): support this +func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c2_i32 = arith.constant 2 : i32 + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + // next op uses non-tensor operand + %9 = arith.addi %8, %c2_i32 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + %15 = arith.addi %14, %2 : i32 + return %15 : i32 +} From 84e47ad87e035bcc8cf75a8c066576e05a2a70c2 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Tue, 12 Mar 2024 09:47:57 -0700 Subject: [PATCH 02/16] Integrate LLVM at llvm/llvm-project@335883844642 Updates LLVM usage to match [335883844642](https://github.com/llvm/llvm-project/commit/335883844642) PiperOrigin-RevId: 615075804 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 6fe18055c..62aa615a7 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "a83f8e0314fcdda162e54cbba1c9dcf230dff093" + LLVM_COMMIT = "3358838446428976a41390fde98fe5b04b08a132" new_git_repository( name = name, From bfb155c05fc7f90b0a47f4b52244414b0949a962 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 12 Mar 2024 15:02:12 -0700 Subject: [PATCH 03/16] disable cggi-straight-line-vectorizer test for now --- tests/cggi/straight_line_vectorizer.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cggi/straight_line_vectorizer.mlir b/tests/cggi/straight_line_vectorizer.mlir index 490d6dd94..f466b64c0 100644 --- a/tests/cggi/straight_line_vectorizer.mlir +++ b/tests/cggi/straight_line_vectorizer.mlir @@ -1,4 +1,5 @@ -// RUN: heir-opt --straight-line-vectorize %s | FileCheck %s +// TODO(#519): disable FileChecks until nondeterminism issues are resolved +// RUN: heir-opt --straight-line-vectorize %s #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext From 2b1887fd1faa9508a8a90609137e8db8edf5b796 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Tue, 12 Mar 2024 16:08:55 -0700 Subject: [PATCH 04/16] Integrate LLVM at llvm/llvm-project@08dd645c15a0 Updates LLVM usage to match [08dd645c15a0](https://github.com/llvm/llvm-project/commit/08dd645c15a0) PiperOrigin-RevId: 615204135 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 62aa615a7..ca9d5a7c1 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "3358838446428976a41390fde98fe5b04b08a132" + LLVM_COMMIT = "08dd645c15a091a53313e278d8f3c090e7c385d1" new_git_repository( name = name, From fdb42b1bcd6304fa4a33c896f51ebde10ae26ff8 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Tue, 12 Mar 2024 22:09:00 -0700 Subject: [PATCH 05/16] Integrate LLVM at llvm/llvm-project@f1ca2a09671e Updates LLVM usage to match [f1ca2a09671e](https://github.com/llvm/llvm-project/commit/f1ca2a09671e) PiperOrigin-RevId: 615284962 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index ca9d5a7c1..cf8d87fff 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "08dd645c15a091a53313e278d8f3c090e7c385d1" + LLVM_COMMIT = "f1ca2a09671e4d4acc2bea362b39268ed7883b6d" new_git_repository( name = name, From 45db59d0708ad7b8eb1e245e8fb848eaf200cf86 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Wed, 13 Mar 2024 02:53:13 -0700 Subject: [PATCH 06/16] Integrate LLVM at llvm/llvm-project@a38b7a432d3c Updates LLVM usage to match [a38b7a432d3c](https://github.com/llvm/llvm-project/commit/a38b7a432d3c) PiperOrigin-RevId: 615342960 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index cf8d87fff..4700c80a5 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "f1ca2a09671e4d4acc2bea362b39268ed7883b6d" + LLVM_COMMIT = "a38b7a432d3cbb093af9310eba5b4982dc0a0243" new_git_repository( name = name, From a96a3129bd9ebea6ea657ef471cdceba164ea8e1 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Wed, 13 Mar 2024 05:48:06 -0700 Subject: [PATCH 07/16] Integrate LLVM at llvm/llvm-project@e371ada409b2 Updates LLVM usage to match [e371ada409b2](https://github.com/llvm/llvm-project/commit/e371ada409b2) PiperOrigin-RevId: 615382976 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 4700c80a5..118539968 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "a38b7a432d3cbb093af9310eba5b4982dc0a0243" + LLVM_COMMIT = "e371ada409b225ea990b5ac0d5cafea26a6046e1" new_git_repository( name = name, From 348c2252cf44e0378a64545dbeec682a51d83dd1 Mon Sep 17 00:00:00 2001 From: Alexander Viand Date: Wed, 13 Mar 2024 08:09:29 -0700 Subject: [PATCH 08/16] Copybara import of the project: -- 5c5b7e3569a507b03e19b1d4f65455a159af8f7e by Alexander : add "tensor of tensor" dialect conversion helpers Adds a utility function `addTensorOfTensorConversionPatterns`` designed to be used as part of dialect conversion. These support lowerings that would create "tensor of tensor" types, such as what happens to `tensor<..xpoly>` in `--polynomial-to-standard`. This commit also enables this for `--polynomial-to-standard`, which is a first step towards supporting polynomial operations over tensors in that lowering. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/heir/pull/508 from AlexanderViand-Intel:poly_tensor_support 5c5b7e3569a507b03e19b1d4f65455a159af8f7e PiperOrigin-RevId: 615421370 --- lib/Conversion/BUILD | 5 + lib/Conversion/PolynomialToStandard/BUILD | 3 + .../PolynomialToStandard.cpp | 34 ++- lib/Conversion/Utils.cpp | 244 ++++++++++++++++++ lib/Conversion/Utils.h | 7 + tests/polynomial/lower_add.mlir | 52 ++++ 6 files changed, 342 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/BUILD b/lib/Conversion/BUILD index 56f0fcf03..f3ecd6420 100644 --- a/lib/Conversion/BUILD +++ b/lib/Conversion/BUILD @@ -8,9 +8,14 @@ cc_library( srcs = ["Utils.cpp"], hdrs = ["Utils.h"], deps = [ + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) diff --git a/lib/Conversion/PolynomialToStandard/BUILD b/lib/Conversion/PolynomialToStandard/BUILD index 41234d13c..ab7d20000 100644 --- a/lib/Conversion/PolynomialToStandard/BUILD +++ b/lib/Conversion/PolynomialToStandard/BUILD @@ -12,6 +12,8 @@ cc_library( deps = [ "@heir//include/Conversion/PolynomialToStandard:pass_inc_gen", "@heir//lib/Conversion:Utils", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", + "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@heir//lib/Dialect/Polynomial/IR:PolynomialOps", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -23,6 +25,7 @@ cc_library( "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], diff --git a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp index d493dbb7c..b6e376463 100644 --- a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp +++ b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp @@ -1,18 +1,45 @@ #include "include/Conversion/PolynomialToStandard/PolynomialToStandard.h" +#include +#include +#include +#include +#include +#include + +#include "include/Dialect/Polynomial/IR/Polynomial.h" +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Dialect/Polynomial/IR/PolynomialOps.h" #include "include/Dialect/Polynomial/IR/PolynomialTypes.h" #include "lib/Conversion/Utils.h" +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" // from @llvm-project #include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/include/mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -474,11 +501,12 @@ struct ConvertMul : public OpConversionPattern { // 2N - 1 sized result tensor -> reduce modulo ideal to get a N sized tensor func::FuncOp divMod = getFuncOpCallback(funcType, polyTy.getRing()); - if (!divMod) + if (!divMod) { return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { diag << "Missing software implementation for polynomial mod op of type" << funcType << " and for ring " << polyTy.getRing(); }); + } rewriter.replaceOpWithNewOp(op, divMod, polyMul.getResult(0)); return success(); @@ -739,8 +767,8 @@ void PolynomialToStandard::runOnOperation() { ConvertConstant, ConvertMulScalar>(typeConverter, context); patterns.add(typeConverter, patterns.getContext(), getDivmodOp); addStructuralConversionPatterns(typeConverter, patterns, target); + addTensorOfTensorConversionPatterns(typeConverter, patterns, target); - // TODO(#143): Handle tensor of polys. if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } diff --git a/lib/Conversion/Utils.cpp b/lib/Conversion/Utils.cpp index 2bf992392..f34a1fabf 100644 --- a/lib/Conversion/Utils.cpp +++ b/lib/Conversion/Utils.cpp @@ -1,8 +1,23 @@ #include "lib/Conversion/Utils.h" +#include +#include +#include + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Region.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -12,6 +27,235 @@ using ::mlir::func::CallOp; using ::mlir::func::FuncOp; using ::mlir::func::ReturnOp; +struct ConvertAny : public ConversionPattern { + ConvertAny(const TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context) { + setDebugName("ConvertAny"); + setHasBoundedRewriteRecursion(true); + } + + // generate a new op where all operands have been replaced with their + // materialized/typeconverted versions + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newOperandTypes; + if (failed(getTypeConverter()->convertTypes(op->getOperandTypes(), + newOperandTypes))) + return failure(); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + + SmallVector, 1> regions; + IRMapping mapping; + for (auto &r : op->getRegions()) { + Region *newRegion = new Region(); + rewriter.cloneRegionBefore(r, *newRegion, newRegion->end(), mapping); + if (failed(rewriter.convertRegionTypes(newRegion, *this->typeConverter))) + return failure(); + regions.emplace_back(newRegion); + } + + Operation *newOp = rewriter.create(OperationState( + op->getLoc(), op->getName().getStringRef(), operands, newResultTypes, + op->getAttrs(), op->getSuccessors(), regions)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct ConvertExtract : public OpConversionPattern { + ConvertExtract(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Convert a tensor.extract that would type-convert to extracting a tensor to + // a tensor.extract_slice operation instead. Specifically, this targets + // extracting SourceType from tensor<...xSourceType> when SourceType would be + // type converted to tensor<...>. + LogicalResult matchAndRewrite( + tensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // replace tensor.extract %t[%i] from tensor + // with an equivalent tensor.slice from tensor + auto shape = op.getTensor().getType().getShape(); + auto resultType = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + auto resultShape = resultType.getShape(); + + // expand op's list of indices by appending as many zeros as there are + // dimension in resultShape + SmallVector offsets; + offsets.append(op.getIndices().begin(), op.getIndices().end()); + for (size_t i = 0; i < resultShape.size(); ++i) { + offsets.push_back(rewriter.getIndexAttr(0)); + } + + // expand resultShape by prepending as many ones as there are dimensions in + // shape + SmallVector sizes; + for (size_t i = 0; i < shape.size(); ++i) { + sizes.push_back(rewriter.getIndexAttr(1)); + } + for (int64_t i : resultShape) { + sizes.push_back(rewriter.getIndexAttr(i)); + } + + // strides are all 1, and we need as many as there are dimensions in + // both shape and resultShape together + SmallVector strides; + for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) { + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getTensor(), offsets, sizes, strides); + + return success(); + } +}; + +struct ConvertInsert : public OpConversionPattern { + ConvertInsert(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Convert a tensor.insert that would type-convert to inserting a tensor to + // a tensor.insert_slice operation instead. Specifically, this targets + // inserting SourceType into tensor<...xSourceType> when SourceType would be + // type converted to tensor<...>. + LogicalResult matchAndRewrite( + tensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // replace tensor.insert %s into %t[%i] with tensor + // with an equivalent tensor.insert_slice with tensor + auto shape = op.getDest().getType().getShape(); + auto resultType = getTypeConverter() + ->convertType(op.getScalar().getType()) + .cast(); + auto resultShape = resultType.getShape(); + + // expand op's list of indices by appending as many zeros as there are + // dimension in resultShape + SmallVector offsets; + offsets.append(op.getIndices().begin(), op.getIndices().end()); + for (size_t i = 0; i < resultShape.size(); ++i) { + offsets.push_back(rewriter.getIndexAttr(0)); + } + + // expand resultShape by prepending as many ones as there are dimensions in + // shape + SmallVector sizes; + for (size_t i = 0; i < shape.size(); ++i) { + sizes.push_back(rewriter.getIndexAttr(1)); + } + for (int64_t i : resultShape) { + sizes.push_back(rewriter.getIndexAttr(i)); + } + + // strides are all 1, and we need as many as there are dimensions in + // both shape and resultShape together + SmallVector strides; + for (size_t i = 0; i < shape.size() + resultShape.size(); ++i) { + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getScalar(), adaptor.getDest(), offsets, sizes, strides); + + return success(); + } +}; + +struct ConvertFromElements + : public OpConversionPattern { + ConvertFromElements(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + // Converts a tensor.from_elements %s0, %s1, ... : tensor<...xSourceType> + // where SourceType would be type-converted to tensor<...> to + // a concatenation of the converted operands (with appropriate reshape) + LogicalResult matchAndRewrite( + tensor::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expand each of the (converted) operands: + SmallVector newOperands; + for (auto o : adaptor.getElements()) { + // extend tensor<...xT> to tensor<1x...xT> + if (auto tensorType = o.getType().dyn_cast()) { + auto shape = tensorType.getShape(); + SmallVector newShape(1, 1); + newShape.append(shape.begin(), shape.end()); + + // Create a dense constant for targetShape + auto shapeOp = rewriter.create( + op.getLoc(), + RankedTensorType::get(newShape.size(), rewriter.getIndexType()), + rewriter.getIndexTensorAttr(newShape)); + + auto reshapeOp = rewriter.create( + op.getLoc(), + RankedTensorType::get(newShape, tensorType.getElementType()), o, + shapeOp); + newOperands.push_back(reshapeOp); + } else { + newOperands.push_back(o); + } + } + // Create the final tensor.concat operation + rewriter.replaceOpWithNewOp(op, 0, newOperands); + + return success(); + } +}; + +void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target) { + target.addDynamicallyLegalDialect([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); + + typeConverter.addConversion([&](TensorType type) -> Type { + if (!typeConverter.isLegal(type.getElementType())) { + typeConverter.convertType(type.getElementType()).dump(); + if (auto convertedType = + typeConverter.convertType(type.getElementType())) { + if (auto castConvertedType = + convertedType.dyn_cast()) { + // Create the combined shape + auto polyShape = castConvertedType.getShape(); + auto tensorShape = type.getShape(); + SmallVector combinedShape(tensorShape.begin(), + tensorShape.end()); + combinedShape.append(polyShape.begin(), polyShape.end()); + auto combinedType = RankedTensorType::get( + combinedShape, castConvertedType.getElementType()); + return combinedType; + } + } + } + return type; + }); + + target.addDynamicallyLegalDialect([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); + + patterns.add( + typeConverter, patterns.getContext()); +} + void addStructuralConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { diff --git a/lib/Conversion/Utils.h b/lib/Conversion/Utils.h index e2f953dc6..48054f295 100644 --- a/lib/Conversion/Utils.h +++ b/lib/Conversion/Utils.h @@ -1,11 +1,18 @@ #ifndef LIB_CONVERSION_UTILS_H_ #define LIB_CONVERSION_UTILS_H_ +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { namespace heir { +// Adds conversion patterns that deal with tensor<..xsource_type> +// when source_type will be type converted to tensor<...>, too +void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + // Adds the standard set of conversion patterns for // converting types involved in func, cf, etc., which // don't depend on the logic of the dialect beyond the diff --git a/tests/polynomial/lower_add.mlir b/tests/polynomial/lower_add.mlir index 0d9eb74d9..2495ba14b 100644 --- a/tests/polynomial/lower_add.mlir +++ b/tests/polynomial/lower_add.mlir @@ -4,6 +4,8 @@ #ring = #polynomial.ring #ring_prime = #polynomial.ring + +// CHECK-LABEL: @test_lower_add_power_of_two_cmod func.func @test_lower_add_power_of_two_cmod() -> !polynomial.polynomial<#ring> { // 2 + 2x + 2x^2 + ... + 2x^{1023} // CHECK: [[X:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]] @@ -19,6 +21,7 @@ func.func @test_lower_add_power_of_two_cmod() -> !polynomial.polynomial<#ring> { return %poly2 : !polynomial.polynomial<#ring> } +// CHECK-LABEL: @test_lower_add_prime_cmod func.func @test_lower_add_prime_cmod() -> !polynomial.polynomial<#ring_prime> { // CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi31>]] %coeffs1 = arith.constant dense<2> : tensor<1024xi31> @@ -41,3 +44,52 @@ func.func @test_lower_add_prime_cmod() -> !polynomial.polynomial<#ring_prime> { // CHECK: return [[TRUNC_RESULT]] : [[T]] return %poly2 : !polynomial.polynomial<#ring_prime> } + +// CHECK-LABEL: @test_lower_add_tensor +func.func @test_lower_add_tensor() -> tensor<2x!polynomial.polynomial<#ring>> { + // 2 + 2x + 2x^2 + ... + 2x^{1023} + // CHECK-DAG: [[A:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]] + %coeffsA = arith.constant dense<2> : tensor<1024xi32> + // CHECK-DAG: [[B:%.+]] = arith.constant dense<3> : [[T]] + %coeffsB = arith.constant dense<3> : tensor<1024xi32> + // CHECK-DAG: [[C:%.+]] = arith.constant dense<4> : [[T]] + %coeffsC = arith.constant dense<4> : tensor<1024xi32> + // CHECK-DAG: [[D:%.+]] = arith.constant dense<5> : [[T]] + %coeffsD = arith.constant dense<5> : tensor<1024xi32> + %polyA = polynomial.from_tensor %coeffsA : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyB = polynomial.from_tensor %coeffsB : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyC = polynomial.from_tensor %coeffsC : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %polyD = polynomial.from_tensor %coeffsD : tensor<1024xi32> -> !polynomial.polynomial<#ring> + %tensor1 = tensor.from_elements %polyA, %polyB : tensor<2x!polynomial.polynomial<#ring>> + %tensor2 = tensor.from_elements %polyC, %polyD : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[S1:%.+]] = arith.constant dense<[1, 1024]> : [[TI:tensor<2xindex>]] + // CHECK: [[T1:%.+]] = tensor.reshape [[A]]([[S1]]) : ([[T]], [[TI]]) -> [[TEX:tensor<1x1024xi32>]] + // CHECK: [[S2:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T2:%.+]] = tensor.reshape [[B]]([[S2]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[C1:%.+]] = tensor.concat dim(0) [[T1]], [[T2]] : ([[TEX]], [[TEX]]) -> [[TT:tensor<2x1024xi32>]] + // CHECK: [[S3:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T3:%.+]] = tensor.reshape [[C]]([[S3]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[S4:%.+]] = arith.constant dense<[1, 1024]> : [[TI]] + // CHECK: [[T4:%.+]] = tensor.reshape [[D]]([[S4]]) : ([[T]], [[TI]]) -> [[TEX]] + // CHECK: [[C2:%.+]] = tensor.concat dim(0) [[T3]], [[T4]] : ([[TEX]], [[TEX]]) -> [[TT:tensor<2x1024xi32>]] + // CHECK-NOT: polynomial.from_tensor + // CHECK-NOT: tensor.from_elements + %tensor3 = affine.for %i = 0 to 2 iter_args(%t0 = %tensor1) -> tensor<2x!polynomial.polynomial<#ring>> { + // CHECK: [[FOR:%.]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[C1]]) -> ([[TT]]) { + %a = tensor.extract %tensor1[%i] : tensor<2x!polynomial.polynomial<#ring>> + %b = tensor.extract %tensor2[%i] : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[AA:%.+]] = tensor.extract_slice [[C1]][[[I]], 0] [1, 1024] [1, 1] : [[TT]] + // CHECK: [[BB:%.+]] = tensor.extract_slice [[C2]][[[I]], 0] [1, 1024] [1, 1] : [[TT]] + // CHECK-NOT: tensor.extract % + %s = polynomial.add(%a, %b) : !polynomial.polynomial<#ring> + // CHECK: [[SUM:%.+]] = arith.addi [[AA]], [[BB]] : [[T]] + // CHECK-NOT: polynomial.add + %t = tensor.insert %s into %t0[%i] : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: [[INS:%.+]] = tensor.insert_slice [[SUM]] into [[T0]][[[I]], 0] [1, 1024] [1, 1] : [[T]] into [[TT]] + // CHECK-NOT: tensor.insert % + affine.yield %t : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: affine.yield [[INS]] : [[TT]] + } + return %tensor3 : tensor<2x!polynomial.polynomial<#ring>> + // CHECK: return [[FOR]] : [[TT]] +} From 73213279e0325ce294a0a936f5577e3220a9ce24 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 Mar 2024 20:55:21 +0000 Subject: [PATCH 09/16] bugfix for addTensorOfTensorConversionPatterns --- lib/Conversion/Utils.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/Utils.cpp b/lib/Conversion/Utils.cpp index f34a1fabf..32cb51e09 100644 --- a/lib/Conversion/Utils.cpp +++ b/lib/Conversion/Utils.cpp @@ -222,9 +222,8 @@ struct ConvertFromElements void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { - target.addDynamicallyLegalDialect([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()); - }); + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); typeConverter.addConversion([&](TensorType type) -> Type { if (!typeConverter.isLegal(type.getElementType())) { @@ -248,9 +247,8 @@ void addTensorOfTensorConversionPatterns(TypeConverter &typeConverter, return type; }); - target.addDynamicallyLegalDialect([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()); - }); + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); patterns.add( typeConverter, patterns.getContext()); From 3b920082e3ff15bad4639e876198db2c06adb953 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Mar 2024 00:21:17 +0000 Subject: [PATCH 10/16] Add -convert-elementwise-to-affine pass This works similarily to -convert-elementwise-to-linalg, but instead of lowering to a `linalg.generic` over `memref`s, this pass lowers to a nest of `affine.for` loops still over `tensor`s. This is used in HEIR for lowering `polynomial` ops over tensors (e.g., in `-polynomial-to-standard`). --- .../Dialect/Polynomial/IR/PolynomialTypes.td | 2 +- include/Transforms/ElementwiseToAffine/BUILD | 35 +++++ .../ElementwiseToAffine/ElementwiseToAffine.h | 18 +++ .../ElementwiseToAffine.td | 18 +++ lib/Transforms/ElementwiseToAffine/BUILD | 22 +++ .../ElementwiseToAffine.cpp | 136 ++++++++++++++++++ tests/polynomial/elementwise.mlir | 10 -- tests/polynomial/elementwise_to_affine.mlir | 90 ++++++++++++ tools/BUILD | 2 + tools/heir-opt.cpp | 20 ++- 10 files changed, 331 insertions(+), 22 deletions(-) create mode 100644 include/Transforms/ElementwiseToAffine/BUILD create mode 100644 include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h create mode 100644 include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td create mode 100644 lib/Transforms/ElementwiseToAffine/BUILD create mode 100644 lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp delete mode 100644 tests/polynomial/elementwise.mlir create mode 100644 tests/polynomial/elementwise_to_affine.mlir diff --git a/include/Dialect/Polynomial/IR/PolynomialTypes.td b/include/Dialect/Polynomial/IR/PolynomialTypes.td index 2637131da..14d282464 100644 --- a/include/Dialect/Polynomial/IR/PolynomialTypes.td +++ b/include/Dialect/Polynomial/IR/PolynomialTypes.td @@ -14,7 +14,7 @@ class Polynomial_Type traits = []> let mnemonic = typeMnemonic; } -def Polynomial : Polynomial_Type<"Polynomial", "polynomial", [MemRefElementTypeInterface]> { +def Polynomial : Polynomial_Type<"Polynomial", "polynomial"> { let summary = "An element of a polynomial quotient ring"; let description = [{ diff --git a/include/Transforms/ElementwiseToAffine/BUILD b/include/Transforms/ElementwiseToAffine/BUILD new file mode 100644 index 000000000..07fd7c045 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/BUILD @@ -0,0 +1,35 @@ +# ElementwiseToAffine tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "ElementwiseToAffine.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ElementwiseToAffine", + ], + "ElementwiseToAffine.h.inc", + ), + ( + ["-gen-pass-doc"], + "ElementwiseToAffinePasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ElementwiseToAffine.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h new file mode 100644 index 000000000..48b4a1b49 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ +#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_H_ diff --git a/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td new file mode 100644 index 000000000..cd83ee554 --- /dev/null +++ b/include/Transforms/ElementwiseToAffine/ElementwiseToAffine.td @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ +#define INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ + +include "mlir/Pass/PassBase.td" + +def ElementwiseToAffine : Pass<"convert-elementwise-to-affine"> { + let summary = "This pass lowers ElementwiseMappable operations to Affine loops."; + let description = [{ + This pass lowers ElementwiseMappable operations over tensors + to affine loop nests that instead apply the operation to the underlying scalar values. + }]; + let dependentDialects = [ + "mlir::affine::AffineDialect", + "mlir::tensor::TensorDialect" + ]; +} + +#endif // INCLUDE_TRANSFORMS_ELEMENTWISETOAFFINE_ELEMENTWISETOAFFINE_TD_ diff --git a/lib/Transforms/ElementwiseToAffine/BUILD b/lib/Transforms/ElementwiseToAffine/BUILD new file mode 100644 index 000000000..c0a142594 --- /dev/null +++ b/lib/Transforms/ElementwiseToAffine/BUILD @@ -0,0 +1,22 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ElementwiseToAffine", + srcs = ["ElementwiseToAffine.cpp"], + hdrs = [ + "@heir//include/Transforms/ElementwiseToAffine:ElementwiseToAffine.h", + ], + deps = [ + "@heir//include/Transforms/ElementwiseToAffine:pass_inc_gen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp b/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp new file mode 100644 index 000000000..9e08662a9 --- /dev/null +++ b/lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.cpp @@ -0,0 +1,136 @@ +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm - project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_ELEMENTWISETOAFFINE +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h.inc" + +// All of this is based on the ElementwiseToLinalg Pass in +// mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp + +static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { + if (!OpTrait::hasElementwiseMappableTraits(op)) return false; + + // TODO(#534): Test ElementwiseToAffine with `any_of` constraints + // as the pass should (in theory) support scalar operands, too + return llvm::all_of(op->getOperandTypes(), + [](Type type) { return isa(type); }); +} + +namespace { + +struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { + ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isElementwiseMappableOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure( + op, "requires elementwise op on ranked tensors"); + + auto resultType = cast(op->getResult(0).getType()); + auto elementType = resultType.getElementType(); + auto shape = resultType.getShape(); + auto rank = resultType.getRank(); + + // Save insertion point prior to entering loop nest + auto ip = rewriter.saveInsertionPoint(); + + // Create an empty tensor as initial value of the iter_args + Value target = + rewriter.create(op->getLoc(), shape, elementType); + + llvm::SmallVector indices; + + // Create a an affine.for loop nest of depth rank + for (size_t i = 0; i < rank; ++i) { + auto loop = + rewriter.create(op->getLoc(), /* lowerBound*/ 0, + /* upperBound*/ shape[i], + /* step*/ 1, + /* iterArgs*/ target); + + // Update target & indices + target = loop.getRegionIterArgs().front(); + indices.push_back(loop.getInductionVar()); + + // If first loop: replace scalar op + if (i == 0) { + rewriter.replaceOp(op, loop); + } else { // yield the result of this loop + rewriter.create(op->getLoc(), + loop->getResults()); + } + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Create the innermost body + auto resultTypes = + llvm::to_vector<6>(llvm::map_range(op->getResultTypes(), [](Type type) { + return cast(type).getElementType(); + })); + + // Generate a `tensor.extract` for each tensor operand + SmallVector newOperands; + for (auto operand : op->getOperands()) { + if (operand.getType().isa()) { + // We don't need to check the shape, as ElementwiseMappable + // requires all tensor operands to have compatible shapes + auto extractOp = rewriter.create(operand.getLoc(), + operand, indices); + newOperands.push_back(extractOp); + } else { + // scalar (technically, "non-tensor") operands can be reused as-is + newOperands.push_back(operand); + } + } + + // "lowered" operation is the same operation, but over non-tensor + // operands + auto *scalarOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + newOperands, resultTypes, op->getAttrs()); + + // insert scalarOp into the tensor at right index + Value inserted = rewriter.create( + op->getLoc(), scalarOp->getResult(0), target, indices); + + // replace lingalg.yield scalarOp with affine.yield insertedOp + rewriter.create(op->getLoc(), inserted); + + // reset insertion point + rewriter.restoreInsertionPoint(ip); + + return success(); + } +}; +} // namespace + +struct ElementwiseToAffine + : impl::ElementwiseToAffineBase { + using ElementwiseToAffineBase::ElementwiseToAffineBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + patterns.add(context); + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return !isElementwiseMappableOpOnRankedTensors(op); + }); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/polynomial/elementwise.mlir b/tests/polynomial/elementwise.mlir deleted file mode 100644 index 0720672b8..000000000 --- a/tests/polynomial/elementwise.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: heir-opt --convert-elementwise-to-linalg --one-shot-bufferize --convert-linalg-to-affine-loops %s | FileCheck %s - -!poly = !polynomial.polynomial<>> - -// CHECK-LABEL: @test_bin_ops -// CHECK: affine.for -func.func @test_bin_ops(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { - %0 = polynomial.add(%arg0, %arg1) : tensor<2x!poly> - return %0 : tensor<2x!poly> -} diff --git a/tests/polynomial/elementwise_to_affine.mlir b/tests/polynomial/elementwise_to_affine.mlir new file mode 100644 index 000000000..27780257d --- /dev/null +++ b/tests/polynomial/elementwise_to_affine.mlir @@ -0,0 +1,90 @@ +// RUN: heir-opt --convert-elementwise-to-affine %s | FileCheck --enable-var-scope %s + +!poly = !polynomial.polynomial<>> + +// CHECK-LABEL: @test_elementwise +// CHECK: {{.*}} -> [[T:tensor<2x!polynomial.*33538049.*]] { +func.func @test_elementwise(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x!polynomial.*33538049.*]] + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]]] : [[T]] + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]]] : [[T]] + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T0]][[[I]]] : [[T]] + // CHECK: affine.yield [[R]] : [[T]] + %0 = polynomial.add(%arg0, %arg1) : tensor<2x!poly> + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x!poly> +} + +// This is just here to make sure the FileCheck commands above work as expected +// CHECK-LABEL: @lowered_elementwise +// CHECK: {{.*}} -> [[T:tensor<2x!polynomial.*33538049.*]] { +func.func @lowered_elementwise(%arg0: tensor<2x!poly>, %arg1: tensor<2x!poly>) -> tensor<2x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x!polynomial.*33538049.*]] + %empty = tensor.empty() : tensor<2x!poly> + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + %0 = affine.for %i = 0 to 2 iter_args(%t0 = %empty) -> (tensor<2x!poly>) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]]] : [[T]] + %a = tensor.extract %arg0[%i] : tensor<2x!poly> + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]]] : [[T]] + %b = tensor.extract %arg1[%i] : tensor<2x!poly> + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + %s = polynomial.add(%a, %b) : !poly + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T0]][[[I]]] : [[T]] + %r = tensor.insert %s into %t0[%i] : tensor<2x!poly> + // CHECK: affine.yield [[R]] : [[T]] + affine.yield %r : tensor<2x!poly> + } + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x!poly> +} + +// CHECK-LABEL: @test_elementwise_multidim +// CHECK: {{.*}} -> [[T:tensor<2x3x!polynomial.*33538049.*]] { +func.func @test_elementwise_multidim(%arg0: tensor<2x3x!poly>, %arg1: tensor<2x3x!poly>) -> tensor<2x3x!poly> { + %0 = polynomial.add(%arg0, %arg1) : tensor<2x3x!poly> + return %0 : tensor<2x3x!poly> + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x3x!polynomial.*33538049.*]] + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + // CHECK: [[INNERLOOP:%.+]] = affine.for [[J:%.+]] = 0 to 3 iter_args([[T1:%.+]] = [[T0]]) -> ([[T]]) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]], [[J]]] : [[T]] + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]], [[J]]] : [[T]] + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T1]][[[I]], [[J]]] : [[T]] + // CHECK: affine.yield [[R]] : [[T]] + // CHECK: affine.yield [[INNERLOOP]] : [[T]] + // CHECK: return [[LOOP]] : [[T]] +} + +// This is just here to make sure the FileCheck commands above work as expected +// CHECK-LABEL: @lowered_elementwise_multidim +// CHECK: {{.*}} -> [[T:tensor<2x3x!polynomial.*33538049.*]] { +func.func @lowered_elementwise_multidim(%arg0: tensor<2x3x!poly>, %arg1: tensor<2x3x!poly>) -> tensor<2x3x!poly> { + // CHECK: [[EMPTY:%.+]] = tensor.empty() : [[T:tensor<2x3x!polynomial.*33538049.*]] + %empty = tensor.empty() : tensor<2x3x!poly> + // CHECK: [[LOOP:%.+]] = affine.for [[I:%.+]] = 0 to 2 iter_args([[T0:%.+]] = [[EMPTY]]) -> ([[T]]) { + %0 = affine.for %i = 0 to 2 iter_args(%t0 = %empty) -> (tensor<2x3x!poly>) { + // CHECK: [[INNERLOOP:%.+]] = affine.for [[J:%.+]] = 0 to 3 iter_args([[T1:%.+]] = [[T0]]) -> ([[T]]) { + %1 = affine.for %j = 0 to 3 iter_args(%t1 = %t0) -> (tensor<2x3x!poly>) { + // CHECK: [[A:%.+]] = tensor.extract %arg0[[[I]], [[J]]] : [[T]] + %a = tensor.extract %arg0[%i, %j] : tensor<2x3x!poly> + // CHECK: [[B:%.+]] = tensor.extract %arg1[[[I]], [[J]]] : [[T]] + %b = tensor.extract %arg1[%i, %j] : tensor<2x3x!poly> + // CHECK: [[S:%.+]] = polynomial.add([[A]], [[B]]) : [[P:!polynomial.*33538049.*]] + // CHECK-NOT: polynomial.add{{.*}} : [[T]] + %s = polynomial.add(%a, %b) : !poly + // CHECK: [[R:%.+]] = tensor.insert [[S]] into [[T1]][[[I]], [[J]]] : [[T]] + %r = tensor.insert %s into %t1[%i, %j] : tensor<2x3x!poly> + // CHECK: affine.yield [[R]] : [[T]] + affine.yield %r : tensor<2x3x!poly> + } + // CHECK: affine.yield [[INNERLOOP]] : [[T]] + affine.yield %1 : tensor<2x3x!poly> + } + // CHECK: return [[LOOP]] : [[T]] + return %0 : tensor<2x3x!poly> +} diff --git a/tools/BUILD b/tools/BUILD index a2446bd13..a0d5db130 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -57,6 +57,7 @@ cc_binary( "@heir//lib/Dialect/TensorExt/Transforms", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", + "@heir//lib/Transforms/ElementwiseToAffine", "@heir//lib/Transforms/ForwardStoreToLoad", "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", @@ -65,6 +66,7 @@ cc_binary( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 14dade8cd..bd6fa84c2 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -24,6 +24,7 @@ #include "include/Dialect/TensorExt/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" +#include "include/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" @@ -32,12 +33,9 @@ #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project +#include "mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h" // from @llvm-project @@ -57,6 +55,7 @@ #include "mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/include/mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/include/mlir/InitAllExtensions.h" // from @llvm-project #include "mlir/include/mlir/InitAllPasses.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project @@ -132,6 +131,7 @@ void tosaPipelineBuilder(OpPassManager &manager) { void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // Poly + manager.addPass(createElementwiseToAffine()); manager.addPass(polynomial::createPolynomialToStandard()); manager.addPass(createCanonicalizerPass()); @@ -140,6 +140,7 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // Needed to lower affine.map and affine.apply manager.addNestedPass(affine::createAffineExpandIndexOpsPass()); manager.addNestedPass(affine::createSimplifyAffineStructuresPass()); + manager.addPass(createLowerAffinePass()); manager.addNestedPass(memref::createExpandOpsPass()); manager.addNestedPass(memref::createExpandStridedMetadataPass()); @@ -160,14 +161,9 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) { // ToLLVM manager.addPass(arith::createArithExpandOpsPass()); manager.addPass(createConvertSCFToCFPass()); - manager.addPass(createConvertControlFlowToLLVMPass()); - manager.addPass(createConvertIndexToLLVMPass()); manager.addNestedPass(memref::createExpandStridedMetadataPass()); - manager.addPass(createCanonicalizerPass()); - manager.addPass(createConvertFuncToLLVMPass()); - manager.addPass(createArithToLLVMConversionPass()); - manager.addPass(createFinalizeMemRefToLLVMConversionPass()); - manager.addPass(createReconcileUnrealizedCastsPass()); + manager.addPass(createConvertToLLVMPass()); + // Cleanup manager.addPass(createCanonicalizerPass()); manager.addPass(createSCCPPass()); @@ -285,6 +281,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registerAllDialects(registry); + registerAllExtensions(registry); // Register MLIR core passes to build pipeline. registerAllPasses(); @@ -294,6 +291,7 @@ int main(int argc, char **argv) { lwe::registerLWEPasses(); secret::registerSecretPasses(); tensor_ext::registerTensorExtPasses(); + registerElementwiseToAffinePasses(); registerSecretizePasses(); registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); From 008ec0e703be2e6254e1be887b296b83b877aa32 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Thu, 14 Mar 2024 21:31:59 -0700 Subject: [PATCH 11/16] Automated Code Change PiperOrigin-RevId: 616004344 --- tests/secretize/BUILD | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/secretize/BUILD b/tests/secretize/BUILD index 6c9032391..c571e6fc6 100644 --- a/tests/secretize/BUILD +++ b/tests/secretize/BUILD @@ -1,9 +1,6 @@ load("//bazel:lit.bzl", "glob_lit_tests") -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) +package(default_applicable_licenses = ["@heir//:license"]) glob_lit_tests( name = "all_tests", From 3ddf09d7782d0940286162d8443dd33d576a4575 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Fri, 15 Mar 2024 08:50:47 -0700 Subject: [PATCH 12/16] Integrate LLVM at llvm/llvm-project@3b5e7c83a6e2 Updates LLVM usage to match [3b5e7c83a6e2](https://github.com/llvm/llvm-project/commit/3b5e7c83a6e2) PiperOrigin-RevId: 616143706 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 118539968..37ac5499e 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "e371ada409b225ea990b5ac0d5cafea26a6046e1" + LLVM_COMMIT = "3b5e7c83a6e226d5bd7ed2e9b67449b64812074c" new_git_repository( name = name, From 989363f51ed115319b0f4d456e7462a24a5e72c8 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 12 Mar 2024 10:40:24 -0700 Subject: [PATCH 13/16] Add target slot analysis --- include/Analysis/TargetSlotAnalysis/BUILD | 9 ++ .../TargetSlotAnalysis/TargetSlotAnalysis.h | 134 ++++++++++++++++++ include/Dialect/BUILD | 16 +++ lib/Analysis/TargetSlotAnalysis/BUILD | 18 +++ .../TargetSlotAnalysis/TargetSlotAnalysis.cpp | 55 +++++++ lib/Dialect/BUILD | 12 ++ lib/Dialect/TensorExt/Transforms/BUILD | 2 + .../Transforms/CollapseInsertionChains.cpp | 20 +-- .../TensorExt/Transforms/InsertRotate.cpp | 42 +++++- lib/Dialect/Utils.h | 37 +++++ 10 files changed, 321 insertions(+), 24 deletions(-) create mode 100644 include/Analysis/TargetSlotAnalysis/BUILD create mode 100644 include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h create mode 100644 lib/Analysis/TargetSlotAnalysis/BUILD create mode 100644 lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp create mode 100644 lib/Dialect/Utils.h diff --git a/include/Analysis/TargetSlotAnalysis/BUILD b/include/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..085e170b2 --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,9 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + ["TargetSlotAnalysis.h"], +) diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h new file mode 100644 index 000000000..984fb338c --- /dev/null +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -0,0 +1,134 @@ +#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ +#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ + +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +/// A target slot is an identification of a downstream tensor index at which an +/// SSA value will be used. To make the previous sentence even mildly +/// comprehensible, consider it in the following example. +/// +/// %c3 = arith.constant 3 : index +/// %c4 = arith.constant 4 : index +/// %c11 = arith.constant 11 : index +/// %c15 = arith.constant 15 : index +/// %v11 = tensor.extract %arg1[%c11] : tensor<16xi32> +/// %v15 = tensor.extract %arg1[%c15] : tensor<16xi32> +/// %1 = arith.addi %v11, %v15: i32 +/// %v3 = tensor.extract %arg1[%c3] : tensor<16xi32> +/// %2 = arith.addi %v3, %1 : i32 +/// %inserted = tensor.insert %2 into %output[%c4] : tensor<16xi32> +/// +/// In vectorized FHE schemes like BGV, the computation model does not +/// efficiently support extracting values at particular indices; instead, it +/// supports SIMD additions of entire vectors, and cyclic rotations of vectors +/// by constant shifts. To optimize the above computation, we want to convert +/// the extractions to rotations, and minimize rotations as much as possible. +/// +/// A naive conversion convert tensor.extract %arg1[Z] to arith.rotate %arg1, +/// Z, always placing the needed values in the zero-th slot. However, the last +/// line above indicates that the downstream dependencies of these computations +/// are ultimately needed in slot 4 of the %output tensor. So one could reduce +/// the number of rotations by rotating instead to slot 4, so that the final +/// rotation is not needed. +/// +/// This analysis identifies that downstream insertion index, and propagates it +/// backward through the IR to attach it to each SSA value, enabling later +/// optimization passes to access it easily. +/// +/// As it turns out, if the IR is well-structured, such as an unrolled affine +/// for loop with simple iteration strides, then aligning to target slots in +/// this way leads to many common sub-expressions that can be eliminated. Cf. +/// the insert-rotate pass for more on that. + +class TargetSlot { + public: + TargetSlot() : value(std::nullopt) {} + TargetSlot(int64_t value) : value(value) {} + ~TargetSlot() = default; + + /// Whether the slot target is initialized. It can be uninitialized when the + /// state hasn't been set during the analysis. + bool isInitialized() const { return value.has_value(); } + + /// Get a known slot target. + const int64_t &getValue() const { + assert(isInitialized()); + return *value; + } + + bool operator==(const TargetSlot &rhs) const { return value == rhs.value; } + + /// Join two target slots. + static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) { + if (!lhs.isInitialized()) return rhs; + if (!rhs.isInitialized()) return lhs; + // If they are both initialized, use an arbitrary deterministic rule to + // select one. A more sophisticated analysis could try to determine which + // slot is more likely to lead to beneficial optimizations. + return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue() + : rhs.getValue()}; + } + + void print(raw_ostream &os) const { os << value; } + + private: + /// The target slot, if known. + std::optional value; + + friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic, + const TargetSlot &foo) { + if (foo.isInitialized()) { + return diagnostic << foo.getValue(); + } + return diagnostic << "uninitialized"; + } +}; + +inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) { + v.print(os); + return os; +} + +class TargetSlotLattice : public dataflow::Lattice { + public: + using Lattice::Lattice; +}; + +/// An analysis that identifies a target slot for an SSA value in a program. +/// This is used by downstream passes to determine how to align rotations in +/// vectorized FHE schemes. +/// +/// We use a backward dataflow analysis because the target slot propagates +/// backward from its final use to the arithmetic operations at which rotations +/// can be optimized. +class TargetSlotAnalysis + : public dataflow::SparseBackwardDataFlowAnalysis { + public: + explicit TargetSlotAnalysis(DataFlowSolver &solver, + SymbolTableCollection &symbolTable) + : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + ~TargetSlotAnalysis() override = default; + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + // Given the computed results of the operation, update its operand lattice + // values. + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override{}; + void visitCallOperand(OpOperand &operand) override{}; + void setToExitState(TargetSlotLattice *lattice) override{}; +}; + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_ diff --git a/include/Dialect/BUILD b/include/Dialect/BUILD index bebb71ecf..3eb3ce231 100644 --- a/include/Dialect/BUILD +++ b/include/Dialect/BUILD @@ -10,6 +10,7 @@ package( exports_files( [ "HEIRInterfaces.h", + "Utils.h", ], ) @@ -43,3 +44,18 @@ gentbl_cc_library( ":td_files", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + hdrs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Analysis/TargetSlotAnalysis/BUILD b/lib/Analysis/TargetSlotAnalysis/BUILD new file mode 100644 index 000000000..cf7438e35 --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/BUILD @@ -0,0 +1,18 @@ +# TargetSlotAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "TargetSlotAnalysis", + srcs = ["TargetSlotAnalysis.cpp"], + hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"], + deps = [ + "@heir//lib/Dialect:Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorDialect", + ], +) diff --git a/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp new file mode 100644 index 000000000..cdb4b41eb --- /dev/null +++ b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp @@ -0,0 +1,55 @@ +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" + +#include "lib/Dialect/Utils.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project + +#define DEBUG_TYPE "target-slot-analysis" + +namespace mlir { +namespace heir { +namespace target_slot_analysis { + +void TargetSlotAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + llvm::TypeSwitch(*op) + .Case([&](auto insertOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); + auto insertIndexRes = get1DExtractionIndex(insertOp); + // If the target slot can't be statically determined, we can't + // propagate anything through the IR. + if (failed(insertIndexRes)) return; + + // The target slot propagates to the value inserted, which is the first + // positional argument + TargetSlotLattice *lattice = operands[0]; + TargetSlot newSlot = TargetSlot{insertIndexRes.value()}; + LLVM_DEBUG({ + llvm::dbgs() << "Joining " << lattice->getValue() << " and " + << newSlot << " --> " + << TargetSlot::join(lattice->getValue(), newSlot) + << "\n"; + }); + ChangeResult changed = lattice->join(newSlot); + propagateIfChanged(lattice, changed); + }) + .Default([&](Operation &op) { + // By default, an op propagates its result target slots to all its + // operands. + for (const TargetSlotLattice *r : results) { + for (TargetSlotLattice *operand : operands) { + ChangeResult result = operand->join(*r); + propagateIfChanged(operand, result); + } + } + }); +} + +} // namespace target_slot_analysis +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/BUILD b/lib/Dialect/BUILD index fb5356f75..9d031b48e 100644 --- a/lib/Dialect/BUILD +++ b/lib/Dialect/BUILD @@ -18,3 +18,15 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index 4258b08ff..bfcc3fad2 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -28,6 +28,7 @@ cc_library( "@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen", "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Analysis/TargetSlotAnalysis", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", @@ -45,6 +46,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect:Utils", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp index 63c2a9983..f64337780 100644 --- a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp +++ b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp @@ -4,6 +4,7 @@ #include #include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "lib/Dialect/Utils.h" #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -27,25 +28,6 @@ namespace tensor_ext { #define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS #include "include/Dialect/TensorExt/Transforms/Passes.h.inc" -template -FailureOr get1DExtractionIndex(Op op) { - auto insertIndices = op.getIndices(); - if (insertIndices.size() != 1) return failure(); - - // Each index must be constant; this may require running --canonicalize or - // -sccp before this pass to apply folding rules (use -sccp if you need to - // fold constants through control flow). - Value insertIndex = *insertIndices.begin(); - auto insertIndexConstOp = insertIndex.getDefiningOp(); - if (!insertIndexConstOp) return failure(); - - auto insertOffsetAttr = - llvm::dyn_cast(insertIndexConstOp.getValue()); - if (!insertOffsetAttr) return failure(); - - return insertOffsetAttr.getInt(); -} - /// A pattern that searches for sequences of extract + insert, where the /// indices extracted and inserted have the same offset, and replaced them with /// a single rotate operation. diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp index c829ba34d..2641792dd 100644 --- a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -2,14 +2,21 @@ #include +#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" #include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#define DEBUG_TYPE "insert-rotate" + namespace mlir { namespace heir { namespace tensor_ext { @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + SymbolTableCollection symbolTable; + DataFlowSolver solver; + // These two upstream analyses are required dependencies for any sparse + // dataflow analysis, or else the analysis will be a no-op. Cf. + // https://github.com/llvm/llvm-project/issues/58922 + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + LLVM_DEBUG({ + getOperation()->walk([&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + llvm::dbgs() << "Target slot for op " << *op << ": " + << targetSlotLattice->getValue() << "\n"; + }); + }); + alignment::populateWithGenerated(patterns); canonicalization::populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/lib/Dialect/Utils.h b/lib/Dialect/Utils.h new file mode 100644 index 000000000..4c25a5e04 --- /dev/null +++ b/lib/Dialect/Utils.h @@ -0,0 +1,37 @@ +#ifndef INCLUDE_DIALECT_UTILS_H_ +#define INCLUDE_DIALECT_UTILS_H_ + +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// Given a tensor::InsertOp or tensor::ExtractOp, and assuming the shape +/// of the input tensor is 1-dimensional and the input index is constant, +/// return the constant index value. If any of these conditions are not +/// met, return a failure. +template +FailureOr get1DExtractionIndex(Op op) { + auto insertIndices = op.getIndices(); + if (insertIndices.size() != 1) return failure(); + + // Each index must be constant; this may require running --canonicalize or + // -sccp before this pass to apply folding rules (use -sccp if you need to + // fold constants through control flow). + Value insertIndex = *insertIndices.begin(); + auto insertIndexConstOp = insertIndex.getDefiningOp(); + if (!insertIndexConstOp) return failure(); + + auto insertOffsetAttr = + llvm::dyn_cast(insertIndexConstOp.getValue()); + if (!insertOffsetAttr) return failure(); + + return insertOffsetAttr.getInt(); +} + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_UTILS_H_ From b0a1a9a909a5b6b406e2ff9da59efe6d630d53c9 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Fri, 15 Mar 2024 13:56:58 -0700 Subject: [PATCH 14/16] yosys: add boolean gate yosys optimization for arith -> comb Fixes https://github.com/google/heir/issues/281 PiperOrigin-RevId: 616234946 --- include/Dialect/Comb/IR/Combinational.td | 12 + .../YosysOptimizer/YosysOptimizer.h | 16 +- .../YosysOptimizer/YosysOptimizer.td | 1 + lib/Dialect/Comb/IR/CombOps.cpp | 6 + lib/Transforms/YosysOptimizer/BUILD | 17 ++ .../YosysOptimizer/BooleanGateImporter.cpp | 58 ++++ .../YosysOptimizer/BooleanGateImporter.h | 37 +++ lib/Transforms/YosysOptimizer/LUTImporter.cpp | 12 +- lib/Transforms/YosysOptimizer/LUTImporter.h | 10 +- .../YosysOptimizer/RTLILImporter.cpp | 2 +- lib/Transforms/YosysOptimizer/RTLILImporter.h | 14 +- .../YosysOptimizer/YosysOptimizer.cpp | 70 +++-- lib/Transforms/YosysOptimizer/yosys/BUILD | 1 + .../yosys/tfhe-rs_cells.liberty | 260 ++++++++++++++++++ tests/yosys_optimizer/add_one.mlir | 4 +- 15 files changed, 484 insertions(+), 36 deletions(-) create mode 100644 lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp create mode 100644 lib/Transforms/YosysOptimizer/BooleanGateImporter.h create mode 100644 lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty diff --git a/include/Dialect/Comb/IR/Combinational.td b/include/Dialect/Comb/IR/Combinational.td index faa8bd2a0..639da10a0 100644 --- a/include/Dialect/Comb/IR/Combinational.td +++ b/include/Dialect/Comb/IR/Combinational.td @@ -80,6 +80,9 @@ def XorOp : UTVariadicOp<"xor", [Commutative]> { bool isBinaryNot(); }]; } +def XNorOp : UTVariadicOp<"xnor">; +def NandOp : UTVariadicOp<"nand">; +def NorOp : UTVariadicOp<"nor">; //===----------------------------------------------------------------------===// // Comparisons @@ -154,6 +157,15 @@ def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { // Unary Operations //===----------------------------------------------------------------------===// +class UnaryOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))"; +} +def InvOp : UnaryOp<"inv">; + // Base class for unary reduction operations that produce an i1. class UnaryI1ReductionOp traits = []> : CombOp { diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.h b/include/Transforms/YosysOptimizer/YosysOptimizer.h index 17b9939b2..c7339eea0 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.h +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.h @@ -1,14 +1,19 @@ #ifndef INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_H_ #define INCLUDE_TRANSFORMS_YOSYSOPTIMIZER_YOSYSOPTIMIZER_H_ -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include + +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project namespace mlir { namespace heir { +enum Mode { Boolean, LUT }; + std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor = 0, bool printStats = false); + int unrollFactor = 0, Mode mode = LUT, bool printStats = false); #define GEN_PASS_DECL #include "include/Transforms/YosysOptimizer/YosysOptimizer.h.inc" @@ -25,6 +30,13 @@ struct YosysOptimizerPipelineOptions "value of zero (default) prevents unrolling."), llvm::cl::init(0)}; + PassOptions::Option mode{ + *this, "mode", + llvm::cl::desc("Map gates to boolean gates or lookup table gates."), + llvm::cl::init(LUT), + llvm::cl::values(clEnumVal(Boolean, "use boolean gates"), + clEnumVal(LUT, "use lookup tables"))}; + PassOptions::Option printStats{ *this, "print-stats", llvm::cl::desc("Prints statistics about the optimized circuit"), diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.td b/include/Transforms/YosysOptimizer/YosysOptimizer.td index a4e52f67d..9bfcbb4de 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.td +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.td @@ -24,6 +24,7 @@ def YosysOptimizer : Pass<"yosys-optimizer"> { - `unroll-factor`: Before optimizing the circuit, unroll loops by a given factor. If unset, this pass will not unroll any loops. - `print-stats`: Prints statistics about the optimized circuits. + - `mode={Boolean,LUT}`: Map gates to boolean gates or lookup table gates. }]; // TODO(#257): add option for the pass to select the unroll factor // automatically. diff --git a/lib/Dialect/Comb/IR/CombOps.cpp b/lib/Dialect/Comb/IR/CombOps.cpp index d0024c25f..0f77c8485 100644 --- a/lib/Dialect/Comb/IR/CombOps.cpp +++ b/lib/Dialect/Comb/IR/CombOps.cpp @@ -160,6 +160,12 @@ LogicalResult OrOp::verify() { return verifyUTBinOp(*this); } LogicalResult XorOp::verify() { return verifyUTBinOp(*this); } +LogicalResult XNorOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult NandOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult NorOp::verify() { return verifyUTBinOp(*this); } + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/YosysOptimizer/BUILD b/lib/Transforms/YosysOptimizer/BUILD index dc46f338d..7b9926124 100644 --- a/lib/Transforms/YosysOptimizer/BUILD +++ b/lib/Transforms/YosysOptimizer/BUILD @@ -62,6 +62,22 @@ cc_test( ], ) +cc_library( + name = "BooleanGateImporter", + srcs = ["BooleanGateImporter.cpp"], + hdrs = ["BooleanGateImporter.h"], + deps = [ + ":RTLILImporter", + "@at_clifford_yosys//:kernel", + "@heir//lib/Dialect/Comb/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "YosysOptimizer", srcs = ["YosysOptimizer.cpp"], @@ -73,6 +89,7 @@ cc_library( "@heir//lib/Transforms/YosysOptimizer/yosys:share_files", ], deps = [ + ":BooleanGateImporter", ":LUTImporter", ":RTLILImporter", "@at_clifford_yosys//:kernel", diff --git a/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp b/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp new file mode 100644 index 000000000..6d2613645 --- /dev/null +++ b/lib/Transforms/YosysOptimizer/BooleanGateImporter.cpp @@ -0,0 +1,58 @@ +#include "lib/Transforms/YosysOptimizer/BooleanGateImporter.h" + +#include + +#include "include/Dialect/Comb/IR/CombOps.h" +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + +namespace mlir { +namespace heir { + +mlir::Operation *BooleanGateImporter::createOp(Yosys::RTLIL::Cell *cell, + SmallVector &inputs, + ImplicitLocOpBuilder &b) const { + auto op = llvm::StringSwitch(cell->type.substr(1)) + .Case("inv", b.create(inputs[0], false)) + .Case("xnor2", b.create(inputs, false)) + .Case("and2", b.create(inputs, false)) + .Case("xor2", b.create(inputs, false)) + .Case("nand2", b.create(inputs, false)) + .Case("nor2", b.create(inputs, false)) + .Case("or2", b.create(inputs, false)) + .Default(nullptr); + if (op == nullptr) { + llvm_unreachable("unexpected cell type"); + } + return op; +} + +SmallVector BooleanGateImporter::getInputs( + Yosys::RTLIL::Cell *cell) const { + // Return all non-Y named attributes. + SmallVector inputs; + for (auto &conn : cell->connections()) { + if (conn.first.contains("Y")) { + continue; + } + inputs.push_back(conn.second); + } + + return inputs; +} + +Yosys::RTLIL::SigSpec BooleanGateImporter::getOutput( + Yosys::RTLIL::Cell *cell) const { + return cell->getPort(Yosys::RTLIL::IdString("\\Y")); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/YosysOptimizer/BooleanGateImporter.h b/lib/Transforms/YosysOptimizer/BooleanGateImporter.h new file mode 100644 index 000000000..b527a397b --- /dev/null +++ b/lib/Transforms/YosysOptimizer/BooleanGateImporter.h @@ -0,0 +1,37 @@ +#ifndef THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ +#define THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ + +#include "lib/Transforms/YosysOptimizer/RTLILImporter.h" +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on +namespace mlir { +namespace heir { + +// BooleanGateImporter implements the RTLILConfig for importing RTLIL that uses +// boolean gates. +class BooleanGateImporter : public RTLILImporter { + public: + BooleanGateImporter(MLIRContext *context) : RTLILImporter(context) {} + + protected: + Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, + ImplicitLocOpBuilder &b) const override; + + SmallVector getInputs( + Yosys::RTLIL::Cell *cell) const override; + + Yosys::RTLIL::SigSpec getOutput(Yosys::RTLIL::Cell *cell) const override; +}; + +} // namespace heir +} // namespace mlir + +#endif // THIRD_PARTY_HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_BOOLEANGATEIMPORTER_H_ diff --git a/lib/Transforms/YosysOptimizer/LUTImporter.cpp b/lib/Transforms/YosysOptimizer/LUTImporter.cpp index b4e9fb572..511e457b2 100644 --- a/lib/Transforms/YosysOptimizer/LUTImporter.cpp +++ b/lib/Transforms/YosysOptimizer/LUTImporter.cpp @@ -3,19 +3,23 @@ #include #include "include/Dialect/Comb/IR/CombOps.h" -#include "kernel/rtlil.h" // from @at_clifford_yosys -#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { mlir::Operation *LUTImporter::createOp(Yosys::RTLIL::Cell *cell, - SmallVector &inputs, + SmallVector &inputs, ImplicitLocOpBuilder &b) const { assert(cell->type.begins_with("\\lut")); @@ -36,7 +40,7 @@ mlir::Operation *LUTImporter::createOp(Yosys::RTLIL::Cell *cell, return b.create(inputs, lookupTable); } -SmallVector LUTImporter::getInputs( +SmallVector LUTImporter::getInputs( Yosys::RTLIL::Cell *cell) const { assert(cell->type.begins_with("\\lut") && "expected lut cells"); diff --git a/lib/Transforms/YosysOptimizer/LUTImporter.h b/lib/Transforms/YosysOptimizer/LUTImporter.h index bf5eef253..b829bf8fa 100644 --- a/lib/Transforms/YosysOptimizer/LUTImporter.h +++ b/lib/Transforms/YosysOptimizer/LUTImporter.h @@ -1,7 +1,6 @@ #ifndef HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_LUTIMPORTER_H_ #define HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_LUTIMPORTER_H_ -#include "kernel/rtlil.h" // from @at_clifford_yosys #include "lib/Transforms/YosysOptimizer/RTLILImporter.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project @@ -9,6 +8,11 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { @@ -18,10 +22,10 @@ class LUTImporter : public RTLILImporter { LUTImporter(MLIRContext *context) : RTLILImporter(context) {} protected: - Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, + Operation *createOp(Yosys::RTLIL::Cell *cell, SmallVector &inputs, ImplicitLocOpBuilder &b) const override; - SmallVector getInputs( + SmallVector getInputs( Yosys::RTLIL::Cell *cell) const override; Yosys::RTLIL::SigSpec getOutput(Yosys::RTLIL::Cell *cell) const override; diff --git a/lib/Transforms/YosysOptimizer/RTLILImporter.cpp b/lib/Transforms/YosysOptimizer/RTLILImporter.cpp index 57dd1a0c8..9705a4662 100644 --- a/lib/Transforms/YosysOptimizer/RTLILImporter.cpp +++ b/lib/Transforms/YosysOptimizer/RTLILImporter.cpp @@ -185,7 +185,7 @@ func::FuncOp RTLILImporter::importModule( "expected cell in RTLIL design"); auto *cell = module->cells_["\\" + cellName]; - SmallVector inputValues; + SmallVector inputValues; for (const auto &conn : getInputs(cell)) { inputValues.push_back(getBit(conn, b, retBitValues)); } diff --git a/lib/Transforms/YosysOptimizer/RTLILImporter.h b/lib/Transforms/YosysOptimizer/RTLILImporter.h index 12a1a9c85..349439642 100644 --- a/lib/Transforms/YosysOptimizer/RTLILImporter.h +++ b/lib/Transforms/YosysOptimizer/RTLILImporter.h @@ -1,9 +1,8 @@ #ifndef HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_RTLILIMPORTER_H_ #define HEIR_LIB_TRANSFORMS_YOSYSOPTIMIZER_RTLILIMPORTER_H_ -#include "kernel/rtlil.h" // from @at_clifford_yosys -#include "llvm/include/llvm/ADT/MapVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/MapVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project @@ -11,6 +10,11 @@ #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +// Block clang-format from reordering +// clang-format off +#include "kernel/rtlil.h" // from @at_clifford_yosys +// clang-format on + namespace mlir { namespace heir { @@ -41,11 +45,11 @@ class RTLILImporter { protected: // cellToOp converts an RTLIL cell to an MLIR operation. virtual Operation *createOp(Yosys::RTLIL::Cell *cell, - SmallVector &inputs, + SmallVector &inputs, ImplicitLocOpBuilder &b) const = 0; // Returns a list of RTLIL cell inputs. - virtual SmallVector getInputs( + virtual SmallVector getInputs( Yosys::RTLIL::Cell *cell) const = 0; // Returns an RTLIL cell output. diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp index 1addfcead..0a92b50ae 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp @@ -16,6 +16,7 @@ #include "include/Dialect/Secret/IR/SecretPatterns.h" #include "include/Dialect/Secret/IR/SecretTypes.h" #include "include/Target/Verilog/VerilogEmitter.h" +#include "lib/Transforms/YosysOptimizer/BooleanGateImporter.h" #include "lib/Transforms/YosysOptimizer/LUTImporter.h" #include "lib/Transforms/YosysOptimizer/RTLILImporter.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project @@ -66,7 +67,7 @@ using std::string; // $2: yosys runfiles // $3: abc path // $4: abc fast option -fast -constexpr std::string_view kYosysTemplate = R"( +constexpr std::string_view kYosysLutTemplate = R"( read_verilog {0}; hierarchy -check -top \{1}; proc; memory; stat; @@ -81,6 +82,24 @@ clean; stat; )"; +// $0: verilog filename +// $1: function name +// $2: abc path +// $3: yosys runfiles path +// $4: abc fast option -fast +constexpr std::string_view kYosysBooleanTemplate = R"( +read_verilog {0}; +hierarchy -check -top \{1}; +proc; memory; stat; +techmap -map {3}/techmap.v; opt; stat; +abc -exe {2} -liberty {3}/tfhe-rs_cells.liberty {4}; stat; +opt_clean -purge; stat; +rename -hide */c:*; rename -enumerate */c:*; +hierarchy -generate * o:Y i:*; opt; opt_clean -purge; +clean; +stat; +)"; + struct RelativeOptimizationStatistics { std::string originalOp; int64_t numArithOps; @@ -91,17 +110,13 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { using YosysOptimizerBase::YosysOptimizerBase; YosysOptimizer(std::string yosysFilesPath, std::string abcPath, bool abcFast, - int unrollFactor, bool printStats) + int unrollFactor, Mode mode, bool printStats) : yosysFilesPath(std::move(yosysFilesPath)), abcPath(std::move(abcPath)), abcFast(abcFast), printStats(printStats), - unrollFactor(unrollFactor) {} - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } + unrollFactor(unrollFactor), + mode(mode) {} void runOnOperation() override; @@ -116,6 +131,7 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { bool abcFast; bool printStats; int unrollFactor; + Mode mode; llvm::SmallVector optStatistics; }; @@ -183,7 +199,6 @@ LogicalResult convertOpOperands(secret::GenericOp op, func::FuncOp func, /// Convert a secret.generic's results from secret.secret> /// to secret.secret. -// genericOp has the original, func op has the memref's yosys optimized LogicalResult convertOpResults(secret::GenericOp op, SmallVector originalResultTy, DenseSet &castOps, @@ -195,7 +210,6 @@ LogicalResult convertOpResults(secret::GenericOp op, opResult.getType().cast(); IntegerType elementType; - int numElements = 1; if (MemRefType convertedType = dyn_cast(secretType.getValueType())) { if (!convertedType.getElementType().isa() || @@ -206,7 +220,6 @@ LogicalResult convertOpResults(secret::GenericOp op, return failure(); } elementType = convertedType.getElementType().cast(); - numElements = convertedType.getNumElements(); } else { elementType = secretType.getValueType().cast(); } @@ -388,9 +401,22 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { // Invoke Yosys to translate to a combinational circuit and optimize. Yosys::log_error_stderr = true; LLVM_DEBUG(Yosys::log_streams.push_back(&std::cout)); - Yosys::run_pass(llvm::formatv(kYosysTemplate.data(), filename, moduleName, - yosysFilesPath, abcPath, - abcFast ? "-fast" : "")); + + LLVM_DEBUG( + llvm::dbgs() << "Using " + << (mode == Mode::LUT ? "LUT cells" : "boolean gates")); + auto yosysTemplate = + llvm::formatv(kYosysLutTemplate.data(), filename, moduleName, + yosysFilesPath, abcPath, abcFast ? "-fast" : "") + .str(); + if (mode == Mode::Boolean) { + std::cout << yosysFilesPath << std::endl; + yosysTemplate = + llvm::formatv(kYosysBooleanTemplate.data(), filename, moduleName, + abcPath, yosysFilesPath, abcFast ? "-fast" : "") + .str(); + } + Yosys::run_pass(yosysTemplate); // Translate Yosys result back to MLIR and insert into the func LLVM_DEBUG(Yosys::run_pass("dump;")); @@ -399,7 +425,6 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { Yosys::run_pass("torder -stop * P*;"); Yosys::log_streams.clear(); auto topologicalOrder = getTopologicalOrder(cellOrder); - LUTImporter lutImporter = LUTImporter(context); Yosys::RTLIL::Design *design = Yosys::yosys_get_design(); auto numCells = design->top_module()->cells().size(); totalCircuitSize += numCells; @@ -408,9 +433,14 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { } LLVM_DEBUG(llvm::dbgs() << "Importing RTLIL module\n"); - + std::unique_ptr importer; + if (mode == Mode::LUT) { + importer = std::make_unique(context); + } else { + importer = std::make_unique(context); + } func::FuncOp func = - lutImporter.importModule(design->top_module(), topologicalOrder); + importer->importModule(design->top_module(), topologicalOrder); Yosys::run_pass("delete;"); LLVM_DEBUG(llvm::dbgs() << "Done importing RTLIL, now type-coverting ops\n"); @@ -554,9 +584,9 @@ void YosysOptimizer::runOnOperation() { std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor, bool printStats) { + int unrollFactor, Mode mode, bool printStats) { return std::make_unique(yosysFilesPath, abcPath, abcFast, - unrollFactor, printStats); + unrollFactor, mode, printStats); } void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, @@ -567,7 +597,7 @@ void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, const YosysOptimizerPipelineOptions &options) { pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, options.abcFast, options.unrollFactor, - options.printStats)); + options.mode, options.printStats)); pm.addPass(mlir::createCSEPass()); }); } diff --git a/lib/Transforms/YosysOptimizer/yosys/BUILD b/lib/Transforms/YosysOptimizer/yosys/BUILD index e89c5930c..f12d811d9 100644 --- a/lib/Transforms/YosysOptimizer/yosys/BUILD +++ b/lib/Transforms/YosysOptimizer/yosys/BUILD @@ -9,5 +9,6 @@ filegroup( name = "share_files", srcs = glob([ "*.v", + "*.liberty", ]), ) diff --git a/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty b/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty new file mode 100644 index 000000000..4fa43fca3 --- /dev/null +++ b/lib/Transforms/YosysOptimizer/yosys/tfhe-rs_cells.liberty @@ -0,0 +1,260 @@ +/********************************************/ +/* */ +/* Supergate cell library for Bench marking */ +/* */ +/* Symbiotic EDA GmbH / Moseley Instruments */ +/* Niels A. Moseley */ +/* */ +/* Process: none */ +/* */ +/* Date : 02-11-2018 */ +/* Version: 1.0 */ +/* */ +/********************************************/ + +library(supergate) { + delay_model : table_lookup; + time_unit : "1ns"; + + /* Inverter */ + cell(inv) { + area : 0; + pin(A) { + direction : input; + } + + pin(Y) { + direction : output; + function : "A'"; + timing() { + related_pin : "A"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("0.0"); + } + cell_fall(scalar) { + values("0.0"); + } + rise_transition(scalar) { + values("0.0"); + } + fall_transition(scalar) { + values("0.0"); + } + } + } + } + + cell(buffer) { + area : 0; + pin(A) { + direction : input; + } + pin(Y) { + direction : output; + function : "A"; + timing() { + related_pin : "A"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("0.0"); + } + cell_fall(scalar) { + values("0.0"); + } + rise_transition(scalar) { + values("0.0"); + } + fall_transition(scalar) { + values("0.0"); + } + } + } + } + + /* 2-input AND gate */ + cell(and2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * B)"; + timing() { + related_pin : "A B"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input NAND gate */ + cell(nand2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * B)'"; + timing() { + related_pin : "A B"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input OR gate */ + cell(or2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A + B)"; + timing() { + related_pin : "A B"; + timing_sense : positive_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input NOR gate */ + cell(nor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A + B)'"; + timing() { + related_pin : "A B"; + timing_sense : negative_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input XOR */ + cell(xor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "(A * (B')) + ((A') * B)"; + timing() { + related_pin : "A B"; + timing_sense : non_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } + + /* 2-input XNOR */ + cell(xnor2) { + area : 100; + pin(A) { + direction : input; + } + pin(B) { + direction : input; + } + pin(Y) { + direction: output; + function : "((A * (B')) + ((A') * B))'"; + timing() { + related_pin : "A B"; + timing_sense : non_unate; + cell_rise(scalar) { + values("1000.0"); + } + cell_fall(scalar) { + values("1000.0"); + } + rise_transition(scalar) { + values("1000.0"); + } + fall_transition(scalar) { + values("1000.0"); + } + } + } + } +} /* end */ diff --git a/tests/yosys_optimizer/add_one.mlir b/tests/yosys_optimizer/add_one.mlir index 73bb8a60b..cd0cf3522 100644 --- a/tests/yosys_optimizer/add_one.mlir +++ b/tests/yosys_optimizer/add_one.mlir @@ -1,4 +1,5 @@ -// RUN: heir-opt --yosys-optimizer %s | FileCheck %s +// RUN: heir-opt --yosys-optimizer --canonicalize --cse %s | FileCheck %s +// RUN: heir-opt --yosys-optimizer="mode=Boolean" --canonicalize --cse %s | FileCheck --check-prefix=CHECK --check-prefix=BOOL %s module { // CHECK-LABEL: @add_one @@ -13,6 +14,7 @@ module { ins(%in, %one: !secret.secret, i8) { ^bb0(%IN: i8, %ONE: i8) : // CHECK-NOT: arith.addi + // BOOL-COUNT-7: comb.inv %2 = arith.addi %IN, %ONE : i8 secret.yield %2 : i8 } -> (!secret.secret) From bfa4527ad2e9d41208cf9bf62adf4c397a2d7a55 Mon Sep 17 00:00:00 2001 From: HEIR Team Date: Sat, 16 Mar 2024 10:21:21 -0700 Subject: [PATCH 15/16] Integrate LLVM at llvm/llvm-project@a4ca07f13b56 Updates LLVM usage to match [a4ca07f13b56](https://github.com/llvm/llvm-project/commit/a4ca07f13b56) PiperOrigin-RevId: 616433576 --- bazel/import_llvm.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bazel/import_llvm.bzl b/bazel/import_llvm.bzl index 37ac5499e..9503b1690 100644 --- a/bazel/import_llvm.bzl +++ b/bazel/import_llvm.bzl @@ -7,7 +7,7 @@ load( def import_llvm(name): """Imports LLVM.""" - LLVM_COMMIT = "3b5e7c83a6e226d5bd7ed2e9b67449b64812074c" + LLVM_COMMIT = "a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b" new_git_repository( name = name, From 3c818ba8907716b4df55ed06d1f6d38934651e39 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Tue, 12 Mar 2024 17:05:11 +0000 Subject: [PATCH 16/16] Updating the tests for e2e tfhe-rs-bool and starting with the tfhe-rs-bool fpga tests --- .gitignore | 2 +- .../TargetSlotAnalysis/TargetSlotAnalysis.h | 6 +- .../TfheRustBool/IR/TfheRustBoolOps.td | 11 +++ tests/cggi_to_tfhe_rust_bool/add_bool.mlir | 2 - .../cggi_to_tfhe_rust_bool/add_one_bool.mlir | 2 - tests/tfhe_rust_bool/end_to_end/BUILD | 1 + tests/tfhe_rust_bool/end_to_end/Cargo.toml | 4 + .../end_to_end/src/main_bool_add.rs | 51 ++++++++++++ .../end_to_end/test_bool_add.mlir | 69 +++++++++++++++ tests/tfhe_rust_bool/end_to_end_fpga/BUILD | 23 +++++ .../tfhe_rust_bool/end_to_end_fpga/Cargo.toml | 21 +++++ .../tfhe_rust_bool/end_to_end_fpga/README.md | 44 ++++++++++ .../end_to_end_fpga/src/main.rs | 83 +++++++++++++++++++ .../end_to_end_fpga/test_add_one_bool.mlir | 73 ++++++++++++++++ .../end_to_end_fpga/test_packed_and.mlir | 13 +++ 15 files changed, 397 insertions(+), 8 deletions(-) create mode 100644 tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs create mode 100644 tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/BUILD create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/README.md create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir create mode 100644 tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir diff --git a/.gitignore b/.gitignore index 5cee65586..318187095 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ venv # for rust codegen tests **/Cargo.lock tests/**/**/target/ -tests/tfhe_rust_bool/end_to_end_fpga/ +tests/tfhe_rust_bool/end_to_end_fpga/tfhe-rs # vscode .vscode/** diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h index cc7487a98..984fb338c 100644 --- a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -122,9 +122,9 @@ class TargetSlotAnalysis void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; - void visitBranchOperand(OpOperand &operand) override {}; - void visitCallOperand(OpOperand &operand) override {}; - void setToExitState(TargetSlotLattice *lattice) override {}; + void visitBranchOperand(OpOperand &operand) override{}; + void visitCallOperand(OpOperand &operand) override{}; + void setToExitState(TargetSlotLattice *lattice) override{}; }; } // namespace target_slot_analysis diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index c711f3d93..92dd1033f 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -56,6 +56,17 @@ def AndPackedOp : TfheRustBool_Op<"and_packed", [ let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); } +def XorPackedOp : TfheRustBool_Op<"xor_packed", [ + Pure, + AllTypesMatch<["lhs", "rhs", "output"]> +]> { + let arguments = (ins + TfheRustBool_ServerKey:$serverKey, + TensorOf<[TfheRustBool_Encrypted]>:$lhs, + TensorOf<[TfheRustBool_Encrypted]>:$rhs + ); + let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); +} def NotOp : TfheRustBool_Op<"not", [ Pure, diff --git a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir index 646c68e81..034d04e8c 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir index 8a6e62cd8..6d4cc5fb0 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_one_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/tfhe_rust_bool/end_to_end/BUILD b/tests/tfhe_rust_bool/end_to_end/BUILD index a189be648..3c47de4e2 100644 --- a/tests/tfhe_rust_bool/end_to_end/BUILD +++ b/tests/tfhe_rust_bool/end_to_end/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( data = [ "Cargo.toml", "src/main.rs", + "src/main_bool_add.rs", "@heir//tests:test_utilities", ], default_tags = [ diff --git a/tests/tfhe_rust_bool/end_to_end/Cargo.toml b/tests/tfhe_rust_bool/end_to_end/Cargo.toml index 975466a91..6e6fa0a3c 100644 --- a/tests/tfhe_rust_bool/end_to_end/Cargo.toml +++ b/tests/tfhe_rust_bool/end_to_end/Cargo.toml @@ -12,3 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "x86_64-unix"] } [[bin]] name = "main" path = "src/main.rs" + +[[bin]] +name = "main_bool_add" +path = "src/main_bool_add.rs" diff --git a/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs new file mode 100644 index 000000000..02f61b2ef --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/src/main_bool_add.rs @@ -0,0 +1,51 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum.reverse_bits() + +} + +fn main() { + let flags = Args::parse(); + let (client_key, server_key) = tfhe::boolean::gen_keys(); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir new file mode 100644 index 000000000..3fe76bfd0 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end/test_bool_add.mlir @@ -0,0 +1,69 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_bool_add -- 15 3 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 00010010 +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/BUILD b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD new file mode 100644 index 000000000..a189be648 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/BUILD @@ -0,0 +1,23 @@ +# See README.md for setup required to run these tests + +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = [ + "Cargo.toml", + "src/main.rs", + "@heir//tests:test_utilities", + ], + default_tags = [ + "manual", + "notap", + ], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml new file mode 100644 index 000000000..07a449700 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "heir-tfhe-rust-integration-test" +version = "0.1.0" +edition = "2021" +default-run = "main" + +[dependencies] +clap = { version = "4.1.8", features = ["derive"] } +rayon = "1.6.1" +serde = { version = "1.0.152", features = ["derive"] } +tfhe = { path = "tfhe-rs/tfhe", features = [ + "boolean", + "x86_64-unix", +] } + +[features] +fpga = ["tfhe/fpga"] + +[[bin]] +name = "main" +path = "src/main.rs" diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md new file mode 100644 index 000000000..42be8b572 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -0,0 +1,44 @@ +# End to end Rust codegen tests - Boolean FPGA + +These tests exercise Rust codegen for the +[tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including +compiling the generated Rust source and running the resulting binary. This sets +tests are specifically of the boolean plaintexts, accompanying COSIC-KU Leuven version of the library, +and the [FPT-accelerator](https://eprint.iacr.org/2022/1635). + +> :warning: Not possible to run these tests without the COSIC extension of TFHE-rs and FPT-accelerator + +To avoid introducing these large dependencies into the entire project, these +tests are manual, and require the system they're running on to have +[Cargo](https://doc.rust-lang.org/cargo/index.html) installed. During the test, +cargo will fetch and build the required dependencies, and `Cargo.toml` in this +directory effectively pins the version of `tfhe` supported. + +Use the following command to run the tests in this directory, where the default +Cargo home `$HOME/.cargo` may need to be replaced by your custom `$CARGO_HOME`, +if you overrode the default option when installing Cargo. + +```bash +bazel query "filter('.mlir.test$', //tests/tfhe_rust_bool/end_to_end/...)" \ + | xargs bazel test --sandbox_writable_path=$HOME/.cargo "$@" +``` + +The `manual` tag is added to the targets in this directory to ensure that they +are not run when someone runs a glob test like `bazel test //...`. + +If you don't do this correctly, you will see an error like this: + +``` +# .---command stderr------------ +# | Updating crates.io index +# | Downloading crates ... +# | Downloaded memoffset v0.9.0 +# | error: failed to download replaced source registry `crates-io` +# | +# | Caused by: +# | failed to open `/home/you/.cargo/registry/cache/index.crates.io-6f17d22bba15001f/memoffset-0.9.0.crate` +# | +# | Caused by: +# | Read-only file system (os error 30) +# `----------------------------- +``` diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs new file mode 100644 index 000000000..47392eaa9 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/src/main.rs @@ -0,0 +1,83 @@ +use clap::Parser; +use tfhe::boolean::prelude::*; + +use tfhe::boolean::engine::BooleanEngine; +use tfhe::boolean::prelude::*; +use std::time::Instant; + +#[cfg(feature = "fpga")] +use tfhe::boolean::server_key::FpgaGates; + + +mod fn_under_test; + +// TODO(https://github.com/google/heir/issues/235): improve generality +#[derive(Parser, Debug)] +struct Args { + /// arguments to forward to function under test + #[arg(id = "input_1", index = 1, action)] + input1: u8, + + #[arg(id = "input_2", index = 2, action)] + input2: u8, +} + +// Encrypt a u8 +pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec { + let arr: [u8; 8] = core::array::from_fn(|shift| (value >> shift) & 1 ); + + let res: Vec = arr.iter() + .map(|bit| client_key.encrypt(if *bit != 0u8 { true } else { false })) + .collect(); + res +} + +// Decrypt a u8 +pub fn decrypt(ciphertexts: &Vec, client_key: &ClientKey) -> u8 { + let mut accum = 0u8; + for (i, ct) in ciphertexts.iter().enumerate() { + let bit = client_key.decrypt(ct); + accum |= (bit as u8) << i; + } + accum + +} + +fn main() { + let flags = Args::parse(); + + let params; + let client_key; + + let mut boolean_engine = BooleanEngine::new(); + + #[cfg(feature = "fpga")] + { + params = tfhe::boolean::engine::fpga::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(*params); + } + + #[cfg(not(feature = "fpga"))] + { + params = tfhe::boolean::parameters::DEFAULT_PARAMETERS_KS_PBS; + client_key = boolean_engine.create_client_key(params); + } + + // generate the server key, only the SW needs this + let server_key = boolean_engine.create_server_key(&client_key); + + #[cfg(feature = "fpga")] + server_key.enable_fpga(params); + + let ct_1 = encrypt(flags.input1.into(), &client_key); + let ct_2 = encrypt(flags.input2.into(), &client_key); + + let ct_1= ct_1.iter().collect(); + let ct_2= ct_2.iter().collect(); + + let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); + + let output = decrypt(&result, &client_key); + + println!("{:08b}", output); +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir new file mode 100644 index 000000000..d70eb52f3 --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_add_one_bool.mlir @@ -0,0 +1,73 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_add_one -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK-LABEL: pub fn fn_under_test( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir new file mode 100644 index 000000000..8ecca7cad --- /dev/null +++ b/tests/tfhe_rust_bool/end_to_end_fpga/test_packed_and.mlir @@ -0,0 +1,13 @@ +// This test ensures the testing harness is working properly with minimal codegen. + +// RUN: heir-translate %s --emit-tfhe-rust-bool > %S/src/fn_under_test.rs +// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 1 | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK: 1 +func.func @fn_under_test(%bsks : !bsks, %a: tensor<8x!eb>, %b: tensor<8x!eb>) -> tensor<8x!eb> { + %res = tfhe_rust_bool.and_packed %bsks, %a, %b: (!bsks, tensor<8x!eb>, tensor<8x!eb>) -> tensor<8x!eb> + return %res : tensor<8x!eb> +}