From aa3ae948d1fd4fd413be4d5b2274371d01a9f9fd Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Wed, 13 Mar 2024 12:13:45 -0700 Subject: [PATCH] feat: Add secret-to-bgv conversion for ciphertext arith ops Adds parameter setting options. All secret inputs are required to be a uniform tensor shape that matches the ring dimension in the parameters specified. PiperOrigin-RevId: 615502361 --- include/Conversion/SecretToBGV/BUILD | 37 ++++ include/Conversion/SecretToBGV/SecretToBGV.h | 16 ++ include/Conversion/SecretToBGV/SecretToBGV.td | 39 ++++ include/Dialect/BGV/IR/BGVOps.td | 2 +- include/Dialect/LWE/IR/LWEAttributes.td | 2 + include/Dialect/LWE/IR/LWEOps.td | 33 +++- include/Dialect/LWE/IR/LWETypes.td | 5 +- lib/Conversion/SecretToBGV/BUILD | 28 +++ lib/Conversion/SecretToBGV/SecretToBGV.cpp | 184 ++++++++++++++++++ lib/Dialect/LWE/IR/LWEDialect.cpp | 12 ++ tests/secret_to_bgv/BUILD | 10 + tests/secret_to_bgv/invalid.mlir | 26 +++ tests/secret_to_bgv/ops.mlir | 25 +++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 15 files changed, 418 insertions(+), 4 deletions(-) create mode 100644 include/Conversion/SecretToBGV/BUILD create mode 100644 include/Conversion/SecretToBGV/SecretToBGV.h create mode 100644 include/Conversion/SecretToBGV/SecretToBGV.td create mode 100644 lib/Conversion/SecretToBGV/BUILD create mode 100644 lib/Conversion/SecretToBGV/SecretToBGV.cpp create mode 100644 tests/secret_to_bgv/BUILD create mode 100644 tests/secret_to_bgv/invalid.mlir create mode 100644 tests/secret_to_bgv/ops.mlir diff --git a/include/Conversion/SecretToBGV/BUILD b/include/Conversion/SecretToBGV/BUILD new file mode 100644 index 000000000..1cdd13a0e --- /dev/null +++ b/include/Conversion/SecretToBGV/BUILD @@ -0,0 +1,37 @@ +# SecretToBGV tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + [ + "SecretToBGV.h", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=SecretToBGV", + ], + "SecretToBGV.h.inc", + ), + ( + ["-gen-pass-doc"], + "SecretToBGV.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "SecretToBGV.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Conversion/SecretToBGV/SecretToBGV.h b/include/Conversion/SecretToBGV/SecretToBGV.h new file mode 100644 index 000000000..286b70b4f --- /dev/null +++ b/include/Conversion/SecretToBGV/SecretToBGV.h @@ -0,0 +1,16 @@ +#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ +#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir { + +#define GEN_PASS_DECL +#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" + +} // namespace mlir::heir + +#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ \ No newline at end of file diff --git a/include/Conversion/SecretToBGV/SecretToBGV.td b/include/Conversion/SecretToBGV/SecretToBGV.td new file mode 100644 index 000000000..cfd6657cb --- /dev/null +++ b/include/Conversion/SecretToBGV/SecretToBGV.td @@ -0,0 +1,39 @@ +#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ +#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ + +include "mlir/Pass/PassBase.td" + +def SecretToBGV : Pass<"secret-to-bgv"> { + let summary = "Lower `secret` to `bgv` dialect."; + + let description = [{ + This pass lowers an IR with `secret.generic` blocks containing arithmetic + operations to operations on ciphertexts with the BGV dialect. + + The pass assumes that the `secret.generic` regions have been distributed + through arithmetic operations so that only one ciphertext operation appears + per generic block. It also requires that `canonicalize` was run so that + non-secret values used are removed from the `secret.generic`'s block + arguments. + + The pass requires that all types are tensors of a uniform shape matching the + dimension of the ciphertext space specified my `poly-mod-degree`. + }]; + + let dependentDialects = [ + "mlir::heir::polynomial::PolynomialDialect", + "mlir::heir::bgv::BGVDialect", + "mlir::heir::lwe::LWEDialect", + ]; + + let options = [ + Option<"polyModDegree", "poly-mod-degree", "int", + /*default=*/"1024", "Default degree of the cyclotomic polynomial " + "modulus to use for ciphertext space.">, + Option<"coefficientModBits", "coefficient-mod-bits", "int", + /*default=*/"29", "Default number of bits of the prime " + "coefficient modulus to use " "for the ciphertext space."> + ]; +} + +#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ \ No newline at end of file diff --git a/include/Dialect/BGV/IR/BGVOps.td b/include/Dialect/BGV/IR/BGVOps.td index c7f32fa3f..48f3c23c7 100644 --- a/include/Dialect/BGV/IR/BGVOps.td +++ b/include/Dialect/BGV/IR/BGVOps.td @@ -27,7 +27,7 @@ class BGV_CiphertextPlaintextOp traits = [AllTypesMatch<["x", "output"]>, TypesMatchWith<"type of 'y' matches encoding type of 'x'", "output", "y", - "lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast($_self).getEncoding())">]> : + "lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast($_self).getEncoding(), ::llvm::cast($_self).getRlweParams().getRing())">]> : BGV_Op { let arguments = (ins RLWECiphertext:$x, diff --git a/include/Dialect/LWE/IR/LWEAttributes.td b/include/Dialect/LWE/IR/LWEAttributes.td index 24c55933f..b1ce34307 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.td +++ b/include/Dialect/LWE/IR/LWEAttributes.td @@ -296,6 +296,8 @@ def RLWE_InverseCanonicalEmbeddingEncoding }]; } +def AnyRLWEEncodingAttr : AnyAttrOf<[RLWE_PolynomialCoefficientEncoding, RLWE_PolynomialEvaluationEncoding, RLWE_InverseCanonicalEmbeddingEncoding]>; + def LWE_LWEParams : AttrDef { let mnemonic = "lwe_params"; diff --git a/include/Dialect/LWE/IR/LWEOps.td b/include/Dialect/LWE/IR/LWEOps.td index aa8018981..02e9744d8 100644 --- a/include/Dialect/LWE/IR/LWEOps.td +++ b/include/Dialect/LWE/IR/LWEOps.td @@ -3,6 +3,8 @@ include "include/Dialect/LWE/IR/LWEDialect.td" include "include/Dialect/LWE/IR/LWETypes.td" +include "include/Dialect/Polynomial/IR/PolynomialAttributes.td" + include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -24,7 +26,7 @@ def LWE_EncodeOp : LWE_Op<"encode", [Pure]> { Examples: ``` - %Y = lwe.encode %value {encoding = #enc}: i1 to !lwe.lwe_plaintext + %Y = lwe.encode %value {encoding = #enc}: i1 to !lwe.lwe_plaintext ``` }]; @@ -59,4 +61,33 @@ def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> { let hasVerifier = 1; } +def LWE_RLWEEncodeOp : LWE_Op<"rlwe_encode", [Pure]> { + let summary = "Encode an integer to yield an RLWE plaintext"; + let description = [{ + Encode an integer to yield an RLWE plaintext. + + This op uses a an encoding attribute to encode the bits of the integer into + an RLWE plaintext value that can then be encrypted. + + Examples: + + ``` + %Y = lwe.rlwe_encode %value {encoding = #enc, ring = #ring}: i1 to !lwe.rlwe_plaintext + ``` + }]; + + let arguments = (ins + SignlessIntegerLike:$plaintext, + AnyRLWEEncodingAttr:$encoding, + Ring_Attr:$ring + ); + + let results = (outs RLWEPlaintext:$output); + let assemblyFormat = "$plaintext attr-dict `:` qualified(type($plaintext)) `to` qualified(type($output))"; + + // Verify that the encoding and ring parameter matches the output plaintext attribute. + let hasVerifier = 1; +} + + #endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWEOPS_TD_ diff --git a/include/Dialect/LWE/IR/LWETypes.td b/include/Dialect/LWE/IR/LWETypes.td index d3ddf1c44..0a678fa03 100644 --- a/include/Dialect/LWE/IR/LWETypes.td +++ b/include/Dialect/LWE/IR/LWETypes.td @@ -43,7 +43,7 @@ def RLWECiphertext : LWE_Type<"RLWECiphertext", "rlwe_ciphertext"> { let parameters = (ins "::mlir::Attribute":$encoding, - OptionalParameter<"RLWEParamsAttr">:$rlwe_params + "RLWEParamsAttr":$rlwe_params ); let assemblyFormat = "`<` struct(params) `>`"; @@ -70,7 +70,8 @@ def RLWEPlaintext : LWE_Type<"RLWEPlaintext", "rlwe_plaintext"> { let summary = "A type for RLWE plaintexts"; let parameters = (ins - "::mlir::Attribute":$encoding + "::mlir::Attribute":$encoding, + "::mlir::heir::polynomial::RingAttr":$ring ); let assemblyFormat = "`<` struct(params) `>`"; diff --git a/lib/Conversion/SecretToBGV/BUILD b/lib/Conversion/SecretToBGV/BUILD new file mode 100644 index 000000000..3f0494ee0 --- /dev/null +++ b/lib/Conversion/SecretToBGV/BUILD @@ -0,0 +1,28 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "SecretToBGV", + srcs = ["SecretToBGV.cpp"], + hdrs = [ + "@heir//include/Conversion/SecretToBGV:SecretToBGV.h", + ], + deps = [ + "@heir//include/Conversion/SecretToBGV:pass_inc_gen", + "@heir//lib/Conversion:Utils", + "@heir//lib/Dialect/BGV/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", + "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", + "@heir//lib/Dialect/Secret/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Conversion/SecretToBGV/SecretToBGV.cpp b/lib/Conversion/SecretToBGV/SecretToBGV.cpp new file mode 100644 index 000000000..d7c987c95 --- /dev/null +++ b/lib/Conversion/SecretToBGV/SecretToBGV.cpp @@ -0,0 +1,184 @@ +#include "include/Conversion/SecretToBGV/SecretToBGV.h" + +#include +#include + +#include "include/Dialect/BGV/IR/BGVDialect.h" +#include "include/Dialect/BGV/IR/BGVOps.h" +#include "include/Dialect/LWE/IR/LWEAttributes.h" +#include "include/Dialect/LWE/IR/LWEOps.h" +#include "include/Dialect/LWE/IR/LWETypes.h" +#include "include/Dialect/Polynomial/IR/Polynomial.h" +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" +#include "include/Dialect/Secret/IR/SecretDialect.h" +#include "include/Dialect/Secret/IR/SecretOps.h" +#include "include/Dialect/Secret/IR/SecretTypes.h" +#include "lib/Conversion/Utils.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir { + +#define GEN_PASS_DEF_SECRETTOBGV +#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" + +namespace { + +FailureOr getRlweRing(MLIRContext *ctx, + int coefficientModBits, + int polyModDegree) { + std::vector monomials; + monomials.emplace_back(1, polyModDegree); + monomials.emplace_back(1, 0); + polynomial::Polynomial xnPlusOne = + polynomial::Polynomial::fromMonomials(monomials, ctx); + switch (coefficientModBits) { + case 29: + return polynomial::RingAttr::get( + APInt(polynomial::APINT_BIT_WIDTH, 463187969), xnPlusOne); + default: + return failure(); + } +} + +// Templatized function for replacing an arithmetic operation T with a BGV +// equivalent operation Y. +template +void replaceOp(ConversionPatternRewriter &rewriter, secret::GenericOp op, + ValueRange inputs) { + rewriter.replaceOpWithNewOp(op, inputs); +} + +template <> +void replaceOp(ConversionPatternRewriter &rewriter, + secret::GenericOp op, + ValueRange inputs) { + rewriter.replaceOpWithNewOp( + op, rewriter.create(op.getLoc(), inputs), + rewriter.getDenseI32ArrayAttr({0, 1, 2}), + rewriter.getDenseI32ArrayAttr({0, 1})); +} + +} // namespace + +// Remove this class if no type conversions are necessary +class SecretToBGVTypeConverter : public TypeConverter { + public: + SecretToBGVTypeConverter(MLIRContext *ctx, polynomial::RingAttr rlweRing) { + addConversion([](Type type) { return type; }); + + // Convert secret types to BGV ciphertext types + addConversion([ctx, this](secret::SecretType type) -> Type { + RankedTensorType tensorTy = cast(type.getValueType()); + return lwe::RLWECiphertextType::get( + ctx, + lwe::PolynomialEvaluationEncodingAttr::get( + ctx, tensorTy.getElementTypeBitWidth(), + tensorTy.getElementTypeBitWidth()), + lwe::RLWEParamsAttr::get(ctx, 2, ring_)); + }); + + ring_ = rlweRing; + } + + polynomial::RingAttr ring_; +}; + +template +class SecretGenericOpConversion + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + secret::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + if (op.getBody()->getOperations().size() > 2) { + // Each secret.generic should contain at most one instruction - + // secret-distribute-generic can be used to distribute through the + // arithmetic ops. + return failure(); + } + + auto &innerOp = op.getBody()->getOperations().front(); + if (!isa(innerOp)) { + return failure(); + } + + // Assemble the arguments for the BGV operation. + SmallVector inputs; + for (OpOperand &operand : innerOp.getOpOperands()) { + if (auto *secretArg = op.getOpOperandForBlockArgument(operand.get())) { + inputs.push_back( + adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); + } else { + // Plaintext-ciphertext operations are not handled. + return failure(); + } + } + + // Directly convert the op if all operands are ciphertext. + replaceOp(rewriter, op, inputs); + return success(); + } +}; + +struct SecretToBGV : public impl::SecretToBGVBase { + using SecretToBGVBase::SecretToBGVBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + + auto rlweRing = getRlweRing(context, coefficientModBits, polyModDegree); + if (failed(rlweRing)) { + return signalPassFailure(); + } + // Ensure that all secret types are uniform and matching the ring + // parameter size. + WalkResult compatibleTensors = module->walk([&](Operation *op) { + for (auto value : op->getOperands()) { + if (auto secretTy = dyn_cast(value.getType())) { + auto tensorTy = dyn_cast(secretTy.getValueType()); + if (!tensorTy || + tensorTy.getShape() != + ArrayRef{rlweRing.value().getIdeal().getDegree()}) { + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + if (compatibleTensors.wasInterrupted()) { + module->emitError( + "expected secret types to be tensors with dimension " + "matching ring parameter"); + return signalPassFailure(); + } + + SecretToBGVTypeConverter typeConverter(context, rlweRing.value()); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + + addStructuralConversionPatterns(typeConverter, patterns, target); + patterns.add, + SecretGenericOpConversion>( + typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace mlir::heir \ No newline at end of file diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index 416130a67..8e272640e 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -158,6 +158,18 @@ LogicalResult EncodeOp::verify() { return success(); } +LogicalResult RLWEEncodeOp::verify() { + auto encodingAttr = this->getEncodingAttr(); + auto outEncoding = this->getOutput().getType().getEncoding(); + + if (encodingAttr != outEncoding) { + return this->emitOpError() + << "encoding attr must match output LWE plaintext encoding"; + } + + return success(); +} + LogicalResult TrivialEncryptOp::verify() { auto paramsAttr = this->getParamsAttr(); auto outParamsAttr = this->getOutput().getType().getLweParams(); diff --git a/tests/secret_to_bgv/BUILD b/tests/secret_to_bgv/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/secret_to_bgv/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/secret_to_bgv/invalid.mlir b/tests/secret_to_bgv/invalid.mlir new file mode 100644 index 000000000..948175a0c --- /dev/null +++ b/tests/secret_to_bgv/invalid.mlir @@ -0,0 +1,26 @@ +// RUN: heir-opt --split-input-file --secret-to-bgv --verify-diagnostics %s | FileCheck %s + +// Tests invalid secret types + +// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} +module { + func.func @test_not_tensor(%arg0 : !secret.secret) -> (!secret.secret) { + return %arg0 : !secret.secret + } +} + +// ----- + +// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}} +module { + func.func @test_invalid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { + return %arg0 : !secret.secret> + } +} + +// ----- + +// CHECK: test_valid_dimension +func.func @test_valid_dimension(%arg0 : !secret.secret>) -> (!secret.secret>) { + return %arg0 : !secret.secret> +} diff --git a/tests/secret_to_bgv/ops.mlir b/tests/secret_to_bgv/ops.mlir new file mode 100644 index 000000000..901e8f38b --- /dev/null +++ b/tests/secret_to_bgv/ops.mlir @@ -0,0 +1,25 @@ +// RUN: heir-opt --canonicalize --secret-to-bgv %s | FileCheck %s + +!eui1 = !secret.secret> + +module { + // CHECK-LABEL: func @test_arith_ops + func.func @test_arith_ops(%arg0 : !eui1, %arg1 : !eui1, %arg2 : !eui1) -> (!eui1) { + %0 = secret.generic ins(%arg0, %arg1 : !eui1, !eui1) { + // CHECK: bgv.add + ^bb0(%ARG0 : tensor<1024xi1>, %ARG1 : tensor<1024xi1>): + %1 = arith.addi %ARG0, %ARG1 : tensor<1024xi1> + secret.yield %1 : tensor<1024xi1> + } -> !eui1 + // CHECK: bgv.mul + // CHECK-NEXT: bgv.relinearize + %1 = secret.generic ins(%0, %arg2 : !eui1, !eui1) { + ^bb0(%ARG0 : tensor<1024xi1>, %ARG1 : tensor<1024xi1>): + %1 = arith.muli %ARG0, %ARG1 : tensor<1024xi1> + secret.yield %1 : tensor<1024xi1> + } -> !eui1 + // CHECK: return + // CHECK-SAME: cmod=463187969, ideal=#polynomial.polynomial<1 + x**1024> + return %1 : !eui1 + } +} diff --git a/tools/BUILD b/tools/BUILD index a2446bd13..90df9629c 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -41,6 +41,7 @@ cc_binary( "@heir//lib/Conversion/MemrefToArith:ExpandCopy", "@heir//lib/Conversion/MemrefToArith:MemrefToArithRegistration", "@heir//lib/Conversion/PolynomialToStandard", + "@heir//lib/Conversion/SecretToBGV", "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/CGGI/Transforms", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 14dade8cd..9f7c0f138 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -8,6 +8,7 @@ #include "include/Conversion/CombToCGGI/CombToCGGI.h" #include "include/Conversion/MemrefToArith/MemrefToArith.h" #include "include/Conversion/PolynomialToStandard/PolynomialToStandard.h" +#include "include/Conversion/SecretToBGV/SecretToBGV.h" #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/CGGI/IR/CGGIDialect.h" #include "include/Dialect/CGGI/Transforms/Passes.h" @@ -319,6 +320,7 @@ int main(int argc, char **argv) { polynomial::registerPolynomialToStandardPasses(); registerCGGIToTfheRustPasses(); registerCGGIToTfheRustBoolPasses(); + registerSecretToBGVPasses(); PassPipelineRegistration<>( "heir-tosa-to-arith",