diff --git a/include/Dialect/BGV/IR/BGVOps.td b/include/Dialect/BGV/IR/BGVOps.td index fabf54ef2..8c1b01103 100644 --- a/include/Dialect/BGV/IR/BGVOps.td +++ b/include/Dialect/BGV/IR/BGVOps.td @@ -102,12 +102,12 @@ def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> { let summary = "Multiplication operation between ciphertext-plaintext."; } -def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { +def BGV_Rotate : BGV_Op<"rotate", [AllTypesMatch<["x", "output"]>]> { let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism."; let arguments = (ins RLWECiphertext:$x, - I64Attr:$offset + SignlessIntegerLike:$offset ); let results = (outs @@ -117,6 +117,27 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { let hasVerifier = 1; } +def BGV_ExtractOp : BGV_Op<"extract", [AllTypesMatch<["x", "output"]>]> { + let summary = "Extract the i-th element of a ciphertext."; + + let description = [{ + While this operation is costly to compute in FHE, we represent it so we can + implement efficient lowerings and folders. + + This op can be implemented as a plaintext multiplication with a one-hot + vector and a rotate. + }]; + + let arguments = (ins + RLWECiphertext:$x, + SignlessIntegerLike:$offset + ); + + let results = (outs + RLWECiphertext:$output + ); +} + def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> { let summary = "Negate the coefficients of the ciphertext."; diff --git a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp index 4857ec746..4a5e473c8 100644 --- a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp +++ b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp @@ -12,6 +12,7 @@ #include "include/Dialect/Openfhe/IR/OpenfheTypes.h" #include "lib/Conversion/Utils.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -148,12 +149,29 @@ struct ConvertRotateOp : public OpConversionPattern { if (failed(result)) return result; Value cryptoContext = result.value(); - auto offsetValue = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adaptor.getOffset())); + Value castOffset = + llvm::TypeSwitch(adaptor.getOffset().getType()) + .Case([&](auto ty) { + return rewriter + .create( + op.getLoc(), rewriter.getI64Type(), adaptor.getOffset()) + .getResult(); + }) + .Case([&](IntegerType ty) { + if (ty.getWidth() < 64) { + return rewriter + .create(op.getLoc(), rewriter.getI64Type(), + adaptor.getOffset()) + .getResult(); + } + return rewriter + .create(op.getLoc(), rewriter.getI64Type(), + adaptor.getOffset()) + .getResult(); + }); rewriter.replaceOp( op, rewriter.create(op.getLoc(), cryptoContext, - adaptor.getX(), offsetValue)); + adaptor.getX(), castOffset)); return success(); } }; diff --git a/lib/Conversion/SecretToBGV/BUILD b/lib/Conversion/SecretToBGV/BUILD index e8631e248..e4b10bed6 100644 --- a/lib/Conversion/SecretToBGV/BUILD +++ b/lib/Conversion/SecretToBGV/BUILD @@ -18,11 +18,13 @@ cc_library( "@heir//lib/Dialect/Polynomial/IR:Polynomial", "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@heir//lib/Dialect/Secret/IR:Dialect", + "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/lib/Conversion/SecretToBGV/SecretToBGV.cpp b/lib/Conversion/SecretToBGV/SecretToBGV.cpp index 39c95e26f..067bad444 100644 --- a/lib/Conversion/SecretToBGV/SecretToBGV.cpp +++ b/lib/Conversion/SecretToBGV/SecretToBGV.cpp @@ -14,16 +14,18 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/IR/SecretOps.h" #include "include/Dialect/Secret/IR/SecretTypes.h" +#include "include/Dialect/TensorExt/IR/TensorExtOps.h" #include "lib/Conversion/Utils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::heir { @@ -64,12 +66,14 @@ class SecretToBGVTypeConverter : public TypeConverter { // Convert secret types to BGV ciphertext types addConversion([ctx, this](secret::SecretType type) -> Type { - RankedTensorType tensorTy = cast(type.getValueType()); + int bitWidth = + llvm::TypeSwitch(type.getValueType()) + .Case( + [&](auto ty) -> int { return ty.getElementTypeBitWidth(); }) + .Case([&](auto ty) -> int { return ty.getWidth(); }); return lwe::RLWECiphertextType::get( ctx, - lwe::PolynomialEvaluationEncodingAttr::get( - ctx, tensorTy.getElementTypeBitWidth(), - tensorTy.getElementTypeBitWidth()), + lwe::PolynomialEvaluationEncodingAttr::get(ctx, bitWidth, bitWidth), lwe::RLWEParamsAttr::get(ctx, 2, ring_)); }); @@ -108,9 +112,7 @@ class SecretGenericOpConversion inputs.push_back( adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); } else { - return rewriter.notifyMatchFailure( - op->getLoc(), - "Plaintext-ciphertext operations are not yet supported."); + inputs.push_back(operand.get()); } } @@ -158,7 +160,7 @@ struct SecretToBGV : public impl::SecretToBGVBase { for (auto value : op->getOperands()) { if (auto secretTy = dyn_cast(value.getType())) { auto tensorTy = dyn_cast(secretTy.getValueType()); - if (!tensorTy || + if (tensorTy && tensorTy.getShape() != ArrayRef{rlweRing.value().getIdeal().getDegree()}) { return WalkResult::interrupt(); @@ -169,7 +171,7 @@ struct SecretToBGV : public impl::SecretToBGVBase { }); if (compatibleTensors.wasInterrupted()) { module->emitError( - "expected secret types to be tensors with dimension " + "expected batched secret types to be tensors with dimension " "matching ring parameter"); return signalPassFailure(); } @@ -183,6 +185,9 @@ struct SecretToBGV : public impl::SecretToBGVBase { addStructuralConversionPatterns(typeConverter, patterns, target); patterns.add, + SecretGenericOpConversion, + SecretGenericOpConversion, + SecretGenericOpConversion, SecretGenericOpMulConversion>(typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/tests/bgv/ops.mlir b/tests/bgv/ops.mlir index cd5c3693c..f4c34b8ce 100644 --- a/tests/bgv/ops.mlir +++ b/tests/bgv/ops.mlir @@ -22,6 +22,7 @@ // CHECK: module module { + // CHECK-LABEL: @test_multiply func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct { %add = bgv.add(%arg0, %arg1) : !ct %sub = bgv.sub(%arg0, %arg1) : !ct @@ -34,11 +35,22 @@ module { return %arg0 : !ct } + // CHECK-LABEL: @test_ciphertext_plaintext func.func @test_ciphertext_plaintext(%arg0: !pt, %arg1: !pt, %arg2: !pt, %arg3: !ct) -> !ct { %add = bgv.add_plain(%arg3, %arg0) : !ct %sub = bgv.sub_plain(%add, %arg1) : !ct %mul = bgv.mul_plain(%sub, %arg2) : !ct - // CHECK: rlwe_params = >> + // CHECK: rlwe_params = >> return %mul : !ct } + + // CHECK-LABEL: @test_rotate_extract + func.func @test_rotate_extract(%arg3: !ct) -> !ct { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %add = bgv.rotate(%arg3, %c1) : (!ct, index) -> !ct + %sub = bgv.extract(%add, %c0) : (!ct, index) -> !ct + // CHECK: rlwe_params = >> + return %sub : !ct + } } diff --git a/tests/bgv/to_openfhe.mlir b/tests/bgv/to_openfhe.mlir index fbd173d6f..a8c5feff6 100644 --- a/tests/bgv/to_openfhe.mlir +++ b/tests/bgv/to_openfhe.mlir @@ -35,9 +35,11 @@ module { %sub = bgv.sub(%x, %y) : !ct // CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] %mul = bgv.mul(%x, %y) : !ct -> !ct_level3 - // CHECK: %[[c5:.*]] = arith.constant 4 : i64 + // CHECK: %[[c5:.*]] = arith.index_cast + // CHECK-SAME: to i64 // CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]] - %rot = bgv.rotate(%x) {offset = 4}: (!ct) -> !ct + %c4 = arith.constant 4 : index + %rot = bgv.rotate(%x, %c4): (!ct, index) -> !ct return } diff --git a/tests/bgv/verifier.mlir b/tests/bgv/verifier.mlir index 7e1d9d7a3..e67bfc714 100644 --- a/tests/bgv/verifier.mlir +++ b/tests/bgv/verifier.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --verify-diagnostics %s +// RUN: heir-opt --verify-diagnostics --split-input-file %s #encoding = #lwe.polynomial_evaluation_encoding @@ -12,8 +12,28 @@ !ct1 = !lwe.rlwe_ciphertext func.func @test_input_dimension_error(%input: !ct) { + %offset = arith.constant 4 : index // expected-error@+1 {{x.dim == 2 does not hold}} - %out = bgv.rotate (%input) {offset = 4} : (!ct) -> !ct1 + %out = bgv.rotate (%input, %offset) : (!ct, index) -> !ct + return +} + +// ----- + +#encoding = #lwe.polynomial_evaluation_encoding + +#my_poly = #polynomial.polynomial<1 + x**1024> +#ring = #polynomial.ring + +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params + +!ct = !lwe.rlwe_ciphertext +!ct1 = !lwe.rlwe_ciphertext +func.func @test_input_output_type_match(%input: !ct1) { + %offset = arith.constant 4 : index + // expected-error@+1 {{failed to verify that all of {x, output} have same type}} + %out = bgv.rotate (%input, %offset) : (!ct1, index) -> !ct return } diff --git a/tests/secret_to_bgv/hamming_distance_1024.mlir b/tests/secret_to_bgv/hamming_distance_1024.mlir new file mode 100644 index 000000000..37a611b15 --- /dev/null +++ b/tests/secret_to_bgv/hamming_distance_1024.mlir @@ -0,0 +1,31 @@ +// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic \ +// RUN: --canonicalize --cse --heir-simd-vectorizer \ +// RUN: --secret-distribute-generic --secret-to-bgv \ +// RUN: %s | FileCheck %s + +// CHECK-LABEL: @hamming +// CHECK: bgv.sub +// CHECK-NEXT: bgv.mul +// CHECK-NEXT: bgv.relinearize + +// TODO(#521): After rotate-and-reduce works, only check for 10 bg.rotate +// CHECK-COUNT-1023: bgv.rotate +// CHECK: bgv.extract +// CHECK-NEXT: return + +func.func @hamming(%arg0: tensor<1024xi16>, %arg1: tensor<1024xi16> {secret.secret}) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 1024 iter_args(%arg6 = %c0_si16) -> i16 { + %1 = tensor.extract %arg0[%arg2] : tensor<1024xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<1024xi16> + %3 = arith.subi %1, %2 : i16 + %4 = tensor.extract %arg0[%arg2] : tensor<1024xi16> + %5 = tensor.extract %arg1[%arg2] : tensor<1024xi16> + %6 = arith.subi %4, %5 : i16 + %7 = arith.muli %3, %6 : i16 + %8 = arith.addi %arg6, %7 : i16 + affine.yield %8 : i16 + } + return %0 : i16 +} diff --git a/tests/secret_to_bgv/invalid.mlir b/tests/secret_to_bgv/invalid.mlir index bcd52e69b..6fb3f136a 100644 --- a/tests/secret_to_bgv/invalid.mlir +++ b/tests/secret_to_bgv/invalid.mlir @@ -2,16 +2,7 @@ // Tests invalid secret types -// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} -module { - func.func @test_not_tensor(%arg0 : !secret.secret) -> (!secret.secret) { - return %arg0 : !secret.secret - } -} - -// ----- - -// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} +// expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}} module { func.func @test_invalid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { return %arg0 : !secret.secret>