From f7c5ffd56c0114455eff56d07274b91394cc1298 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Mon, 11 Mar 2024 09:12:28 -0700 Subject: [PATCH] bgv: use LWE dialect's RLWE ciphertext type in the BGV dialect Fixes https://github.com/google/heir/issues/487 PiperOrigin-RevId: 614686524 --- include/Dialect/BGV/IR/BGVAttributes.h | 12 ---- include/Dialect/BGV/IR/BGVAttributes.td | 17 ----- include/Dialect/BGV/IR/BGVDialect.td | 3 - include/Dialect/BGV/IR/BGVOps.h | 2 - include/Dialect/BGV/IR/BGVOps.td | 54 ++++++++-------- include/Dialect/BGV/IR/BGVTraits.h | 20 +++--- include/Dialect/BGV/IR/BGVTypes.h | 13 ---- include/Dialect/BGV/IR/BGVTypes.td | 47 -------------- include/Dialect/BGV/IR/BUILD | 64 +------------------ include/Dialect/LWE/IR/LWEAttributes.h | 3 + include/Dialect/LWE/IR/LWEAttributes.td | 6 +- include/Dialect/Openfhe/IR/OpenfheOps.td | 15 ++++- lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp | 36 +---------- .../BGVToPolynomial/BGVToPolynomial.cpp | 14 ++-- lib/Conversion/BGVToPolynomial/BUILD | 1 + lib/Dialect/BGV/IR/BGVDialect.cpp | 50 ++++++--------- lib/Dialect/BGV/IR/BUILD | 6 +- lib/Dialect/CGGI/Transforms/BUILD | 2 + .../CGGI/Transforms/SetDefaultParameters.cpp | 15 +++-- lib/Dialect/LWE/IR/BUILD | 1 + lib/Target/OpenFhePke/BUILD | 2 + lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp | 7 +- .../OpenFhePke/OpenFhePkeHeaderEmitter.cpp | 3 +- tests/bgv/ops.mlir | 35 ++++++---- tests/bgv/to_openfhe.mlir | 42 ++++++------ tests/bgv/to_openfhe_invalid.mlir | 46 ++++++++----- tests/bgv/to_polynomial.mlir | 13 ++-- tests/bgv/verifier.mlir | 16 +++-- tests/lwe/types.mlir | 5 +- tests/openfhe/emit_openfhe_pke.mlir | 5 +- tests/openfhe/end_to_end/binops.mlir | 3 +- tests/openfhe/ops.mlir | 5 +- 32 files changed, 222 insertions(+), 341 deletions(-) delete mode 100644 include/Dialect/BGV/IR/BGVAttributes.h delete mode 100644 include/Dialect/BGV/IR/BGVAttributes.td delete mode 100644 include/Dialect/BGV/IR/BGVTypes.h delete mode 100644 include/Dialect/BGV/IR/BGVTypes.td diff --git a/include/Dialect/BGV/IR/BGVAttributes.h b/include/Dialect/BGV/IR/BGVAttributes.h deleted file mode 100644 index dafe2dbe5..000000000 --- a/include/Dialect/BGV/IR/BGVAttributes.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ - -#include "include/Dialect/BGV/IR/BGVDialect.h" - -// Required to pull in poly's Ring_Attr -#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" - -#define GET_ATTRDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVAttributes.h.inc" - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ diff --git a/include/Dialect/BGV/IR/BGVAttributes.td b/include/Dialect/BGV/IR/BGVAttributes.td deleted file mode 100644 index 678b8cc14..000000000 --- a/include/Dialect/BGV/IR/BGVAttributes.td +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ - -include "BGVDialect.td" - -include "mlir/IR/DialectBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/OpBase.td" - -def BGVRingArrayAttr : AttrDef { - let mnemonic = "rings"; - let parameters = (ins ArrayRefParameter<"::mlir::heir::polynomial::RingAttr">:$rings); - let assemblyFormat = "`<` $rings `>`"; -} - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ diff --git a/include/Dialect/BGV/IR/BGVDialect.td b/include/Dialect/BGV/IR/BGVDialect.td index de598e4bf..af818bca1 100644 --- a/include/Dialect/BGV/IR/BGVDialect.td +++ b/include/Dialect/BGV/IR/BGVDialect.td @@ -14,9 +14,6 @@ def BGV_Dialect : Dialect { }]; let cppNamespace = "::mlir::heir::bgv"; - - let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 1; } #endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVDIALECT_TD_ diff --git a/include/Dialect/BGV/IR/BGVOps.h b/include/Dialect/BGV/IR/BGVOps.h index 225580651..73d668ad6 100644 --- a/include/Dialect/BGV/IR/BGVOps.h +++ b/include/Dialect/BGV/IR/BGVOps.h @@ -1,10 +1,8 @@ #ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_ #define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_ -#include "include/Dialect/BGV/IR/BGVAttributes.h" #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/BGV/IR/BGVTraits.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project diff --git a/include/Dialect/BGV/IR/BGVOps.td b/include/Dialect/BGV/IR/BGVOps.td index 0ff1a9a1f..0d885aa62 100644 --- a/include/Dialect/BGV/IR/BGVOps.td +++ b/include/Dialect/BGV/IR/BGVOps.td @@ -2,11 +2,13 @@ #define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_TD_ include "BGVDialect.td" -include "BGVTypes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "include/Dialect/LWE/IR/LWETypes.td" +include "include/Dialect/Polynomial/IR/PolynomialAttributes.td" + def SameOperandsAndResultRings: NativeOpTrait<"SameOperandsAndResultRings"> { let cppNamespace = "::mlir::heir::bgv"; } @@ -15,7 +17,7 @@ class BGV_Op traits = []> : Op { let assemblyFormat = [{ - `(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type(results) + `(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results)) }]; let cppNamespace = "::mlir::heir::bgv"; } @@ -26,45 +28,45 @@ def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> { let summary = "Addition operation between ciphertexts."; let arguments = (ins - Ciphertext:$x, - Ciphertext:$y + RLWECiphertext:$x, + RLWECiphertext:$y ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); - let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ; + let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> { let summary = "Subtraction operation between ciphertexts."; let arguments = (ins - Ciphertext:$x, - Ciphertext:$y + RLWECiphertext:$x, + RLWECiphertext:$y ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); - let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ; + let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands]> { let summary = "Multiplication operation between ciphertexts."; let arguments = (ins - Ciphertext:$x, - Ciphertext:$y + RLWECiphertext:$x, + RLWECiphertext:$y ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); - let assemblyFormat = "`(` operands `)` attr-dict `:` type($x) `->` type($output)" ; + let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($x)) `->` qualified(type($output))" ; let hasVerifier = 1; } @@ -73,12 +75,12 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism."; let arguments = (ins - Ciphertext:$x, + RLWECiphertext:$x, I64Attr:$offset ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); let hasVerifier = 1; @@ -88,14 +90,14 @@ def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> { let summary = "Negate the coefficients of the ciphertext."; let arguments = (ins - Ciphertext:$x + RLWECiphertext:$x ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); - let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ; + let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> { @@ -112,30 +114,28 @@ def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> { }]; let arguments = (ins - Ciphertext:$x, + RLWECiphertext:$x, DenseI32ArrayAttr:$from_basis, DenseI32ArrayAttr:$to_basis ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); let hasVerifier = 1; } -def BGV_ModulusSwitch : BGV_Op<"modulus_switch", [SameOperandsAndResultRings]> { - // This must be validated against the BGV ring parameter. +def BGV_ModulusSwitch : BGV_Op<"modulus_switch"> { let summary = "Lower the modulus level of the ciphertext."; let arguments = (ins - Ciphertext:$x, - I64Attr:$from_level, - I64Attr:$to_level + RLWECiphertext:$x, + Ring_Attr:$to_ring ); let results = (outs - Ciphertext:$output + RLWECiphertext:$output ); let hasVerifier = 1; diff --git a/include/Dialect/BGV/IR/BGVTraits.h b/include/Dialect/BGV/IR/BGVTraits.h index 16c75fffb..5eee3ce69 100644 --- a/include/Dialect/BGV/IR/BGVTraits.h +++ b/include/Dialect/BGV/IR/BGVTraits.h @@ -1,12 +1,14 @@ #ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ #define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ -#include "include/Dialect/BGV/IR/BGVAttributes.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" +#include "include/Dialect/LWE/IR/LWETypes.h" +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir::heir::bgv { +// TODO(#212): Move to LWE dialect/namespace // Trait that ensures that all operands and results ciphertext have the same set // of rings. template @@ -14,30 +16,30 @@ class SameOperandsAndResultRings : public OpTrait::TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - BGVRingsAttr rings = nullptr; + polynomial::RingAttr rings = nullptr; for (auto rTy : op->getResultTypes()) { - auto ct = dyn_cast(rTy); + auto ct = dyn_cast(rTy); if (!ct) continue; if (rings == nullptr) { - rings = ct.getRings(); + rings = ct.getRlweParams().getRing(); continue; } - if (rings != ct.getRings()) { + if (rings != ct.getRlweParams().getRing()) { return op->emitOpError() << "requires all operands and results to have the same rings"; } } for (auto oTy : op->getOperandTypes()) { - auto ct = dyn_cast(oTy); + auto ct = dyn_cast(oTy); if (!ct) continue; // Check only ciphertexts if (rings == nullptr) { - rings = ct.getRings(); + rings = ct.getRlweParams().getRing(); continue; } - if (rings != ct.getRings()) { + if (rings != ct.getRlweParams().getRing()) { return op->emitOpError() << "requires all operands and results to have the same rings"; } diff --git a/include/Dialect/BGV/IR/BGVTypes.h b/include/Dialect/BGV/IR/BGVTypes.h deleted file mode 100644 index b4bb43b15..000000000 --- a/include/Dialect/BGV/IR/BGVTypes.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ - -#include "include/Dialect/BGV/IR/BGVAttributes.h" -#include "include/Dialect/BGV/IR/BGVDialect.h" - -// Required to pull in poly's Ring_Attr -#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" - -#define GET_TYPEDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVTypes.h.inc" - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ diff --git a/include/Dialect/BGV/IR/BGVTypes.td b/include/Dialect/BGV/IR/BGVTypes.td deleted file mode 100644 index ae4aff93f..000000000 --- a/include/Dialect/BGV/IR/BGVTypes.td +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ - -include "BGVDialect.td" -include "BGVAttributes.td" - -include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/DialectBase.td" -include "mlir/IR/OpBase.td" - -// TODO(#100): Add a plaintext type. - -// A base class for all types in this dialect -class BGV_Type - : TypeDef { - let mnemonic = typeMnemonic; -} - -def Ciphertext : BGV_Type<"Ciphertext", "ciphertext"> { - let summary = "a BGV ciphertext"; - - let description = [{ - This type tracks the BGV ciphertext parameters, including the ciphertext - dimension (number of polynomials) and the set of rings that were used for - the particular BGV scheme instance. The default dimension is 2, representing - a ciphertext that is canonically encrypted against the key basis $(1, s)$. - - The type also includes a ring parameter specification. - - For example, `bgv.ciphertext` is a ciphertext with 3 - polynomials $(c_0, c_1, c_2)$. - - The optional attribute `level` specifies the "current ring". - }]; - - // TODO(#99): Add # of plaintext bits. - let parameters = (ins - BGVRingArrayAttr:$rings, - DefaultValuedParameter<"unsigned", "2">:$dim, - OptionalParameter<"std::optional">:$level - ); - - let assemblyFormat = "`<` `rings` `=` $rings (`,` `dim` `=` $dim^ )? (`,` `level` `=` $level^ )? `>`"; -} - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ diff --git a/include/Dialect/BGV/IR/BUILD b/include/Dialect/BGV/IR/BUILD index e2b899efc..99d304734 100644 --- a/include/Dialect/BGV/IR/BUILD +++ b/include/Dialect/BGV/IR/BUILD @@ -11,8 +11,6 @@ exports_files( [ "BGVDialect.h", "BGVOps.h", - "BGVTypes.h", - "BGVAttributes.h", "BGVTraits.h", ], ) @@ -20,10 +18,8 @@ exports_files( td_library( name = "td_files", srcs = [ - "BGVAttributes.td", "BGVDialect.td", "BGVOps.td", - "BGVTypes.td", ], deps = [ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", @@ -60,35 +56,6 @@ gentbl_cc_library( ], ) -gentbl_cc_library( - name = "types_inc_gen", - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - ], - "BGVTypes.h.inc", - ), - ( - [ - "-gen-typedef-defs", - ], - "BGVTypes.cpp.inc", - ), - ( - ["-gen-typedef-doc"], - "BGVTypes.md", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "BGVTypes.td", - deps = [ - ":attributes_inc_gen", - ":dialect_inc_gen", - ":td_files", - ], -) - gentbl_cc_library( name = "ops_inc_gen", tbl_outs = [ @@ -110,34 +77,7 @@ gentbl_cc_library( deps = [ ":dialect_inc_gen", ":td_files", - ":types_inc_gen", - ], -) - -gentbl_cc_library( - name = "attributes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-attrdef-decls", - ], - "BGVAttributes.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - ], - "BGVAttributes.cpp.inc", - ), - ( - ["-gen-attrdef-doc"], - "BGVAttributes.md", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "BGVAttributes.td", - deps = [ - ":dialect_inc_gen", - ":td_files", + "@heir//include/Dialect/LWE/IR:td_files", + "@heir//include/Dialect/Polynomial/IR:td_files", ], ) diff --git a/include/Dialect/LWE/IR/LWEAttributes.h b/include/Dialect/LWE/IR/LWEAttributes.h index c29eb0b99..092e1b568 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.h +++ b/include/Dialect/LWE/IR/LWEAttributes.h @@ -4,6 +4,9 @@ #include "include/Dialect/LWE/IR/LWEDialect.h" #include "mlir/include/mlir/IR/TensorEncoding.h" // from @llvm-project +// Required to pull in poly's Ring_Attr +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" + #define GET_ATTRDEF_CLASSES #include "include/Dialect/LWE/IR/LWEAttributes.h.inc" diff --git a/include/Dialect/LWE/IR/LWEAttributes.td b/include/Dialect/LWE/IR/LWEAttributes.td index f8a102cea..24c55933f 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.td +++ b/include/Dialect/LWE/IR/LWEAttributes.td @@ -309,14 +309,16 @@ def LWE_RLWEParams : AttrDef { let description = [{ An attribute describing classical RLWE parameters: - - `cmod`: the coefficient modulus for the polynomials. - `dimension`: the number of polynomials used in an RLWE sample, analogous to LWEParams.dimension. - `polyDegree`: the degree $N$ of the negacyclic polynomial modulus $x^N + 1$. }]; - let parameters = (ins "IntegerAttr": $cmod, "unsigned":$dimension, "unsigned": $polyDegree); + let parameters = (ins + DefaultValuedParameter<"unsigned", "2">:$dimension, + "::mlir::heir::polynomial::RingAttr":$ring + ); let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/include/Dialect/Openfhe/IR/OpenfheOps.td b/include/Dialect/Openfhe/IR/OpenfheOps.td index b9c47c01f..4f238163a 100644 --- a/include/Dialect/Openfhe/IR/OpenfheOps.td +++ b/include/Dialect/Openfhe/IR/OpenfheOps.td @@ -30,6 +30,17 @@ class Openfhe_UnaryOp traits = []> let results = (outs RLWECiphertext:$output); } +class Openfhe_UnaryTypeSwitchOp traits = []> + : Openfhe_Op{ + let arguments = (ins + Openfhe_CryptoContext:$cryptoContext, + RLWECiphertext:$ciphertext + ); + let results = (outs RLWECiphertext:$output); +} + class Openfhe_BinaryOp traits = []> : Openfhe_Op { let summary = "OpenFHE negate operati def SquareOp : Openfhe_UnaryOp<"square"> { let summary = "OpenFHE square operation of a ciphertext."; } def RelinOp : Openfhe_UnaryOp<"relin"> { let summary = "OpenFHE relinearize operation of a ciphertext."; } -def ModReduceOp : Openfhe_UnaryOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } -def LevelReduceOp : Openfhe_UnaryOp<"level_reduce"> { let summary = "OpenFHE level_reduce operation of a ciphertext."; } +def ModReduceOp : Openfhe_UnaryTypeSwitchOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } +def LevelReduceOp : Openfhe_UnaryTypeSwitchOp<"level_reduce"> { let summary = "OpenFHE level_reduce operation of a ciphertext."; } def RotOp : Openfhe_Op<"rot",[ Pure, diff --git a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp index e5e3f44f3..4857ec746 100644 --- a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp +++ b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp @@ -5,7 +5,6 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/BGV/IR/BGVOps.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" @@ -29,24 +28,8 @@ namespace mlir::heir::bgv { class ToLWECiphertextTypeConverter : public TypeConverter { public: - // Convert ciphertext to RLWE ciphertext ToLWECiphertextTypeConverter(MLIRContext *ctx) { addConversion([](Type type) { return type; }); - addConversion([ctx](CiphertextType type) -> Type { - assert(type.getLevel().has_value()); - auto level = type.getLevel().value(); - assert(level < type.getRings().getRings().size()); - - auto ring = type.getRings().getRings()[level]; - auto cmod = ring.getCmod(); - auto polyDegree = ring.getIdeal().getDegree(); - auto dim = type.getDim(); - return lwe::RLWECiphertextType::get( - // TODO(#99): Set a default encoding when the adding # of plaintext - // bits on BGV Ciphertext is done. - ctx, lwe::PolynomialEvaluationEncodingAttr::get(ctx, 0, 0), - lwe::RLWEParamsAttr::get(ctx, cmod, dim, polyDegree)); - }); } }; @@ -208,9 +191,6 @@ struct ConvertRelinOp : public OpConversionPattern { } }; -bool checkModulusSwitchLevels(unsigned long fromLevel, unsigned long toLevel) { - return fromLevel == toLevel + 1; -} struct ConvertModulusSwitchOp : public OpConversionPattern { ConvertModulusSwitchOp(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -223,22 +203,10 @@ struct ConvertModulusSwitchOp : public OpConversionPattern { FailureOr result = getContextualCryptoContext(op.getOperation()); if (failed(result)) return result; - auto fromLevel = adaptor.getFromLevel(); - auto toLevel = adaptor.getToLevel(); - - // Since the `ModReduce()` function in OpenFHE decreases the level of a - // ciphertext by 1, `fromLevel == toLevel + 1` holds for - // `bgv.ModulusSwitch`. - if (!checkModulusSwitchLevels(fromLevel, toLevel)) { - op.emitError() << "fromLevel must be toLevel + 1, got " - << "fromLevel: " << fromLevel << " and " - << "toLevel: " << toLevel; - return failure(); - } - Value cryptoContext = result.value(); rewriter.replaceOp(op, rewriter.create( - op.getLoc(), cryptoContext, adaptor.getX())); + op.getLoc(), op.getOutput().getType(), + cryptoContext, adaptor.getX())); return success(); } }; diff --git a/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp b/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp index 77a28022a..c50d7330b 100644 --- a/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp +++ b/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp @@ -5,7 +5,7 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/BGV/IR/BGVOps.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" +#include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Dialect/Polynomial/IR/Polynomial.h" #include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "include/Dialect/Polynomial/IR/PolynomialOps.h" @@ -13,6 +13,7 @@ #include "lib/Conversion/Utils.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -26,15 +27,12 @@ class CiphertextTypeConverter : public TypeConverter { // Convert ciphertext to tensor<#dim x !poly.poly<#rings[#level]>> CiphertextTypeConverter(MLIRContext *ctx) { addConversion([](Type type) { return type; }); - addConversion([ctx](CiphertextType type) -> Type { - assert(type.getLevel().has_value()); - auto level = type.getLevel().value(); - assert(level < type.getRings().getRings().size()); - - auto ring = type.getRings().getRings()[level]; + addConversion([ctx](lwe::RLWECiphertextType type) -> Type { + auto rlweParams = type.getRlweParams(); + auto ring = rlweParams.getRing(); auto polyTy = polynomial::PolynomialType::get(ctx, ring); - return RankedTensorType::get({type.getDim()}, polyTy); + return RankedTensorType::get({rlweParams.getDimension()}, polyTy); }); } // We don't include any custom materialization ops because this lowering is diff --git a/lib/Conversion/BGVToPolynomial/BUILD b/lib/Conversion/BGVToPolynomial/BUILD index 2909e3b2a..ddd6b5447 100644 --- a/lib/Conversion/BGVToPolynomial/BUILD +++ b/lib/Conversion/BGVToPolynomial/BUILD @@ -13,6 +13,7 @@ cc_library( "@heir//include/Conversion/BGVToPolynomial:pass_inc_gen", "@heir//lib/Conversion:Utils", "@heir//lib/Dialect/BGV/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:Polynomial", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/BGV/IR/BGVDialect.cpp b/lib/Dialect/BGV/IR/BGVDialect.cpp index ba1dcb7c8..7e89b7906 100644 --- a/lib/Dialect/BGV/IR/BGVDialect.cpp +++ b/lib/Dialect/BGV/IR/BGVDialect.cpp @@ -1,18 +1,12 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" -#include "include/Dialect/BGV/IR/BGVAttributes.h" #include "include/Dialect/BGV/IR/BGVOps.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project // Generated definitions #include "include/Dialect/BGV/IR/BGVDialect.cpp.inc" -#define GET_ATTRDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVAttributes.cpp.inc" -#define GET_TYPEDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVTypes.cpp.inc" #define GET_OP_CLASSES #include "include/Dialect/BGV/IR/BGVOps.cpp.inc" @@ -27,14 +21,6 @@ namespace bgv { // Dialect construction: there is one instance per context and it registers its // operations, types, and interfaces here. void BGVDialect::initialize() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "include/Dialect/BGV/IR/BGVAttributes.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "include/Dialect/BGV/IR/BGVTypes.cpp.inc" - >(); addOperations< #define GET_OP_LIST #include "include/Dialect/BGV/IR/BGVOps.cpp.inc" @@ -44,22 +30,23 @@ void BGVDialect::initialize() { LogicalResult MulOp::verify() { auto x = getX().getType(); auto y = getY().getType(); - if (x.getDim() != y.getDim()) { + if (x.getRlweParams().getDimension() != y.getRlweParams().getDimension()) { return emitOpError() << "input dimensions do not match"; } auto out = getOutput().getType(); - if (out.getDim() != 1 + x.getDim()) { + if (out.getRlweParams().getDimension() != + 1 + x.getRlweParams().getDimension()) { return emitOpError() << "output.dim == x.dim + 1 does not hold"; } return success(); } LogicalResult Rotate::verify() { auto x = getX().getType(); - if (x.getDim() != 2) { + if (x.getRlweParams().getDimension() != 2) { return emitOpError() << "x.dim == 2 does not hold"; } auto out = getOutput().getType(); - if (out.getDim() != 2) { + if (out.getRlweParams().getDimension() != 2) { return emitOpError() << "output.dim == 2 does not hold"; } return success(); @@ -68,10 +55,10 @@ LogicalResult Rotate::verify() { LogicalResult Relinearize::verify() { auto x = getX().getType(); auto out = getOutput().getType(); - if (x.getDim() != getFromBasis().size()) { + if (x.getRlweParams().getDimension() != getFromBasis().size()) { return emitOpError() << "input dimension does not match from_basis"; } - if (out.getDim() != getToBasis().size()) { + if (out.getRlweParams().getDimension() != getToBasis().size()) { return emitOpError() << "output dimension does not match to_basis"; } return success(); @@ -79,21 +66,22 @@ LogicalResult Relinearize::verify() { LogicalResult ModulusSwitch::verify() { auto x = getX().getType(); - auto rings = x.getRings().getRings().size(); - auto to = getToLevel(); - auto from = getFromLevel(); - if (to < 0 || to >= from || from >= rings) { - return emitOpError() << "invalid levels, should be true: 0 <= " << to - << " < " << from << " < " << rings; + auto xRing = x.getRlweParams().getRing(); + + auto out = getOutput().getType(); + auto outRing = out.getRlweParams().getRing(); + if (outRing != getToRing()) { + return emitOpError() << "output ring should match to_ring"; } - if (x.getLevel().has_value() && x.getLevel().value() != from) { - return emitOpError() << "input level does not match from_level"; + if (xRing.getCmod().getValue().ule(outRing.getCmod().getValue())) { + return emitOpError() + << "output ring modulus should be less than the input ring modulus"; } - auto outLvl = getOutput().getType().getLevel(); - if (!outLvl.has_value() || outLvl.value() != to) { + if (!xRing.getCmod().getValue().urem(outRing.getCmod().getValue()).isZero()) { return emitOpError() - << "output level should be specified and match to_level"; + << "output ring modulus should divide the input ring modulus"; } + return success(); } diff --git a/lib/Dialect/BGV/IR/BUILD b/lib/Dialect/BGV/IR/BUILD index 1554e86c8..4ba153eef 100644 --- a/lib/Dialect/BGV/IR/BUILD +++ b/lib/Dialect/BGV/IR/BUILD @@ -11,21 +11,19 @@ cc_library( "BGVDialect.cpp", ], hdrs = [ - "@heir//include/Dialect/BGV/IR:BGVAttributes.h", "@heir//include/Dialect/BGV/IR:BGVDialect.h", "@heir//include/Dialect/BGV/IR:BGVOps.h", "@heir//include/Dialect/BGV/IR:BGVTraits.h", - "@heir//include/Dialect/BGV/IR:BGVTypes.h", ], deps = [ - "@heir//include/Dialect/BGV/IR:attributes_inc_gen", "@heir//include/Dialect/BGV/IR:dialect_inc_gen", "@heir//include/Dialect/BGV/IR:ops_inc_gen", - "@heir//include/Dialect/BGV/IR:types_inc_gen", "@heir//include/Dialect/Polynomial/IR:attributes_inc_gen", + "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", ], ) diff --git a/lib/Dialect/CGGI/Transforms/BUILD b/lib/Dialect/CGGI/Transforms/BUILD index c4f9d0bed..8a52a45d0 100644 --- a/lib/Dialect/CGGI/Transforms/BUILD +++ b/lib/Dialect/CGGI/Transforms/BUILD @@ -26,6 +26,8 @@ cc_library( "@heir//include/Dialect/CGGI/Transforms:pass_inc_gen", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp index 90902bd13..ba6f340c9 100644 --- a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp +++ b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp @@ -1,8 +1,12 @@ #include "include/Dialect/CGGI/Transforms/SetDefaultParameters.h" +#include + #include "include/Dialect/CGGI/IR/CGGIAttributes.h" #include "include/Dialect/CGGI/IR/CGGIOps.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" +#include "include/Dialect/Polynomial/IR/Polynomial.h" +#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project @@ -22,10 +26,12 @@ struct SetDefaultParameters auto *op = getOperation(); MLIRContext &context = getContext(); unsigned defaultRlweDimension = 1; - unsigned defaultPolyDegree = 1024; APInt defaultCmod = APInt::getOneBitSet(64, 32); - IntegerAttr defaultCmodAttr = - IntegerAttr::get(IntegerType::get(&context, 64), defaultCmod); + std::vector monomials; + monomials.push_back(polynomial::Monomial(1, 1024)); + monomials.push_back(polynomial::Monomial(1, 0)); + polynomial::Polynomial defaultPolyIdeal = + polynomial::Polynomial::fromMonomials(monomials, &context); // https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py unsigned defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16 @@ -36,7 +42,8 @@ struct SetDefaultParameters unsigned defaultKskGadgetNumLevels = 5; lwe::RLWEParamsAttr defaultRlweParams = lwe::RLWEParamsAttr::get( - &context, defaultCmodAttr, defaultRlweDimension, defaultPolyDegree); + &context, defaultRlweDimension, + polynomial::RingAttr::get(defaultCmod, defaultPolyIdeal)); CGGIParamsAttr defaultParams = CGGIParamsAttr::get(&context, defaultRlweParams, defaultBskNoiseVariance, defaultBskGadgetBaseLog, diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 03b2b9617..7c315ee1a 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -19,6 +19,7 @@ cc_library( "@heir//include/Dialect/LWE/IR:dialect_inc_gen", "@heir//include/Dialect/LWE/IR:ops_inc_gen", "@heir//include/Dialect/LWE/IR:types_inc_gen", + "@heir//include/Dialect/Polynomial/IR:attributes_inc_gen", "@heir//lib/Dialect/Polynomial/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/lib/Target/OpenFhePke/BUILD b/lib/Target/OpenFhePke/BUILD index ca04e11b4..4f265ac58 100644 --- a/lib/Target/OpenFhePke/BUILD +++ b/lib/Target/OpenFhePke/BUILD @@ -32,6 +32,7 @@ cc_library( "@heir//lib/Analysis/SelectVariableNames", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Target:Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -55,6 +56,7 @@ cc_library( "@heir//lib/Analysis/SelectVariableNames", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Target:Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 33422b127..b89e5d481 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -7,6 +7,7 @@ #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheOps.h" +#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/Utils.h" @@ -39,9 +40,9 @@ void registerToOpenFhePkeTranslation() { return translateToOpenFhePke(op, output); }, [](DialectRegistry ®istry) { - registry - .insert(); + registry.insert(); }); } diff --git a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp index 0b6cca506..0d2195eb0 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp @@ -3,6 +3,7 @@ #include "include/Analysis/SelectVariableNames/SelectVariableNames.h" #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" +#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/Utils.h" @@ -35,7 +36,7 @@ void registerToOpenFhePkeHeaderTranslation() { }, [](DialectRegistry ®istry) { registry.insert(); + lwe::LWEDialect, polynomial::PolynomialDialect>(); }); } diff --git a/tests/bgv/ops.mlir b/tests/bgv/ops.mlir index f2fe983d2..6f4cb3bfc 100644 --- a/tests/bgv/ops.mlir +++ b/tests/bgv/ops.mlir @@ -3,23 +3,32 @@ // This simply tests for syntax. +#encoding = #lwe.polynomial_evaluation_encoding + #my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#otherrings = #bgv.rings<#ring1> +// cmod is 64153 * 2521 +#ring1 = #polynomial.ring +#ring2 = #polynomial.ring + +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params +#params2 = #lwe.rlwe_params + +!ct = !lwe.rlwe_ciphertext +!ct1 = !lwe.rlwe_ciphertext +!ct2 = !lwe.rlwe_ciphertext // CHECK: module module { - func.func @test_multiply(%arg0 : !bgv.ciphertext, %arg1: !bgv.ciphertext) -> !bgv.ciphertext { - %add = bgv.add(%arg0, %arg1) : !bgv.ciphertext - %sub = bgv.sub(%arg0, %arg1) : !bgv.ciphertext - %neg = bgv.negate(%arg0) : !bgv.ciphertext + func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct { + %add = bgv.add(%arg0, %arg1) : !ct + %sub = bgv.sub(%arg0, %arg1) : !ct + %neg = bgv.negate(%arg0) : !ct - %0 = bgv.mul(%arg0, %arg1) : !bgv.ciphertext -> !bgv.ciphertext - %1 = bgv.relinearize(%0) {from_basis = array, to_basis = array } : (!bgv.ciphertext) -> !bgv.ciphertext - %2 = bgv.modulus_switch(%1) {from_level = 1, to_level=0} : (!bgv.ciphertext) -> !bgv.ciphertext - // CHECK: <>, >> - return %arg0 : !bgv.ciphertext + %0 = bgv.mul(%arg0, %arg1) : !ct -> !ct1 + %1 = bgv.relinearize(%0) {from_basis = array, to_basis = array } : (!ct1) -> !ct + %2 = bgv.modulus_switch(%1) {to_ring = #ring2} : (!ct) -> !ct2 + // CHECK: rlwe_params = >> + return %arg0 : !ct } } diff --git a/tests/bgv/to_openfhe.mlir b/tests/bgv/to_openfhe.mlir index 8e9eb752e..fbd173d6f 100644 --- a/tests/bgv/to_openfhe.mlir +++ b/tests/bgv/to_openfhe.mlir @@ -1,23 +1,31 @@ // RUN: heir-opt --bgv-to-openfhe %s | FileCheck %s +#encoding = #lwe.polynomial_evaluation_encoding + #my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -!ct = !bgv.ciphertext -!ct_dim = !bgv.ciphertext -!ct_level = !bgv.ciphertext +// cmod is 64153 * 2521 +#ring1 = #polynomial.ring +#ring2 = #polynomial.ring + +#params1 = #lwe.rlwe_params +#params2 = #lwe.rlwe_params +#params3 = #lwe.rlwe_params +#params4 = #lwe.rlwe_params + +!ct = !lwe.rlwe_ciphertext +!ct_dim = !lwe.rlwe_ciphertext +!ct_level = !lwe.rlwe_ciphertext +!ct_level3 = !lwe.rlwe_ciphertext // CHECK: module module { - // CHECK: func.func @test_fn([[X:%.+]]: [[T:.*33538049.*]]) -> [[T]] + // CHECK: func.func @test_fn([[X:%.+]]: [[T:.*161729713.*]]) -> [[T]] func.func @test_fn(%x : !ct) -> !ct { // CHECK: return [[X]] : [[T]] return %x : !ct } - // CHECK: func.func @test_ops([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*33538049.*]], [[Y:%.+]]: [[T]]) + // CHECK: func.func @test_ops([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*161729713.*]], [[Y:%.+]]: [[T]]) func.func @test_ops(%x : !ct, %y : !ct) { // CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]] %negate = bgv.negate(%x) : !ct @@ -26,14 +34,14 @@ module { // CHECK: %[[v3:.*]] = openfhe.sub [[C]], %[[x3:.*]], %[[y3:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] %sub = bgv.sub(%x, %y) : !ct // CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] - %mul = bgv.mul(%x, %y) : !ct -> !bgv.ciphertext + %mul = bgv.mul(%x, %y) : !ct -> !ct_level3 // CHECK: %[[c5:.*]] = arith.constant 4 : i64 // CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]] %rot = bgv.rotate(%x) {offset = 4}: (!ct) -> !ct return } - // CHECK: func.func @test_relin([[C]]: [[S]], [[X:%.+]]: [[T:.*33538049.*]]) + // CHECK: func.func @test_relin([[C]]: [[S]], [[X:%.+]]: [[T:.*161729713.*]]) func.func @test_relin(%x : !ct_dim) { // CHECK: %[[v6:.*]] = openfhe.relin [[C]], %[[x6:.*]]: ([[S]], [[T]]) -> [[T]] %relin = bgv.relinearize(%x) { @@ -42,12 +50,10 @@ module { return } - // CHECK: func.func @test_modswitch([[C]]: [[S]], [[X:%.+]]: [[T:.*33538049.*]]) - func.func @test_modswitch(%x : !ct_level) { - // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]]: ([[S]], [[T]]) -> [[T]] - %mod_switch = bgv.modulus_switch(%x) { - from_level = 2, to_level = 1 - }: (!ct_level) -> !bgv.ciphertext - return + // CHECK: func.func @test_modswitch([[C]]: [[S]], [[X:%.+]]: [[T:.*161729713.*]]) -> [[T1:.*2521.*]] { + func.func @test_modswitch(%x : !ct) -> !ct_level { + // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]] : ([[S]], [[T]]) -> [[T1]] + %mod_switch = bgv.modulus_switch(%x) { to_ring=#ring2 }: (!ct) -> !ct_level + return %mod_switch : !ct_level } } diff --git a/tests/bgv/to_openfhe_invalid.mlir b/tests/bgv/to_openfhe_invalid.mlir index bc9f73bc0..a81ea78d8 100644 --- a/tests/bgv/to_openfhe_invalid.mlir +++ b/tests/bgv/to_openfhe_invalid.mlir @@ -1,27 +1,43 @@ -// RUN: not heir-opt --bgv-to-openfhe --split-input-file %s 2>&1 +// RUN: heir-opt --bgv-to-openfhe --split-input-file --verify-diagnostics %s 2>&1 + +#encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -!ct = !bgv.ciphertext +// cmod is 64153 * 2521 +#ring1 = #polynomial.ring +#ring2 = #polynomial.ring + +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params +#params2 = #lwe.rlwe_params + +!ct = !lwe.rlwe_ciphertext +!ct1 = !lwe.rlwe_ciphertext +!ct2 = !lwe.rlwe_ciphertext -func.func @test_relin_to_basis_error(%x: !bgv.ciphertext) { - // expected-error@+1 {{toBasis must be [0, 1], got [0, 2]}} - %relin_error = bgv.relinearize(%x) { from_basis = array, to_basis = array }: (!bgv.ciphertext) -> !ct +func.func @test_relin_to_basis_error(%x: !ct1) { + // expected-error@+2 {{toBasis must be [0, 1], got [0, 2]}} + // expected-error@+1 {{failed to legalize operation 'bgv.relinearize' that was explicitly marked illegal}} + %relin_error = bgv.relinearize(%x) { from_basis = array, to_basis = array }: (!ct1) -> !ct return } // ----- +#encoding = #lwe.polynomial_evaluation_encoding + #my_poly = #polynomial.polynomial<1 + x**1024> #ring1 = #polynomial.ring #ring2 = #polynomial.ring -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -func.func @test_modswitch_level_error(%x: !bgv.ciphertext) { - // expected-error@+1 {{fromLevel must be toLevel + 1, got fromLevel: 2 and toLevel: 0}} - %relin_error = bgv.modulus_switch(%x) { - from_level = 2, to_level = 0 - }: (!bgv.ciphertext) -> !bgv.ciphertext + +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params +#params2 = #lwe.rlwe_params + +!ct1 = !lwe.rlwe_ciphertext +!ct2 = !lwe.rlwe_ciphertext + +func.func @test_modswitch_level_error(%x: !ct2) { + // expected-error@+1 {{output ring should match to_ring}} + %relin_error = bgv.modulus_switch(%x) {to_ring=#ring2}: (!ct2) -> !ct1 return } diff --git a/tests/bgv/to_polynomial.mlir b/tests/bgv/to_polynomial.mlir index 2635c7e0e..dea7e6dd2 100644 --- a/tests/bgv/to_polynomial.mlir +++ b/tests/bgv/to_polynomial.mlir @@ -3,11 +3,14 @@ // This simply tests for syntax. +#encoding = #lwe.polynomial_evaluation_encoding + #my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -!ct1 = !bgv.ciphertext +#ring = #polynomial.ring +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params + +!ct1 = !lwe.rlwe_ciphertext // CHECK: module module { @@ -40,7 +43,7 @@ module { // CHECK: [[Z1:%.+]] = polynomial.add([[X0Y1]], [[X1Y0]]) : [[P]] // CHECK: [[Z2:%.+]] = polynomial.mul([[X1]], [[Y1]]) : [[P]] // CHECK: [[Z:%.+]] = tensor.from_elements [[Z0]], [[Z1]], [[Z2]] : tensor<3x[[P]]> - %mul = bgv.mul(%x, %y) : !ct1 -> !bgv.ciphertext + %mul = bgv.mul(%x, %y) : !ct1 -> !lwe.rlwe_ciphertext return } } diff --git a/tests/bgv/verifier.mlir b/tests/bgv/verifier.mlir index e990b5239..7e1d9d7a3 100644 --- a/tests/bgv/verifier.mlir +++ b/tests/bgv/verifier.mlir @@ -1,13 +1,19 @@ // RUN: heir-opt --verify-diagnostics %s +#encoding = #lwe.polynomial_evaluation_encoding + #my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> +#ring = #polynomial.ring + +#params = #lwe.rlwe_params +#params1 = #lwe.rlwe_params + +!ct = !lwe.rlwe_ciphertext +!ct1 = !lwe.rlwe_ciphertext -func.func @test_input_dimension_error(%input: !bgv.ciphertext) { +func.func @test_input_dimension_error(%input: !ct) { // expected-error@+1 {{x.dim == 2 does not hold}} - %out = bgv.rotate (%input) {offset = 4} : (!bgv.ciphertext) -> !bgv.ciphertext + %out = bgv.rotate (%input) {offset = 4} : (!ct) -> !ct1 return } diff --git a/tests/lwe/types.mlir b/tests/lwe/types.mlir index e3a5af99a..a0a78ac64 100644 --- a/tests/lwe/types.mlir +++ b/tests/lwe/types.mlir @@ -23,7 +23,10 @@ func.func @test_valid_lwe_ciphertext_unspecified(%arg0 : !ciphertext_noparams) - return %arg0 : !ciphertext_noparams } -#rlwe_params = #lwe.rlwe_params + +#my_poly = #polynomial.polynomial<1 + x**1024> +#ring = #polynomial.ring +#rlwe_params = #lwe.rlwe_params !ciphertext_rlwe = !lwe.rlwe_ciphertext // CHECK-LABEL: test_valid_rlwe_ciphertext diff --git a/tests/openfhe/emit_openfhe_pke.mlir b/tests/openfhe/emit_openfhe_pke.mlir index 9dfe5aa9d..74bb054c7 100644 --- a/tests/openfhe/emit_openfhe_pke.mlir +++ b/tests/openfhe/emit_openfhe_pke.mlir @@ -1,7 +1,10 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding -#params = #lwe.rlwe_params + +#my_poly = #polynomial.polynomial<1 + x**16384> +#ring= #polynomial.ring +#params = #lwe.rlwe_params !cc = !openfhe.crypto_context !ek = !openfhe.eval_key !pt = !lwe.rlwe_plaintext diff --git a/tests/openfhe/end_to_end/binops.mlir b/tests/openfhe/end_to_end/binops.mlir index 2cb857c0e..1d0e4e30d 100644 --- a/tests/openfhe/end_to_end/binops.mlir +++ b/tests/openfhe/end_to_end/binops.mlir @@ -1,7 +1,8 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding -#params = #lwe.rlwe_params +#ring = #polynomial.ring> +#params = #lwe.rlwe_params !cc = !openfhe.crypto_context !ct = !lwe.rlwe_ciphertext diff --git a/tests/openfhe/ops.mlir b/tests/openfhe/ops.mlir index eb2e746f0..265be8619 100644 --- a/tests/openfhe/ops.mlir +++ b/tests/openfhe/ops.mlir @@ -2,7 +2,10 @@ // This simply tests for syntax. #encoding = #lwe.polynomial_evaluation_encoding -#params = #lwe.rlwe_params +#my_poly = #polynomial.polynomial<1 + x**16384> +#ring= #polynomial.ring +#params = #lwe.rlwe_params + !pk = !openfhe.public_key !ek = !openfhe.eval_key !cc = !openfhe.crypto_context