diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 8e102df34..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "files.associations": { - "*.inc": "cpp", - "type_traits": "cpp", - "vector": "cpp", - "__bit_reference": "cpp", - "bitset": "cpp", - "deque": "cpp", - "limits": "cpp", - "ratio": "cpp", - "tuple": "cpp", - "istream": "cpp", - "ostream": "cpp", - "array": "cpp", - "functional": "cpp", - "utility": "cpp", - "variant": "cpp", - "__node_handle": "cpp", - "__split_buffer": "cpp", - "hash_map": "cpp", - "hash_set": "cpp", - "forward_list": "cpp", - "list": "cpp", - "map": "cpp", - "queue": "cpp", - "regex": "cpp", - "set": "cpp", - "span": "cpp", - "stack": "cpp", - "string": "cpp", - "string_view": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "any": "cpp" - } -} diff --git a/include/Dialect/BGV/IR/BGVAttributes.h b/include/Dialect/BGV/IR/BGVAttributes.h deleted file mode 100644 index dafe2dbe5..000000000 --- a/include/Dialect/BGV/IR/BGVAttributes.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ - -#include "include/Dialect/BGV/IR/BGVDialect.h" - -// Required to pull in poly's Ring_Attr -#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" - -#define GET_ATTRDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVAttributes.h.inc" - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_H_ diff --git a/include/Dialect/BGV/IR/BGVAttributes.td b/include/Dialect/BGV/IR/BGVAttributes.td deleted file mode 100644 index 678b8cc14..000000000 --- a/include/Dialect/BGV/IR/BGVAttributes.td +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ - -include "BGVDialect.td" - -include "mlir/IR/DialectBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/OpBase.td" - -def BGVRingArrayAttr : AttrDef { - let mnemonic = "rings"; - let parameters = (ins ArrayRefParameter<"::mlir::heir::polynomial::RingAttr">:$rings); - let assemblyFormat = "`<` $rings `>`"; -} - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVATTRIBUTES_TD_ diff --git a/include/Dialect/BGV/IR/BGVTraits.h b/include/Dialect/BGV/IR/BGVTraits.h deleted file mode 100644 index 16c75fffb..000000000 --- a/include/Dialect/BGV/IR/BGVTraits.h +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ - -#include "include/Dialect/BGV/IR/BGVAttributes.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project - -namespace mlir::heir::bgv { - -// Trait that ensures that all operands and results ciphertext have the same set -// of rings. -template -class SameOperandsAndResultRings - : public OpTrait::TraitBase { - public: - static LogicalResult verifyTrait(Operation *op) { - BGVRingsAttr rings = nullptr; - for (auto rTy : op->getResultTypes()) { - auto ct = dyn_cast(rTy); - if (!ct) continue; - if (rings == nullptr) { - rings = ct.getRings(); - continue; - } - if (rings != ct.getRings()) { - return op->emitOpError() - << "requires all operands and results to have the same rings"; - } - } - - for (auto oTy : op->getOperandTypes()) { - auto ct = dyn_cast(oTy); - if (!ct) continue; // Check only ciphertexts - - if (rings == nullptr) { - rings = ct.getRings(); - continue; - } - - if (rings != ct.getRings()) { - return op->emitOpError() - << "requires all operands and results to have the same rings"; - } - } - return success(); - } -}; - -} // namespace mlir::heir::bgv - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ diff --git a/include/Dialect/BGV/IR/BGVTypes.h b/include/Dialect/BGV/IR/BGVTypes.h deleted file mode 100644 index b4bb43b15..000000000 --- a/include/Dialect/BGV/IR/BGVTypes.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ - -#include "include/Dialect/BGV/IR/BGVAttributes.h" -#include "include/Dialect/BGV/IR/BGVDialect.h" - -// Required to pull in poly's Ring_Attr -#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" - -#define GET_TYPEDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVTypes.h.inc" - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_H_ diff --git a/include/Dialect/BGV/IR/BGVTypes.td b/include/Dialect/BGV/IR/BGVTypes.td deleted file mode 100644 index ae4aff93f..000000000 --- a/include/Dialect/BGV/IR/BGVTypes.td +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ - -include "BGVDialect.td" -include "BGVAttributes.td" - -include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/DialectBase.td" -include "mlir/IR/OpBase.td" - -// TODO(#100): Add a plaintext type. - -// A base class for all types in this dialect -class BGV_Type - : TypeDef { - let mnemonic = typeMnemonic; -} - -def Ciphertext : BGV_Type<"Ciphertext", "ciphertext"> { - let summary = "a BGV ciphertext"; - - let description = [{ - This type tracks the BGV ciphertext parameters, including the ciphertext - dimension (number of polynomials) and the set of rings that were used for - the particular BGV scheme instance. The default dimension is 2, representing - a ciphertext that is canonically encrypted against the key basis $(1, s)$. - - The type also includes a ring parameter specification. - - For example, `bgv.ciphertext` is a ciphertext with 3 - polynomials $(c_0, c_1, c_2)$. - - The optional attribute `level` specifies the "current ring". - }]; - - // TODO(#99): Add # of plaintext bits. - let parameters = (ins - BGVRingArrayAttr:$rings, - DefaultValuedParameter<"unsigned", "2">:$dim, - OptionalParameter<"std::optional">:$level - ); - - let assemblyFormat = "`<` `rings` `=` $rings (`,` `dim` `=` $dim^ )? (`,` `level` `=` $level^ )? `>`"; -} - -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTYPES_TD_ diff --git a/include/Dialect/CGGI/Transforms/Passes.td b/include/Dialect/CGGI/Transforms/Passes.td index dad88d8ab..8787e08f4 100644 --- a/include/Dialect/CGGI/Transforms/Passes.td +++ b/include/Dialect/CGGI/Transforms/Passes.td @@ -19,16 +19,4 @@ def SetDefaultParameters : Pass<"cggi-set-default-parameters"> { let dependentDialects = ["mlir::heir::cggi::CGGIDialect"]; } -<<<<<<< HEAD -def StraightLineVectorizer : Pass<"cggi-straight-line-vectorizer"> { - let summary = "A straight-line vectorizer for CGGI bootstrapping ops."; - let description = [{ - This pass vectorizes CGGI ops. It ignores control flow and only vectorizes - straight-line programs within a given region. - }]; - let dependentDialects = ["mlir::heir::cggi::CGGIDialect"]; -} - -======= ->>>>>>> main #endif // INCLUDE_DIALECT_CGGI_TRANSFORMS_PASSES_TD_ diff --git a/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h b/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h deleted file mode 100644 index e97a4b3ef..000000000 --- a/include/Dialect/CGGI/Transforms/StraightLineVectorizer.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ -#define INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace heir { -namespace cggi { - -#define GEN_PASS_DECL_STRAIGHTLINEVECTORIZER -#include "include/Dialect/CGGI/Transforms/Passes.h.inc" - -} // namespace cggi -} // namespace heir -} // namespace mlir - -#endif // INCLUDE_DIALECT_CGGI_TRANSFORMS_STRAIGHTLINEVECTORIZER_H_ diff --git a/include/Dialect/LWE/IR/LWEAttributes.h b/include/Dialect/LWE/IR/LWEAttributes.h index 0087f73e3..092e1b568 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.h +++ b/include/Dialect/LWE/IR/LWEAttributes.h @@ -4,12 +4,9 @@ #include "include/Dialect/LWE/IR/LWEDialect.h" #include "mlir/include/mlir/IR/TensorEncoding.h" // from @llvm-project -<<<<<<< HEAD -======= // Required to pull in poly's Ring_Attr #include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" ->>>>>>> main #define GET_ATTRDEF_CLASSES #include "include/Dialect/LWE/IR/LWEAttributes.h.inc" diff --git a/include/Dialect/LWE/IR/LWEAttributes.td b/include/Dialect/LWE/IR/LWEAttributes.td index 08b6b0679..24c55933f 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.td +++ b/include/Dialect/LWE/IR/LWEAttributes.td @@ -309,24 +309,16 @@ def LWE_RLWEParams : AttrDef { let description = [{ An attribute describing classical RLWE parameters: -<<<<<<< HEAD - - `cmod`: the coefficient modulus for the polynomials. -======= ->>>>>>> main - `dimension`: the number of polynomials used in an RLWE sample, analogous to LWEParams.dimension. - `polyDegree`: the degree $N$ of the negacyclic polynomial modulus $x^N + 1$. }]; -<<<<<<< HEAD - let parameters = (ins "IntegerAttr": $cmod, "unsigned":$dimension, "unsigned": $polyDegree); -======= let parameters = (ins DefaultValuedParameter<"unsigned", "2">:$dimension, "::mlir::heir::polynomial::RingAttr":$ring ); ->>>>>>> main let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/include/Dialect/Openfhe/IR/OpenfheOps.td b/include/Dialect/Openfhe/IR/OpenfheOps.td index e6c4fcfec..4f238163a 100644 --- a/include/Dialect/Openfhe/IR/OpenfheOps.td +++ b/include/Dialect/Openfhe/IR/OpenfheOps.td @@ -30,8 +30,6 @@ class Openfhe_UnaryOp traits = []> let results = (outs RLWECiphertext:$output); } -<<<<<<< HEAD -======= class Openfhe_UnaryTypeSwitchOp traits = []> : Openfhe_Op traits = []> let results = (outs RLWECiphertext:$output); } ->>>>>>> main class Openfhe_BinaryOp traits = []> : Openfhe_Op { let summary = "OpenFHE negate operati def SquareOp : Openfhe_UnaryOp<"square"> { let summary = "OpenFHE square operation of a ciphertext."; } def RelinOp : Openfhe_UnaryOp<"relin"> { let summary = "OpenFHE relinearize operation of a ciphertext."; } -<<<<<<< HEAD -def ModReduceOp : Openfhe_UnaryOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } -def LevelReduceOp : Openfhe_UnaryOp<"level_reduce"> { let summary = "OpenFHE level_reduce operation of a ciphertext."; } -======= def ModReduceOp : Openfhe_UnaryTypeSwitchOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } def LevelReduceOp : Openfhe_UnaryTypeSwitchOp<"level_reduce"> { let summary = "OpenFHE level_reduce operation of a ciphertext."; } ->>>>>>> main def RotOp : Openfhe_Op<"rot",[ Pure, diff --git a/include/Dialect/Polynomial/IR/PolynomialTypes.td b/include/Dialect/Polynomial/IR/PolynomialTypes.td index d2c42f628..2637131da 100644 --- a/include/Dialect/Polynomial/IR/PolynomialTypes.td +++ b/include/Dialect/Polynomial/IR/PolynomialTypes.td @@ -4,18 +4,6 @@ include "include/Dialect/Polynomial/IR/PolynomialDialect.td" include "include/Dialect/Polynomial/IR/PolynomialAttributes.td" -<<<<<<< HEAD -include "mlir/IR/DialectBase.td" -include "mlir/IR/AttrTypeBase.td" - -// A base class for all types in this dialect -class Polynomial_Type - : TypeDef { - let mnemonic = typeMnemonic; -} - -def Polynomial : Polynomial_Type<"Polynomial", "polynomial"> { -======= include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/DialectBase.td" @@ -27,7 +15,6 @@ class Polynomial_Type traits = []> } def Polynomial : Polynomial_Type<"Polynomial", "polynomial", [MemRefElementTypeInterface]> { ->>>>>>> main let summary = "An element of a polynomial quotient ring"; let description = [{ diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index ea4b51a50..92dd1033f 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -68,7 +68,6 @@ def XorPackedOp : TfheRustBool_Op<"xor_packed", [ let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); } - def NotOp : TfheRustBool_Op<"not", [ Pure, AllTypesMatch<["input", "output"]> diff --git a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp index caaa03947..4857ec746 100644 --- a/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp +++ b/lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp @@ -5,11 +5,6 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/BGV/IR/BGVOps.h" -<<<<<<< HEAD -#include "include/Dialect/BGV/IR/BGVTypes.h" - == == == - = ->>>>>>> main #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" @@ -26,275 +21,227 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - namespace mlir::heir::bgv { +namespace mlir::heir::bgv { #define GEN_PASS_DEF_BGVTOOPENFHE #include "include/Conversion/BGVToOpenfhe/BGVToOpenfhe.h.inc" - class ToLWECiphertextTypeConverter : public TypeConverter { - public: -<<<<<<< HEAD - // Convert ciphertext to RLWE ciphertext - ToLWECiphertextTypeConverter(MLIRContext *ctx) { - addConversion([](Type type) { return type; }); - addConversion([ctx](CiphertextType type) -> Type { - assert(type.getLevel().has_value()); - auto level = type.getLevel().value(); - assert(level < type.getRings().getRings().size()); - - auto ring = type.getRings().getRings()[level]; - auto cmod = ring.getCmod(); - auto polyDegree = ring.getIdeal().getDegree(); - auto dim = type.getDim(); - return lwe::RLWECiphertextType::get( - // TODO(#99): Set a default encoding when the adding # of plaintext - // bits on BGV Ciphertext is done. - ctx, lwe::PolynomialEvaluationEncodingAttr::get(ctx, 0, 0), - lwe::RLWEParamsAttr::get(ctx, cmod, dim, polyDegree)); - }); -======= - ToLWECiphertextTypeConverter(MLIRContext *ctx) { - addConversion([](Type type) { return type; }); ->>>>>>> main - } - }; - - bool containsBGVOps(func::FuncOp func) { - auto walkResult = func.walk([&](Operation *op) { - if (llvm::isa(op->getDialect())) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return walkResult.wasInterrupted(); +class ToLWECiphertextTypeConverter : public TypeConverter { + public: + ToLWECiphertextTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); } - - FailureOr getContextualCryptoContext(Operation * op) { - Value cryptoContext = op->getParentOfType() - .getBody() - .getBlocks() - .front() - .getArguments() - .front(); - if (!cryptoContext.getType().isa()) { - return op->emitOpError() << "Found BGV op in a function without a public " - "key argument. Did the AddCryptoContextArg " - "pattern fail to run?"; - } - return cryptoContext; +}; + +bool containsBGVOps(func::FuncOp func) { + auto walkResult = func.walk([&](Operation *op) { + if (llvm::isa(op->getDialect())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return walkResult.wasInterrupted(); +} + +FailureOr getContextualCryptoContext(Operation *op) { + Value cryptoContext = op->getParentOfType() + .getBody() + .getBlocks() + .front() + .getArguments() + .front(); + if (!cryptoContext.getType().isa()) { + return op->emitOpError() + << "Found BGV op in a function without a public " + "key argument. Did the AddCryptoContextArg pattern fail to run?"; } + return cryptoContext; +} - struct AddCryptoContextArg : public OpConversionPattern { - AddCryptoContextArg(mlir::MLIRContext *context) - : OpConversionPattern(context, /* benefit= */ 2) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!containsBGVOps(op)) { - return failure(); - } - - auto cryptoContextType = openfhe::CryptoContextType::get(getContext()); - FunctionType originalType = op.getFunctionType(); - llvm::SmallVector newTypes; - newTypes.reserve(originalType.getNumInputs() + 1); - newTypes.push_back(cryptoContextType); - for (auto t : originalType.getInputs()) { - newTypes.push_back(t); - } - auto newFuncType = - FunctionType::get(getContext(), newTypes, originalType.getResults()); - rewriter.modifyOpInPlace(op, [&] { - op.setType(newFuncType); - - Block &block = op.getBody().getBlocks().front(); - block.insertArgument(&block.getArguments().front(), cryptoContextType, - op.getLoc()); - }); - - return success(); - } - }; - - template - struct ConvertUnaryOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - UnaryOp op, typename UnaryOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualCryptoContext(op.getOperation()); - if (failed(result)) return result; - - Value cryptoContext = result.value(); - rewriter.replaceOp( - op, rewriter.create(op.getLoc(), cryptoContext, - adaptor.getOperands()[0])); - return success(); - } - }; - - template - struct ConvertBinOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - BinOp op, typename BinOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualCryptoContext(op.getOperation()); - if (failed(result)) return result; - - Value cryptoContext = result.value(); - rewriter.replaceOp( - op, rewriter.create(op.getLoc(), cryptoContext, - adaptor.getOperands()[0], - adaptor.getOperands()[1])); - return success(); +struct AddCryptoContextArg : public OpConversionPattern { + AddCryptoContextArg(mlir::MLIRContext *context) + : OpConversionPattern(context, /* benefit= */ 2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!containsBGVOps(op)) { + return failure(); } - }; - - using ConvertNegateOp = ConvertUnaryOp; - - using ConvertAddOp = ConvertBinOp; - using ConvertSubOp = ConvertBinOp; - using ConvertMulOp = ConvertBinOp; - - struct ConvertRotateOp : public OpConversionPattern { - ConvertRotateOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - Rotate op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualCryptoContext(op.getOperation()); - if (failed(result)) return result; - - Value cryptoContext = result.value(); - auto offsetValue = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adaptor.getOffset())); - rewriter.replaceOp( - op, rewriter.create(op.getLoc(), cryptoContext, - adaptor.getX(), offsetValue)); - return success(); + + auto cryptoContextType = openfhe::CryptoContextType::get(getContext()); + FunctionType originalType = op.getFunctionType(); + llvm::SmallVector newTypes; + newTypes.reserve(originalType.getNumInputs() + 1); + newTypes.push_back(cryptoContextType); + for (auto t : originalType.getInputs()) { + newTypes.push_back(t); } - }; + auto newFuncType = + FunctionType::get(getContext(), newTypes, originalType.getResults()); + rewriter.modifyOpInPlace(op, [&] { + op.setType(newFuncType); + + Block &block = op.getBody().getBlocks().front(); + block.insertArgument(&block.getArguments().front(), cryptoContextType, + op.getLoc()); + }); - bool checkRelinToBasis(llvm::ArrayRef toBasis) { - if (toBasis.size() != 2) return false; - return toBasis[0] == 0 && toBasis[1] == 1; + return success(); + } +}; + +template +struct ConvertUnaryOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOp( + op, rewriter.create(op.getLoc(), cryptoContext, + adaptor.getOperands()[0])); + return success(); } - struct ConvertRelinOp : public OpConversionPattern { - ConvertRelinOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - Relinearize op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualCryptoContext(op.getOperation()); - if (failed(result)) return result; - - auto toBasis = adaptor.getToBasis(); - - // Since the `Relinearize()` function in OpenFHE relinearizes a ciphertext - // to the lowest level (for (1,s)), the `to_basis` of `bgv.RelinOp` must - // be [0,1]. - if (!checkRelinToBasis(toBasis)) { - op.emitError() << "toBasis must be [0, 1], got [" << toBasis << "]"; - return failure(); - } - - Value cryptoContext = result.value(); - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), cryptoContext, adaptor.getX())); - return success(); +}; + +template +struct ConvertBinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + BinOp op, typename BinOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOp(op, + rewriter.create(op.getLoc(), cryptoContext, + adaptor.getOperands()[0], + adaptor.getOperands()[1])); + return success(); + } +}; + +using ConvertNegateOp = ConvertUnaryOp; + +using ConvertAddOp = ConvertBinOp; +using ConvertSubOp = ConvertBinOp; +using ConvertMulOp = ConvertBinOp; + +struct ConvertRotateOp : public OpConversionPattern { + ConvertRotateOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + Rotate op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + auto offsetValue = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adaptor.getOffset())); + rewriter.replaceOp( + op, rewriter.create(op.getLoc(), cryptoContext, + adaptor.getX(), offsetValue)); + return success(); + } +}; + +bool checkRelinToBasis(llvm::ArrayRef toBasis) { + if (toBasis.size() != 2) return false; + return toBasis[0] == 0 && toBasis[1] == 1; +} +struct ConvertRelinOp : public OpConversionPattern { + ConvertRelinOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + Relinearize op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + auto toBasis = adaptor.getToBasis(); + + // Since the `Relinearize()` function in OpenFHE relinearizes a ciphertext + // to the lowest level (for (1,s)), the `to_basis` of `bgv.RelinOp` must be + // [0,1]. + if (!checkRelinToBasis(toBasis)) { + op.emitError() << "toBasis must be [0, 1], got [" << toBasis << "]"; + return failure(); } - }; -<<<<<<< HEAD - bool checkModulusSwitchLevels(unsigned long fromLevel, - unsigned long toLevel) { - return fromLevel == toLevel + 1; + Value cryptoContext = result.value(); + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), cryptoContext, adaptor.getX())); + return success(); } -======= ->>>>>>> main - struct ConvertModulusSwitchOp : public OpConversionPattern { - ConvertModulusSwitchOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - ModulusSwitch op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualCryptoContext(op.getOperation()); - if (failed(result)) return result; - -<<<<<<< HEAD - auto fromLevel = adaptor.getFromLevel(); - auto toLevel = adaptor.getToLevel(); - - // Since the `ModReduce()` function in OpenFHE decreases the level of a - // ciphertext by 1, `fromLevel == toLevel + 1` holds for - // `bgv.ModulusSwitch`. - if (!checkModulusSwitchLevels(fromLevel, toLevel)) { - op.emitError() << "fromLevel must be toLevel + 1, got " - << "fromLevel: " << fromLevel << " and " - << "toLevel: " << toLevel; - return failure(); - } - - Value cryptoContext = result.value(); - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), cryptoContext, adaptor.getX())); -======= - Value cryptoContext = result.value(); - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), op.getOutput().getType(), - cryptoContext, adaptor.getX())); ->>>>>>> main - return success(); - } - }; - - struct BGVToOpenfhe : public impl::BGVToOpenfheBase { - void runOnOperation() override { - MLIRContext *context = &getContext(); - auto *module = getOperation(); - ToLWECiphertextTypeConverter typeConverter(context); - - ConversionTarget target(*context); - target.addLegalDialect(); - target.addIllegalDialect(); - - RewritePatternSet patterns(context); - addStructuralConversionPatterns(typeConverter, patterns, target); - - target.addDynamicallyLegalOp([&](func::FuncOp op) { - bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 && - op.getFunctionType() - .getInputs() - .begin() - ->isa(); - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()) && - (!containsBGVOps(op) || hasCryptoContextArg); - }); - patterns.add(typeConverter, - context); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - return signalPassFailure(); - } +}; + +struct ConvertModulusSwitchOp : public OpConversionPattern { + ConvertModulusSwitchOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ModulusSwitch op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), op.getOutput().getType(), + cryptoContext, adaptor.getX())); + return success(); + } +}; + +struct BGVToOpenfhe : public impl::BGVToOpenfheBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + ToLWECiphertextTypeConverter typeConverter(context); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(context); + addStructuralConversionPatterns(typeConverter, patterns, target); + + target.addDynamicallyLegalOp([&](func::FuncOp op) { + bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 && + op.getFunctionType() + .getInputs() + .begin() + ->isa(); + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()) && + (!containsBGVOps(op) || hasCryptoContextArg); + }); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); } - }; + } +}; } // namespace mlir::heir::bgv diff --git a/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp b/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp index 6e3c4ac00..c50d7330b 100644 --- a/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp +++ b/lib/Conversion/BGVToPolynomial/BGVToPolynomial.cpp @@ -5,12 +5,7 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" #include "include/Dialect/BGV/IR/BGVOps.h" -<<<<<<< HEAD -#include "include/Dialect/BGV/IR/BGVTypes.h" - == == == - = #include "include/Dialect/LWE/IR/LWETypes.h" - >>>>>>> main #include "include/Dialect/Polynomial/IR/Polynomial.h" #include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "include/Dialect/Polynomial/IR/PolynomialOps.h" @@ -18,179 +13,162 @@ #include "lib/Conversion/Utils.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project - <<<<<<>>>>>> - main -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - namespace mlir::heir::bgv { +namespace mlir::heir::bgv { #define GEN_PASS_DEF_BGVTOPOLYNOMIAL #include "include/Conversion/BGVToPolynomial/BGVToPolynomial.h.inc" - class CiphertextTypeConverter : public TypeConverter { - public: - // Convert ciphertext to tensor<#dim x !poly.poly<#rings[#level]>> - CiphertextTypeConverter(MLIRContext *ctx) { - addConversion([](Type type) { return type; }); -<<<<<<< HEAD - addConversion([ctx](CiphertextType type) -> Type { - assert(type.getLevel().has_value()); - auto level = type.getLevel().value(); - assert(level < type.getRings().getRings().size()); - - auto ring = type.getRings().getRings()[level]; - auto polyTy = polynomial::PolynomialType::get(ctx, ring); - - return RankedTensorType::get({type.getDim()}, polyTy); -======= - addConversion([ctx](lwe::RLWECiphertextType type) -> Type { - auto rlweParams = type.getRlweParams(); - auto ring = rlweParams.getRing(); - auto polyTy = polynomial::PolynomialType::get(ctx, ring); - - return RankedTensorType::get({rlweParams.getDimension()}, polyTy); ->>>>>>> main - }); +class CiphertextTypeConverter : public TypeConverter { + public: + // Convert ciphertext to tensor<#dim x !poly.poly<#rings[#level]>> + CiphertextTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](lwe::RLWECiphertextType type) -> Type { + auto rlweParams = type.getRlweParams(); + auto ring = rlweParams.getRing(); + auto polyTy = polynomial::PolynomialType::get(ctx, ring); + + return RankedTensorType::get({rlweParams.getDimension()}, polyTy); + }); + } + // We don't include any custom materialization ops because this lowering is + // all done in a single pass. The dialect conversion framework works by + // resolving intermediate (mid-pass) type conflicts by inserting + // unrealized_conversion_cast ops, and only converting those to custom + // materializations if they persist at the end of the pass. In our case, + // we'd only need to use custom materializations if we split this lowering + // across multiple passes. +}; + +struct ConvertAdd : public OpConversionPattern { + ConvertAdd(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), adaptor.getOperands()[0], + adaptor.getOperands()[1])); + return success(); + } +}; + +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), adaptor.getOperands()[0], + adaptor.getOperands()[1])); + return success(); + } +}; + +struct ConvertNegate : public OpConversionPattern { + ConvertNegate(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + Negate op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto arg = adaptor.getOperands()[0]; + auto neg = rewriter.create(loc, -1, /*width=*/8); + rewriter.replaceOp(op, rewriter.create( + loc, arg.getType(), arg, neg)); + return success(); + } +}; + +struct ConvertMul : public OpConversionPattern { + ConvertMul(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto x = adaptor.getX(); + auto xT = cast(x.getType()); + auto y = adaptor.getY(); + auto yT = cast(y.getType()); + + if (xT.getNumElements() != 2 || yT.getNumElements() != 2) { + op.emitError() << "`bgv.mul` expects ciphertext as two polynomials, got " + << xT.getNumElements() << " and " << yT.getNumElements(); + return failure(); } - // We don't include any custom materialization ops because this lowering is - // all done in a single pass. The dialect conversion framework works by - // resolving intermediate (mid-pass) type conflicts by inserting - // unrealized_conversion_cast ops, and only converting those to custom - // materializations if they persist at the end of the pass. In our case, - // we'd only need to use custom materializations if we split this lowering - // across multiple passes. - }; - - struct ConvertAdd : public OpConversionPattern { - ConvertAdd(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - AddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), adaptor.getOperands()[0], - adaptor.getOperands()[1])); - return success(); + + if (xT.getElementType() != yT.getElementType()) { + op->emitOpError() << "`bgv.mul` expects operands of the same type"; + return failure(); } - }; - struct ConvertSub : public OpConversionPattern { - ConvertSub(mlir::MLIRContext *context) - : OpConversionPattern(context) {} + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + // z = mul([x0, x1], [y0, y1]) := [x0.y0, x0.y1 + x1.y0, x1.y1] + auto i0 = b.create(0); + auto i1 = b.create(1); - using OpConversionPattern::OpConversionPattern; + auto x0 = + b.create(xT.getElementType(), x, ValueRange{i0}); + auto x1 = + b.create(xT.getElementType(), x, ValueRange{i1}); - LogicalResult matchAndRewrite( - SubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), adaptor.getOperands()[0], - adaptor.getOperands()[1])); - return success(); - } - }; - - struct ConvertNegate : public OpConversionPattern { - ConvertNegate(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - Negate op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto arg = adaptor.getOperands()[0]; - auto neg = rewriter.create(loc, -1, /*width=*/8); - rewriter.replaceOp(op, rewriter.create( - loc, arg.getType(), arg, neg)); - return success(); - } - }; - - struct ConvertMul : public OpConversionPattern { - ConvertMul(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - MulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto x = adaptor.getX(); - auto xT = cast(x.getType()); - auto y = adaptor.getY(); - auto yT = cast(y.getType()); - - if (xT.getNumElements() != 2 || yT.getNumElements() != 2) { - op.emitError() - << "`bgv.mul` expects ciphertext as two polynomials, got " - << xT.getNumElements() << " and " << yT.getNumElements(); - return failure(); - } - - if (xT.getElementType() != yT.getElementType()) { - op->emitOpError() << "`bgv.mul` expects operands of the same type"; - return failure(); - } - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - // z = mul([x0, x1], [y0, y1]) := [x0.y0, x0.y1 + x1.y0, x1.y1] - auto i0 = b.create(0); - auto i1 = b.create(1); - - auto x0 = - b.create(xT.getElementType(), x, ValueRange{i0}); - auto x1 = - b.create(xT.getElementType(), x, ValueRange{i1}); - - auto y0 = - b.create(yT.getElementType(), y, ValueRange{i0}); - auto y1 = - b.create(yT.getElementType(), y, ValueRange{i1}); - - auto z0 = b.create(x0, y0); - auto x0y1 = b.create(x0, y1); - auto x1y0 = b.create(x1, y0); - auto z1 = b.create(x0y1, x1y0); - auto z2 = b.create(x1, y1); - - auto z = b.create(ArrayRef({z0, z1, z2})); - - rewriter.replaceOp(op, z); - return success(); - } - }; + auto y0 = + b.create(yT.getElementType(), y, ValueRange{i0}); + auto y1 = + b.create(yT.getElementType(), y, ValueRange{i1}); + + auto z0 = b.create(x0, y0); + auto x0y1 = b.create(x0, y1); + auto x1y0 = b.create(x1, y0); + auto z1 = b.create(x0y1, x1y0); + auto z2 = b.create(x1, y1); + + auto z = b.create(ArrayRef({z0, z1, z2})); + + rewriter.replaceOp(op, z); + return success(); + } +}; - struct BGVToPolynomial : public impl::BGVToPolynomialBase { - void runOnOperation() override { - MLIRContext *context = &getContext(); - auto *module = getOperation(); - CiphertextTypeConverter typeConverter(context); +struct BGVToPolynomial : public impl::BGVToPolynomialBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + CiphertextTypeConverter typeConverter(context); - ConversionTarget target(*context); - target.addLegalOp(); + ConversionTarget target(*context); + target.addLegalOp(); - RewritePatternSet patterns(context); + RewritePatternSet patterns(context); - patterns.add( - typeConverter, context); - target.addIllegalOp(); + patterns.add( + typeConverter, context); + target.addIllegalOp(); - addStructuralConversionPatterns(typeConverter, patterns, target); + addStructuralConversionPatterns(typeConverter, patterns, target); - // Run full conversion, if any BGV ops were missed out the pass will fail. - if (failed(applyFullConversion(module, target, std::move(patterns)))) { - return signalPassFailure(); - } + // Run full conversion, if any BGV ops were missed out the pass will fail. + if (failed(applyFullConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); } - }; + } +}; } // namespace mlir::heir::bgv diff --git a/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp index 6b83bed6f..c2e20944b 100644 --- a/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp +++ b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp @@ -4,11 +4,6 @@ #include "include/Dialect/CGGI/IR/CGGIDialect.h" #include "include/Dialect/CGGI/IR/CGGIOps.h" -<<<<<<< HEAD -#include "include/Dialect/LWE/IR/LWEAttributes.h" - == == == - = ->>>>>>> main #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/LWE/IR/LWEOps.h" #include "include/Dialect/LWE/IR/LWETypes.h" @@ -16,284 +11,266 @@ #include "include/Dialect/TfheRustBool/IR/TfheRustBoolOps.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolTypes.h" #include "lib/Conversion/Utils.h" -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project - <<<<<<>>>>>> - main -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project - <<<<<<>>>>>> - main -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - namespace mlir::heir { +namespace mlir::heir { #define GEN_PASS_DEF_CGGITOTFHERUSTBOOL #include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc" - class CGGIToTfheRustBoolTypeConverter : public TypeConverter { - public: - CGGIToTfheRustBoolTypeConverter(MLIRContext *ctx) { - addConversion([](Type type) { return type; }); - addConversion([ctx](lwe::LWECiphertextType type) -> Type { - return tfhe_rust_bool::EncryptedBoolType::get(ctx); - }); - addConversion([this](ShapedType type) -> Type { - return type.cloneWith(type.getShape(), - this->convertType(type.getElementType())); - }); - } - }; - - // /// Returns true if the func's body contains any CGGI ops. - bool containsCGGIOpsBool(func::FuncOp func) { - auto walkResult = func.walk([&](Operation *op) { - if (llvm::isa(op->getDialect())) - return WalkResult::interrupt(); - return WalkResult::advance(); +class CGGIToTfheRustBoolTypeConverter : public TypeConverter { + public: + CGGIToTfheRustBoolTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](lwe::LWECiphertextType type) -> Type { + return tfhe_rust_bool::EncryptedBoolType::get(ctx); + }); + addConversion([this](ShapedType type) -> Type { + return type.cloneWith(type.getShape(), + this->convertType(type.getElementType())); }); - return walkResult.wasInterrupted(); } - - /// Returns the Value corresponding to a server key in the FuncOp containing - /// this op. - FailureOr getContextualBoolServerKey(Operation * op) { - Value serverKey = op->getParentOfType() - .getBody() - .getBlocks() - .front() - .getArguments() - .front(); - if (!serverKey.getType().isa()) { - return op->emitOpError() - << "Found CGGI op in a function without a server " - "key argument. Did the AddBoolServerKeyArg pattern fail to " - "run?"; - } - return serverKey; +}; + +// /// Returns true if the func's body contains any CGGI ops. +bool containsCGGIOpsBool(func::FuncOp func) { + auto walkResult = func.walk([&](Operation *op) { + if (llvm::isa(op->getDialect())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return walkResult.wasInterrupted(); +} + +/// Returns the Value corresponding to a server key in the FuncOp containing +/// this op. +FailureOr getContextualBoolServerKey(Operation *op) { + Value serverKey = op->getParentOfType() + .getBody() + .getBlocks() + .front() + .getArguments() + .front(); + if (!serverKey.getType().isa()) { + return op->emitOpError() + << "Found CGGI op in a function without a server " + "key argument. Did the AddBoolServerKeyArg pattern fail to run?"; } - - template - struct GenericOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector retTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - retTypes))) - return failure(); - rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), - op->getAttrs()); - - return success(); + return serverKey; +} + +template +struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +/// Convert a func by adding a server key argument. Converted ops in other +/// patterns need a server key SSA value available, so this pattern needs a +/// higher benefit. +struct AddBoolServerKeyArg : public OpConversionPattern { + AddBoolServerKeyArg(mlir::MLIRContext *context) + : OpConversionPattern(context, /* benefit= */ 2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!containsCGGIOpsBool(op)) { + return failure(); } - }; - - /// Convert a func by adding a server key argument. Converted ops in other - /// patterns need a server key SSA value available, so this pattern needs a - /// higher benefit. - struct AddBoolServerKeyArg : public OpConversionPattern { - AddBoolServerKeyArg(mlir::MLIRContext *context) - : OpConversionPattern(context, /* benefit= */ 2) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!containsCGGIOpsBool(op)) { - return failure(); - } - - auto serverKeyType = tfhe_rust_bool::ServerKeyType::get(getContext()); - FunctionType originalType = op.getFunctionType(); - llvm::SmallVector newTypes; - newTypes.reserve(originalType.getNumInputs() + 1); - newTypes.push_back(serverKeyType); - for (auto t : originalType.getInputs()) { - newTypes.push_back(t); - } - auto newFuncType = - FunctionType::get(getContext(), newTypes, originalType.getResults()); - rewriter.modifyOpInPlace(op, [&] { - op.setType(newFuncType); - - // In addition to updating the type signature, we need to update the - // entry block's arguments to match the type signature - Block &block = op.getBody().getBlocks().front(); - block.insertArgument(&block.getArguments().front(), serverKeyType, - op.getLoc()); - }); - - return success(); + + auto serverKeyType = tfhe_rust_bool::ServerKeyType::get(getContext()); + FunctionType originalType = op.getFunctionType(); + llvm::SmallVector newTypes; + newTypes.reserve(originalType.getNumInputs() + 1); + newTypes.push_back(serverKeyType); + for (auto t : originalType.getInputs()) { + newTypes.push_back(t); } - }; + auto newFuncType = + FunctionType::get(getContext(), newTypes, originalType.getResults()); + rewriter.modifyOpInPlace(op, [&] { + op.setType(newFuncType); + + // In addition to updating the type signature, we need to update the + // entry block's arguments to match the type signature + Block &block = op.getBody().getBlocks().front(); + block.insertArgument(&block.getArguments().front(), serverKeyType, + op.getLoc()); + }); - template - struct ConvertBinOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + return success(); + } +}; - LogicalResult matchAndRewrite( - BinOp op, typename BinOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - FailureOr result = getContextualBoolServerKey(op); - if (failed(result)) return result; +template +struct ConvertBinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - Value serverKey = result.value(); + LogicalResult matchAndRewrite( + BinOp op, typename BinOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; - rewriter.replaceOp( - op, b.create(serverKey, adaptor.getLhs(), - adaptor.getRhs())); - return success(); - } - }; + Value serverKey = result.value(); + + rewriter.replaceOp(op, b.create( + serverKey, adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; - using ConvertBoolAndOp = ConvertBinOp; - using ConvertBoolOrOp = ConvertBinOp; - using ConvertBoolXorOp = ConvertBinOp; +using ConvertBoolAndOp = ConvertBinOp; +using ConvertBoolOrOp = ConvertBinOp; +using ConvertBoolXorOp = ConvertBinOp; - struct ConvertBoolNotOp : public OpConversionPattern { - ConvertBoolNotOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} +struct ConvertBoolNotOp : public OpConversionPattern { + ConvertBoolNotOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - cggi::NotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - FailureOr result = getContextualBoolServerKey(op); - if (failed(result)) return result; + LogicalResult matchAndRewrite( + cggi::NotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; - Value serverKey = result.value(); + Value serverKey = result.value(); - rewriter.replaceOp( - op, b.create(serverKey, adaptor.getInput())); - return success(); - } - }; - - struct ConvertBoolTrivialEncryptOp - : public OpConversionPattern { - ConvertBoolTrivialEncryptOp(mlir::MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - lwe::TrivialEncryptOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr result = getContextualBoolServerKey(op.getOperation()); - if (failed(result)) return result; - - Value serverKey = result.value(); - lwe::EncodeOp encodeOp = op.getInput().getDefiningOp(); - if (!encodeOp) { - return op.emitError() << "Expected input to TrivialEncrypt to be the " - "result of an EncodeOp, but it was " - << op.getInput().getDefiningOp()->getName(); - } - auto outputType = tfhe_rust_bool::EncryptedBoolType::get(getContext()); - - auto createTrivialOp = rewriter.create( - op.getLoc(), outputType, serverKey, encodeOp.getPlaintext()); - rewriter.replaceOp(op, createTrivialOp); - return success(); + rewriter.replaceOp( + op, b.create(serverKey, adaptor.getInput())); + return success(); + } +}; + +struct ConvertBoolTrivialEncryptOp + : public OpConversionPattern { + ConvertBoolTrivialEncryptOp(mlir::MLIRContext *context) + : OpConversionPattern(context, /*benefit=*/1) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::TrivialEncryptOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualBoolServerKey(op.getOperation()); + if (failed(result)) return result; + + Value serverKey = result.value(); + lwe::EncodeOp encodeOp = op.getInput().getDefiningOp(); + if (!encodeOp) { + return op.emitError() << "Expected input to TrivialEncrypt to be the " + "result of an EncodeOp, but it was " + << op.getInput().getDefiningOp()->getName(); } - }; + auto outputType = tfhe_rust_bool::EncryptedBoolType::get(getContext()); - struct ConvertBoolEncodeOp : public OpConversionPattern { - ConvertBoolEncodeOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} + auto createTrivialOp = rewriter.create( + op.getLoc(), outputType, serverKey, encodeOp.getPlaintext()); + rewriter.replaceOp(op, createTrivialOp); + return success(); + } +}; - using OpConversionPattern::OpConversionPattern; +struct ConvertBoolEncodeOp : public OpConversionPattern { + ConvertBoolEncodeOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} - LogicalResult matchAndRewrite( - lwe::EncodeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return success(); - } - }; - - class CGGIToTfheRustBool - : public impl::CGGIToTfheRustBoolBase { - void runOnOperation() override { - MLIRContext *context = &getContext(); - auto *op = getOperation(); - - CGGIToTfheRustBoolTypeConverter typeConverter(context); - RewritePatternSet patterns(context); - ConversionTarget target(*context); - addStructuralConversionPatterns(typeConverter, patterns, target); - - target.addLegalDialect(); - target.addIllegalDialect(); - target.addIllegalDialect(); - - // FuncOp is marked legal by the default structural conversion patterns - // helper, just based on type conversion. We need more, but because the - // addDynamicallyLegalOp is a set-based method, we can add this after - // calling addStructuralConversionPatterns and it will overwrite the - // legality condition set in that function. - target.addDynamicallyLegalOp([&](func::FuncOp op) { - bool hasServerKeyArg = op.getFunctionType().getNumInputs() > 0 && - op.getFunctionType() - .getInputs() - .begin() - ->isa(); - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()) && - (!containsCGGIOpsBool(op) || hasServerKeyArg); - }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); - - // FIXME: still need to update callers to insert the new server key arg, - // if needed and possible. - patterns.add< - AddBoolServerKeyArg, ConvertBoolAndOp, ConvertBoolEncodeOp, - ConvertBoolOrOp, ConvertBoolTrivialEncryptOp, ConvertBoolXorOp, - ConvertBoolNotOp, GenericOpPattern, - GenericOpPattern, - GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, - GenericOpPattern, - GenericOpPattern>(typeConverter, context); - - if (failed(applyPartialConversion(op, target, std::move(patterns)))) { - return signalPassFailure(); - } + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::EncodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +class CGGIToTfheRustBool + : public impl::CGGIToTfheRustBoolBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *op = getOperation(); + + CGGIToTfheRustBoolTypeConverter typeConverter(context); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + addStructuralConversionPatterns(typeConverter, patterns, target); + + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + + // FuncOp is marked legal by the default structural conversion patterns + // helper, just based on type conversion. We need more, but because the + // addDynamicallyLegalOp is a set-based method, we can add this after + // calling addStructuralConversionPatterns and it will overwrite the + // legality condition set in that function. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + bool hasServerKeyArg = op.getFunctionType().getNumInputs() > 0 && + op.getFunctionType() + .getInputs() + .begin() + ->isa(); + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()) && + (!containsCGGIOpsBool(op) || hasServerKeyArg); + }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + // FIXME: still need to update callers to insert the new server key arg, if + // needed and possible. + patterns.add< + AddBoolServerKeyArg, ConvertBoolAndOp, ConvertBoolEncodeOp, + ConvertBoolOrOp, ConvertBoolTrivialEncryptOp, ConvertBoolXorOp, + ConvertBoolNotOp, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern >(typeConverter, context); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); } - }; + } +}; } // namespace mlir::heir diff --git a/lib/Dialect/BGV/IR/BGVDialect.cpp b/lib/Dialect/BGV/IR/BGVDialect.cpp index cda9432c3..183e5b45c 100644 --- a/lib/Dialect/BGV/IR/BGVDialect.cpp +++ b/lib/Dialect/BGV/IR/BGVDialect.cpp @@ -1,21 +1,5 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" -<<<<<<< HEAD -#include "include/Dialect/BGV/IR/BGVAttributes.h" -#include "include/Dialect/BGV/IR/BGVOps.h" -#include "include/Dialect/BGV/IR/BGVTypes.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project - -// Generated definitions -#include "include/Dialect/BGV/IR/BGVDialect.cpp.inc" -#define GET_ATTRDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVAttributes.cpp.inc" -#define GET_TYPEDEF_CLASSES -#include "include/Dialect/BGV/IR/BGVTypes.cpp.inc" - == == == - = #include #include "include/Dialect/BGV/IR/BGVOps.h" @@ -29,163 +13,107 @@ // Generated definitions #include "include/Dialect/BGV/IR/BGVDialect.cpp.inc" - >>>>>>> main #define GET_OP_CLASSES #include "include/Dialect/BGV/IR/BGVOps.cpp.inc" - namespace mlir { - namespace heir { - namespace bgv { +namespace mlir { +namespace heir { +namespace bgv { - //===----------------------------------------------------------------------===// - // BGV dialect. - //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// BGV dialect. +//===----------------------------------------------------------------------===// - // Dialect construction: there is one instance per context and it registers - // its operations, types, and interfaces here. - void BGVDialect::initialize() { -<<<<<<< HEAD - addAttributes< -#define GET_ATTRDEF_LIST -#include "include/Dialect/BGV/IR/BGVAttributes.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "include/Dialect/BGV/IR/BGVTypes.cpp.inc" - >(); -======= ->>>>>>> main - addOperations< +// Dialect construction: there is one instance per context and it registers its +// operations, types, and interfaces here. +void BGVDialect::initialize() { + addOperations< #define GET_OP_LIST #include "include/Dialect/BGV/IR/BGVOps.cpp.inc" - >(); - } + >(); +} - LogicalResult MulOp::verify() { - auto x = getX().getType(); - auto y = getY().getType(); -<<<<<<< HEAD - if (x.getDim() != y.getDim()) { - return emitOpError() << "input dimensions do not match"; - } - auto out = getOutput().getType(); - if (out.getDim() != 1 + x.getDim()) { -======= - if (x.getRlweParams().getDimension() != y.getRlweParams().getDimension()) { - return emitOpError() << "input dimensions do not match"; - } - auto out = getOutput().getType(); - if (out.getRlweParams().getDimension() != - 1 + x.getRlweParams().getDimension()) { ->>>>>>> main - return emitOpError() << "output.dim == x.dim + 1 does not hold"; - } - return success(); +LogicalResult MulOp::verify() { + auto x = getX().getType(); + auto y = getY().getType(); + if (x.getRlweParams().getDimension() != y.getRlweParams().getDimension()) { + return emitOpError() << "input dimensions do not match"; } - LogicalResult Rotate::verify() { - auto x = getX().getType(); -<<<<<<< HEAD - if (x.getDim() != 2) { - return emitOpError() << "x.dim == 2 does not hold"; - } - auto out = getOutput().getType(); - if (out.getDim() != 2) { -======= - if (x.getRlweParams().getDimension() != 2) { - return emitOpError() << "x.dim == 2 does not hold"; - } - auto out = getOutput().getType(); - if (out.getRlweParams().getDimension() != 2) { ->>>>>>> main - return emitOpError() << "output.dim == 2 does not hold"; - } - return success(); + auto out = getOutput().getType(); + if (out.getRlweParams().getDimension() != + 1 + x.getRlweParams().getDimension()) { + return emitOpError() << "output.dim == x.dim + 1 does not hold"; } - - LogicalResult Relinearize::verify() { - auto x = getX().getType(); - auto out = getOutput().getType(); -<<<<<<< HEAD - if (x.getDim() != getFromBasis().size()) { - return emitOpError() << "input dimension does not match from_basis"; - } - if (out.getDim() != getToBasis().size()) { -======= - if (x.getRlweParams().getDimension() != getFromBasis().size()) { - return emitOpError() << "input dimension does not match from_basis"; - } - if (out.getRlweParams().getDimension() != getToBasis().size()) { ->>>>>>> main - return emitOpError() << "output dimension does not match to_basis"; - } - return success(); + return success(); +} +LogicalResult Rotate::verify() { + auto x = getX().getType(); + if (x.getRlweParams().getDimension() != 2) { + return emitOpError() << "x.dim == 2 does not hold"; } + auto out = getOutput().getType(); + if (out.getRlweParams().getDimension() != 2) { + return emitOpError() << "output.dim == 2 does not hold"; + } + return success(); +} - LogicalResult ModulusSwitch::verify() { - auto x = getX().getType(); -<<<<<<< HEAD - auto rings = x.getRings().getRings().size(); - auto to = getToLevel(); - auto from = getFromLevel(); - if (to < 0 || to >= from || from >= rings) { - return emitOpError() << "invalid levels, should be true: 0 <= " << to - << " < " << from << " < " << rings; - } - if (x.getLevel().has_value() && x.getLevel().value() != from) { - return emitOpError() << "input level does not match from_level"; - } - auto outLvl = getOutput().getType().getLevel(); - if (!outLvl.has_value() || outLvl.value() != to) { - return emitOpError() - << "output level should be specified and match to_level"; - } -======= - auto xRing = x.getRlweParams().getRing(); +LogicalResult Relinearize::verify() { + auto x = getX().getType(); + auto out = getOutput().getType(); + if (x.getRlweParams().getDimension() != getFromBasis().size()) { + return emitOpError() << "input dimension does not match from_basis"; + } + if (out.getRlweParams().getDimension() != getToBasis().size()) { + return emitOpError() << "output dimension does not match to_basis"; + } + return success(); +} - auto out = getOutput().getType(); - auto outRing = out.getRlweParams().getRing(); - if (outRing != getToRing()) { - return emitOpError() << "output ring should match to_ring"; - } - if (xRing.getCmod().getValue().ule(outRing.getCmod().getValue())) { - return emitOpError() << "output ring modulus should be less than the " - "input ring modulus"; - } - if (!xRing.getCmod() - .getValue() - .urem(outRing.getCmod().getValue()) - .isZero()) { - return emitOpError() - << "output ring modulus should divide the input ring modulus"; - } +LogicalResult ModulusSwitch::verify() { + auto x = getX().getType(); + auto xRing = x.getRlweParams().getRing(); - return success(); + auto out = getOutput().getType(); + auto outRing = out.getRlweParams().getRing(); + if (outRing != getToRing()) { + return emitOpError() << "output ring should match to_ring"; } - - LogicalResult MulOp::inferReturnTypes( - MLIRContext *ctx, std::optional, MulOp::Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes) { - auto x = cast(adaptor.getX().getType()); - auto y = cast(adaptor.getY().getType()); - auto newDim = - x.getRlweParams().getDimension() + y.getRlweParams().getDimension() - 1; - inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( - ctx, x.getEncoding(), - lwe::RLWEParamsAttr::get(ctx, newDim, x.getRlweParams().getRing()))); - return success(); + if (xRing.getCmod().getValue().ule(outRing.getCmod().getValue())) { + return emitOpError() + << "output ring modulus should be less than the input ring modulus"; } - - LogicalResult Relinearize::inferReturnTypes( - MLIRContext *ctx, std::optional, Relinearize::Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes) { - auto x = cast(adaptor.getX().getType()); - inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( - ctx, x.getEncoding(), - lwe::RLWEParamsAttr::get(ctx, 2, x.getRlweParams().getRing()))); ->>>>>>> main - return success(); + if (!xRing.getCmod().getValue().urem(outRing.getCmod().getValue()).isZero()) { + return emitOpError() + << "output ring modulus should divide the input ring modulus"; } - } // namespace bgv - } // namespace heir + return success(); +} + +LogicalResult MulOp::inferReturnTypes( + MLIRContext *ctx, std::optional, MulOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + auto x = cast(adaptor.getX().getType()); + auto y = cast(adaptor.getY().getType()); + auto newDim = + x.getRlweParams().getDimension() + y.getRlweParams().getDimension() - 1; + inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( + ctx, x.getEncoding(), + lwe::RLWEParamsAttr::get(ctx, newDim, x.getRlweParams().getRing()))); + return success(); +} + +LogicalResult Relinearize::inferReturnTypes( + MLIRContext *ctx, std::optional, Relinearize::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + auto x = cast(adaptor.getX().getType()); + inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( + ctx, x.getEncoding(), + lwe::RLWEParamsAttr::get(ctx, 2, x.getRlweParams().getRing()))); + return success(); +} + +} // namespace bgv +} // namespace heir } // namespace mlir diff --git a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp index f852d7555..ba6f340c9 100644 --- a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp +++ b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp @@ -1,11 +1,5 @@ #include "include/Dialect/CGGI/Transforms/SetDefaultParameters.h" -<<<<<<< HEAD -#include "include/Dialect/CGGI/IR/CGGIAttributes.h" -#include "include/Dialect/CGGI/IR/CGGIOps.h" -#include "include/Dialect/LWE/IR/LWEAttributes.h" - == == == - = #include #include "include/Dialect/CGGI/IR/CGGIAttributes.h" @@ -13,86 +7,73 @@ #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Dialect/Polynomial/IR/Polynomial.h" #include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" - >>>>>>> main #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project - namespace mlir { - namespace heir { - namespace cggi { +namespace mlir { +namespace heir { +namespace cggi { #define GEN_PASS_DEF_SETDEFAULTPARAMETERS #include "include/Dialect/CGGI/Transforms/Passes.h.inc" - struct SetDefaultParameters - : impl::SetDefaultParametersBase { - using SetDefaultParametersBase::SetDefaultParametersBase; +struct SetDefaultParameters + : impl::SetDefaultParametersBase { + using SetDefaultParametersBase::SetDefaultParametersBase; - void runOnOperation() override { - auto *op = getOperation(); - MLIRContext &context = getContext(); - unsigned defaultRlweDimension = 1; -<<<<<<< HEAD - unsigned defaultPolyDegree = 1024; - APInt defaultCmod = APInt::getOneBitSet(64, 32); - IntegerAttr defaultCmodAttr = - IntegerAttr::get(IntegerType::get(&context, 64), defaultCmod); -======= - APInt defaultCmod = APInt::getOneBitSet(64, 32); - std::vector monomials; - monomials.push_back(polynomial::Monomial(1, 1024)); - monomials.push_back(polynomial::Monomial(1, 0)); - polynomial::Polynomial defaultPolyIdeal = - polynomial::Polynomial::fromMonomials(monomials, &context); ->>>>>>> main + void runOnOperation() override { + auto *op = getOperation(); + MLIRContext &context = getContext(); + unsigned defaultRlweDimension = 1; + APInt defaultCmod = APInt::getOneBitSet(64, 32); + std::vector monomials; + monomials.push_back(polynomial::Monomial(1, 1024)); + monomials.push_back(polynomial::Monomial(1, 0)); + polynomial::Polynomial defaultPolyIdeal = + polynomial::Polynomial::fromMonomials(monomials, &context); - // https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py - unsigned defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16 - unsigned defaultBskGadgetBaseLog = 4; - unsigned defaultBskGadgetNumLevels = 6; - unsigned defaultKskNoiseVariance = - 268435456; // stdev = 2**14, var = 2**28 - unsigned defaultKskGadgetBaseLog = 4; - unsigned defaultKskGadgetNumLevels = 5; + // https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py + unsigned defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16 + unsigned defaultBskGadgetBaseLog = 4; + unsigned defaultBskGadgetNumLevels = 6; + unsigned defaultKskNoiseVariance = 268435456; // stdev = 2**14, var = 2**28 + unsigned defaultKskGadgetBaseLog = 4; + unsigned defaultKskGadgetNumLevels = 5; - lwe::RLWEParamsAttr defaultRlweParams = lwe::RLWEParamsAttr::get( -<<<<<<< HEAD - &context, defaultCmodAttr, defaultRlweDimension, defaultPolyDegree); -======= - &context, defaultRlweDimension, - polynomial::RingAttr::get(defaultCmod, defaultPolyIdeal)); ->>>>>>> main - CGGIParamsAttr defaultParams = CGGIParamsAttr::get( - &context, defaultRlweParams, defaultBskNoiseVariance, - defaultBskGadgetBaseLog, defaultBskGadgetNumLevels, - defaultKskNoiseVariance, defaultKskGadgetBaseLog, - defaultKskGadgetNumLevels); + lwe::RLWEParamsAttr defaultRlweParams = lwe::RLWEParamsAttr::get( + &context, defaultRlweDimension, + polynomial::RingAttr::get(defaultCmod, defaultPolyIdeal)); + CGGIParamsAttr defaultParams = + CGGIParamsAttr::get(&context, defaultRlweParams, + defaultBskNoiseVariance, defaultBskGadgetBaseLog, + defaultBskGadgetNumLevels, defaultKskNoiseVariance, + defaultKskGadgetBaseLog, defaultKskGadgetNumLevels); - auto walkResult = op->walk([&](Operation *op) { - return llvm::TypeSwitch(*op) - .Case([&](auto op) { - op.getOperation()->setAttr("cggi_params", defaultParams); - return WalkResult::advance(); - }) - .Default([&](Operation &op) { - if (llvm::isa(op.getDialect())) { - op.emitOpError() << "Found an unsupported cggi op"; - return WalkResult::interrupt(); - } - // An unsupported op doesn't get any parameters set on it, and - // that's OK. - return WalkResult::advance(); - }); - }); + auto walkResult = op->walk([&](Operation *op) { + return llvm::TypeSwitch(*op) + .Case([&](auto op) { + op.getOperation()->setAttr("cggi_params", defaultParams); + return WalkResult::advance(); + }) + .Default([&](Operation &op) { + if (llvm::isa(op.getDialect())) { + op.emitOpError() << "Found an unsupported cggi op"; + return WalkResult::interrupt(); + } + // An unsupported op doesn't get any parameters set on it, and + // that's OK. + return WalkResult::advance(); + }); + }); - if (walkResult.wasInterrupted()) { - signalPassFailure(); - } + if (walkResult.wasInterrupted()) { + signalPassFailure(); } - }; + } +}; - } // namespace cggi - } // namespace heir +} // namespace cggi +} // namespace heir } // namespace mlir diff --git a/lib/Dialect/CGGI/Transforms/StraightLineVectorizer.cpp b/lib/Dialect/CGGI/Transforms/StraightLineVectorizer.cpp deleted file mode 100644 index 0dd2a57ad..000000000 --- a/lib/Dialect/CGGI/Transforms/StraightLineVectorizer.cpp +++ /dev/null @@ -1,179 +0,0 @@ -#include "include/Dialect/CGGI/Transforms/StraightLineVectorizer.h" - -#include "include/Dialect/CGGI/IR/CGGIAttributes.h" -#include "include/Dialect/CGGI/IR/CGGIOps.h" -#include "include/Dialect/LWE/IR/LWEAttributes.h" -#include "include/Graph/Graph.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project - -#define DEBUG_TYPE "straight-line-vectorizer" - -namespace mlir { -namespace heir { -namespace cggi { - -#define GEN_PASS_DEF_STRAIGHTLINEVECTORIZER -#include "include/Dialect/CGGI/Transforms/Passes.h.inc" - -/// Returns true if the two operations can be combined into a single vectorized -/// operation. -bool areCompatible(Operation *lhs, Operation *rhs) { - if (lhs->getName() != rhs->getName() || - lhs->getDialect() != rhs->getDialect() || - lhs->getResultTypes() != rhs->getResultTypes()) { - return false; - } - - return llvm::TypeSwitch(lhs) - .Case([&](auto op) { return true; }) - .Case([&](auto op) { - return cast(rhs).getLookupTable() == op.getLookupTable(); - }) - .Case([&](auto op) { - return cast(rhs).getLookupTable() == op.getLookupTable(); - }) - .Default([&](Operation *) { - lhs->emitOpError("Unsupported check for vectorizable compatibility."); - return false; - }); -} - -bool tryVectorizeBlock(Block *block) { - graph::Graph graph; - for (auto &op : block->getOperations()) { - if (!isa(op.getDialect()) || - !op.hasTrait()) { - continue; - } - - graph.addVertex(&op); - SetVector backwardSlice; - BackwardSliceOptions options; - options.omitBlockArguments = true; - getBackwardSlice(&op, &backwardSlice, options); - for (auto *upstreamDep : backwardSlice) { - // An edge from upstreamDep to `op` means that upstreamDep must be - // computed before `op`. - graph.addEdge(upstreamDep, &op); - } - } - - if (graph.empty()) { - return false; - } - - auto result = graph.sortGraphByLevels(); - assert(succeeded(result) && - "Only possible failure is a cycle in the SSA graph!"); - auto levels = result.value(); - - LLVM_DEBUG({ - llvm::dbgs() - << "Found operations to vectorize. In topo-sorted level order:\n"; - int level_num = 0; - for (const auto &level : levels) { - llvm::dbgs() << "\nLevel " << level_num++ << ":\n"; - for (auto op : level) { - llvm::dbgs() << " - " << *op << "\n"; - } - } - }); - - bool madeReplacement = false; - for (const auto &level : levels) { - DenseMap> compatibleOps; - for (auto *op : level) { - bool foundCompatible = false; - for (auto &[key, bucket] : compatibleOps) { - if (areCompatible(key, op)) { - compatibleOps[key].push_back(op); - foundCompatible = true; - } - } - if (!foundCompatible) { - compatibleOps[op].push_back(op); - } - } - LLVM_DEBUG(llvm::dbgs() - << "Partitioned level of size " << level.size() << " into " - << compatibleOps.size() << " groups of compatible ops\n"); - - for (auto &[key, bucket] : compatibleOps) { - if (bucket.size() < 2) { - continue; - } - - LLVM_DEBUG({ - llvm::dbgs() << "Vectorizing ops:\n"; - for (auto op : bucket) { - llvm::dbgs() << " - " << *op << "\n"; - } - }); - - OpBuilder builder(bucket.back()); - // relies on CGGI ops having a single result type - Type elementType = key->getResultTypes()[0]; - RankedTensorType tensorType = RankedTensorType::get( - {static_cast(bucket.size())}, elementType); - - SmallVector vectorizedOperands; - for (int operandIndex = 0; operandIndex < key->getNumOperands(); - ++operandIndex) { - SmallVector operands; - operands.reserve(bucket.size()); - for (auto *op : bucket) { - operands.push_back(op->getOperand(operandIndex)); - } - auto fromElementsOp = builder.create( - key->getLoc(), tensorType, operands); - vectorizedOperands.push_back(fromElementsOp.getResult()); - } - - Operation *vectorizedOp = builder.clone(*key); - vectorizedOp->setOperands(vectorizedOperands); - vectorizedOp->getResult(0).setType(tensorType); - - int bucketIndex = 0; - for (auto *op : bucket) { - auto extractionIndex = builder.create( - op->getLoc(), builder.getIndexAttr(bucketIndex)); - auto extractOp = builder.create( - op->getLoc(), elementType, vectorizedOp->getResult(0), - extractionIndex.getResult()); - op->replaceAllUsesWith(ValueRange{extractOp.getResult()}); - bucketIndex++; - } - - for (auto *op : bucket) { - op->erase(); - } - madeReplacement = true; - } - } - - return madeReplacement; -} - -struct StraightLineVectorizer - : impl::StraightLineVectorizerBase { - using StraightLineVectorizerBase::StraightLineVectorizerBase; - - void runOnOperation() override { - getOperation()->walk([&](Block *block) { - if (tryVectorizeBlock(block)) { - sortTopologically(block); - } - }); - } -}; - -} // namespace cggi -} // namespace heir -} // namespace mlir diff --git a/lib/Dialect/Secret/IR/SecretPatterns.cpp b/lib/Dialect/Secret/IR/SecretPatterns.cpp index 8ca2b0268..00a262b5f 100644 --- a/lib/Dialect/Secret/IR/SecretPatterns.cpp +++ b/lib/Dialect/Secret/IR/SecretPatterns.cpp @@ -567,18 +567,6 @@ LogicalResult HoistPlaintextOps::matchAndRewrite( return true; }; -<<<<<<< HEAD - auto it = std::find_if(opRange.begin(), opRange.end(), - [&](Operation &op) { return canHoist(op); }); - if (it == opRange.end()) { - return failure(); - } - - Operation *opToHoist = &*it; - LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n"); - genericOp.extractOpBeforeGeneric(opToHoist, rewriter); - return success(); -======= LLVM_DEBUG( llvm::dbgs() << "Scanning generic body looking for ops to hoist...\n"); @@ -603,7 +591,6 @@ LogicalResult HoistPlaintextOps::matchAndRewrite( LLVM_DEBUG(llvm::dbgs() << "Done hoisting\n"); return hoistedAny ? success() : failure(); ->>>>>>> main } void genericAbsorbConstants(secret::GenericOp genericOp, @@ -623,11 +610,7 @@ void genericAbsorbConstants(secret::GenericOp genericOp, // inside the region. Region *operandRegion = definingOp->getParentRegion(); if (operandRegion && !genericOp.getRegion().isAncestor(operandRegion)) { -<<<<<<< HEAD - auto copiedOp = rewriter.clone(*definingOp); -======= auto *copiedOp = rewriter.clone(*definingOp); ->>>>>>> main rewriter.replaceAllUsesWith(operand, copiedOp->getResults()); // If this was a block argument, additionally remove the block // argument. diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 9c171eacb..b89e5d481 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -7,10 +7,7 @@ #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheOps.h" -<<<<<<< HEAD -======= #include "include/Dialect/Polynomial/IR/PolynomialDialect.h" - >>>>>>> main #include "include/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/Utils.h" @@ -31,266 +28,258 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project - namespace mlir { - namespace heir { - namespace openfhe { - - void registerToOpenFhePkeTranslation() { - TranslateFromMLIRRegistration reg( - "emit-openfhe-pke", - "translate the openfhe dialect to C++ code against the OpenFHE pke API", - [](Operation *op, llvm::raw_ostream &output) { - return translateToOpenFhePke(op, output); - }, - [](DialectRegistry ®istry) { -<<<<<<< HEAD - registry.insert(); -======= - registry.insert(); ->>>>>>> main - }); +namespace mlir { +namespace heir { +namespace openfhe { + +void registerToOpenFhePkeTranslation() { + TranslateFromMLIRRegistration reg( + "emit-openfhe-pke", + "translate the openfhe dialect to C++ code against the OpenFHE pke API", + [](Operation *op, llvm::raw_ostream &output) { + return translateToOpenFhePke(op, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} + +LogicalResult translateToOpenFhePke(Operation *op, llvm::raw_ostream &os) { + SelectVariableNames variableNames(op); + OpenFhePkeEmitter emitter(os, &variableNames); + LogicalResult result = emitter.translate(*op); + return result; +} + +LogicalResult OpenFhePkeEmitter::translate(Operation &op) { + LogicalResult status = + llvm::TypeSwitch(op) + // Builtin ops + .Case([&](auto op) { return printOperation(op); }) + // Func ops + .Case( + [&](auto op) { return printOperation(op); }) + // Arith ops + .Case([&](auto op) { return printOperation(op); }) + // OpenFHE ops + .Case([&](auto op) { return printOperation(op); }) + .Default([&](Operation &) { + return op.emitOpError("unable to find printer for op"); + }); + + if (failed(status)) { + op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); + return failure(); } + return success(); +} - LogicalResult translateToOpenFhePke(Operation *op, llvm::raw_ostream &os) { - SelectVariableNames variableNames(op); - OpenFhePkeEmitter emitter(os, &variableNames); - LogicalResult result = emitter.translate(*op); - return result; - } - - LogicalResult OpenFhePkeEmitter::translate(Operation &op) { - LogicalResult status = - llvm::TypeSwitch(op) - // Builtin ops - .Case([&](auto op) { return printOperation(op); }) - // Func ops - .Case( - [&](auto op) { return printOperation(op); }) - // Arith ops - .Case( - [&](auto op) { return printOperation(op); }) - // OpenFHE ops - .Case( - [&](auto op) { return printOperation(op); }) - .Default([&](Operation &) { - return op.emitOpError("unable to find printer for op"); - }); - - if (failed(status)) { - op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); +LogicalResult OpenFhePkeEmitter::printOperation(ModuleOp moduleOp) { + os << kModulePrelude << "\n"; + for (Operation &op : moduleOp) { + if (failed(translate(op))) { return failure(); } - return success(); } - LogicalResult OpenFhePkeEmitter::printOperation(ModuleOp moduleOp) { - os << kModulePrelude << "\n"; - for (Operation &op : moduleOp) { - if (failed(translate(op))) { - return failure(); - } - } + return success(); +} - return success(); +LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { + if (funcOp.getNumResults() != 1) { + return funcOp.emitOpError() << "Only functions with a single return type " + "are supported, but this function has " + << funcOp.getNumResults(); + return failure(); } - LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { - if (funcOp.getNumResults() != 1) { - return funcOp.emitOpError() << "Only functions with a single return type " - "are supported, but this function has " - << funcOp.getNumResults(); - return failure(); - } - - Type result = funcOp.getResultTypes()[0]; - if (failed(emitType(result))) { - return funcOp.emitOpError() << "Failed to emit type " << result; - } - - os << " " << funcOp.getName() << "("; - os.indent(); - - // Check the types without printing to enable failure outside of - // commaSeparatedValues; maybe consider making commaSeparatedValues combine - // the results into a FailureOr, like commaSeparatedTypes in tfhe_rust - // emitter. - for (Value arg : funcOp.getArguments()) { - if (failed(convertType(arg.getType()))) { - return funcOp.emitOpError() << "Failed to emit type " << arg.getType(); - } - } - - os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { - auto res = convertType(value.getType()); - return res.value() + " " + variableNames->getNameForValue(value); - }); - os.unindent(); - os << ") {\n"; - os.indent(); - - for (Block &block : funcOp.getBlocks()) { - for (Operation &op : block.getOperations()) { - if (failed(translate(op))) { - return failure(); - } - } - } - - os.unindent(); - os << "}\n"; - return success(); + Type result = funcOp.getResultTypes()[0]; + if (failed(emitType(result))) { + return funcOp.emitOpError() << "Failed to emit type " << result; } - LogicalResult OpenFhePkeEmitter::printOperation(func::ReturnOp op) { - if (op.getNumOperands() != 1) { - op.emitError() << "Only one return value supported"; - return failure(); - } - os << "return " << variableNames->getNameForValue(op.getOperands()[0]) - << ";\n"; - return success(); - } - - void OpenFhePkeEmitter::emitAssignPrefix(Value result) { - os << "auto " << variableNames->getNameForValue(result) << " = "; - } - - LogicalResult OpenFhePkeEmitter::printEvalMethod( - ::mlir::Value result, ::mlir::Value cryptoContext, - ::mlir::ValueRange nonEvalOperands, std::string_view op) { - emitAssignPrefix(result); - - os << variableNames->getNameForValue(cryptoContext) << "->" << op << "("; - os << commaSeparatedValues(nonEvalOperands, [&](Value value) { - return variableNames->getNameForValue(value); - }); - os << ");\n"; - return success(); - } - - LogicalResult OpenFhePkeEmitter::printOperation(AddOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getLhs(), op.getRhs()}, "EvalAdd"); - } + os << " " << funcOp.getName() << "("; + os.indent(); - LogicalResult OpenFhePkeEmitter::printOperation(SubOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getLhs(), op.getRhs()}, "EvalSub"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(MulOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getLhs(), op.getRhs()}, "EvalMult"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(MulPlainOp op) { - // OpenFHE defines an overload for EvalMult to work on both plaintext and - // ciphertext inputs. - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext(), op.getPlaintext()}, "EvalMult"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(MulConstOp op) { - // OpenFHE defines an overload for EvalMult to work on constant inputs, - // but only for some schemes. - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext(), op.getConstant()}, "EvalMult"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(NegateOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext()}, "EvalNegate"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(SquareOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext()}, "EvalSquare"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(RelinOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext()}, "Relinearize"); - } - - LogicalResult OpenFhePkeEmitter::printOperation(ModReduceOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext()}, "ModReduce"); + // Check the types without printing to enable failure outside of + // commaSeparatedValues; maybe consider making commaSeparatedValues combine + // the results into a FailureOr, like commaSeparatedTypes in tfhe_rust + // emitter. + for (Value arg : funcOp.getArguments()) { + if (failed(convertType(arg.getType()))) { + return funcOp.emitOpError() << "Failed to emit type " << arg.getType(); + } } - LogicalResult OpenFhePkeEmitter::printOperation(LevelReduceOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext()}, "LevelReduce"); - } + os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { + auto res = convertType(value.getType()); + return res.value() + " " + variableNames->getNameForValue(value); + }); + os.unindent(); + os << ") {\n"; + os.indent(); - LogicalResult OpenFhePkeEmitter::printOperation(RotOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext(), op.getIndex()}, "EvalRotate"); + for (Block &block : funcOp.getBlocks()) { + for (Operation &op : block.getOperations()) { + if (failed(translate(op))) { + return failure(); + } + } } - LogicalResult OpenFhePkeEmitter::printOperation(AutomorphOp op) { - // EvalAutomorphism has a bit of a strange function signature in OpenFHE: - // - // EvalAutomorphism( - // ConstCiphertext ciphertext, - // int32_t i, - // const std::map>& evalKeyMap - // ) - // - // Here i is an index to evalKeyMap, but no other data from evalKeyMap is - // used. To match the API, we emit code that just creates a single-entry map - // locally before calling EvalAutomorphism. - // - // This would probably be an easy upstream fix to add a specialized function - // call if it becomes necessary. - std::string mapName = - variableNames->getNameForValue(op.getResult()) + "evalkeymap"; - auto result = convertType(op.getEvalKey().getType()); - os << "std::map " << mapName << " = {{0, " - << variableNames->getNameForValue(op.getEvalKey()) << "}};\n"; - - emitAssignPrefix(op.getResult()); - os << variableNames->getNameForValue(op.getCryptoContext()) - << "->EvalAutomorphism("; - os << variableNames->getNameForValue(op.getCiphertext()) << ", 0, " - << mapName << ");\n"; - return success(); - } + os.unindent(); + os << "}\n"; + return success(); +} - LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) { - return printEvalMethod(op.getResult(), op.getCryptoContext(), - {op.getCiphertext(), op.getEvalKey()}, "KeySwitch"); +LogicalResult OpenFhePkeEmitter::printOperation(func::ReturnOp op) { + if (op.getNumOperands() != 1) { + op.emitError() << "Only one return value supported"; + return failure(); } - - LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) { - auto valueAttr = op.getValue(); - emitAssignPrefix(op.getResult()); - if (auto intAttr = dyn_cast(valueAttr)) { - os << intAttr.getValue() << ";\n"; - } else { - return op.emitError() - << "Unsupported constant type " << valueAttr.getType(); - } - return success(); + os << "return " << variableNames->getNameForValue(op.getOperands()[0]) + << ";\n"; + return success(); +} + +void OpenFhePkeEmitter::emitAssignPrefix(Value result) { + os << "auto " << variableNames->getNameForValue(result) << " = "; +} + +LogicalResult OpenFhePkeEmitter::printEvalMethod( + ::mlir::Value result, ::mlir::Value cryptoContext, + ::mlir::ValueRange nonEvalOperands, std::string_view op) { + emitAssignPrefix(result); + + os << variableNames->getNameForValue(cryptoContext) << "->" << op << "("; + os << commaSeparatedValues(nonEvalOperands, [&](Value value) { + return variableNames->getNameForValue(value); + }); + os << ");\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(AddOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getLhs(), op.getRhs()}, "EvalAdd"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(SubOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getLhs(), op.getRhs()}, "EvalSub"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(MulOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getLhs(), op.getRhs()}, "EvalMult"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(MulPlainOp op) { + // OpenFHE defines an overload for EvalMult to work on both plaintext and + // ciphertext inputs. + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext(), op.getPlaintext()}, "EvalMult"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(MulConstOp op) { + // OpenFHE defines an overload for EvalMult to work on constant inputs, + // but only for some schemes. + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext(), op.getConstant()}, "EvalMult"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(NegateOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "EvalNegate"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(SquareOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "EvalSquare"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(RelinOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "Relinearize"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(ModReduceOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "ModReduce"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(LevelReduceOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext()}, "LevelReduce"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(RotOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext(), op.getIndex()}, "EvalRotate"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(AutomorphOp op) { + // EvalAutomorphism has a bit of a strange function signature in OpenFHE: + // + // EvalAutomorphism( + // ConstCiphertext ciphertext, + // int32_t i, + // const std::map>& evalKeyMap + // ) + // + // Here i is an index to evalKeyMap, but no other data from evalKeyMap is + // used. To match the API, we emit code that just creates a single-entry map + // locally before calling EvalAutomorphism. + // + // This would probably be an easy upstream fix to add a specialized function + // call if it becomes necessary. + std::string mapName = + variableNames->getNameForValue(op.getResult()) + "evalkeymap"; + auto result = convertType(op.getEvalKey().getType()); + os << "std::map " << mapName << " = {{0, " + << variableNames->getNameForValue(op.getEvalKey()) << "}};\n"; + + emitAssignPrefix(op.getResult()); + os << variableNames->getNameForValue(op.getCryptoContext()) + << "->EvalAutomorphism("; + os << variableNames->getNameForValue(op.getCiphertext()) << ", 0, " << mapName + << ");\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) { + return printEvalMethod(op.getResult(), op.getCryptoContext(), + {op.getCiphertext(), op.getEvalKey()}, "KeySwitch"); +} + +LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) { + auto valueAttr = op.getValue(); + emitAssignPrefix(op.getResult()); + if (auto intAttr = dyn_cast(valueAttr)) { + os << intAttr.getValue() << ";\n"; + } else { + return op.emitError() << "Unsupported constant type " + << valueAttr.getType(); } + return success(); +} - LogicalResult OpenFhePkeEmitter::emitType(Type type) { - auto result = convertType(type); - if (failed(result)) { - return failure(); - } - os << result; - return success(); +LogicalResult OpenFhePkeEmitter::emitType(Type type) { + auto result = convertType(type); + if (failed(result)) { + return failure(); } - - OpenFhePkeEmitter::OpenFhePkeEmitter(raw_ostream &os, - SelectVariableNames *variableNames) - : os(os), variableNames(variableNames) {} - } // namespace openfhe - } // namespace heir + os << result; + return success(); +} + +OpenFhePkeEmitter::OpenFhePkeEmitter(raw_ostream &os, + SelectVariableNames *variableNames) + : os(os), variableNames(variableNames) {} +} // namespace openfhe +} // namespace heir } // namespace mlir diff --git a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp index 38a4401ca..0d2195eb0 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp @@ -3,10 +3,7 @@ #include "include/Analysis/SelectVariableNames/SelectVariableNames.h" #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/Openfhe/IR/OpenfheDialect.h" -<<<<<<< HEAD -======= #include "include/Dialect/Polynomial/IR/PolynomialDialect.h" - >>>>>>> main #include "include/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/Utils.h" @@ -25,108 +22,104 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project - namespace mlir { - namespace heir { - namespace openfhe { +namespace mlir { +namespace heir { +namespace openfhe { - void registerToOpenFhePkeHeaderTranslation() { - TranslateFromMLIRRegistration reg( - "emit-openfhe-pke-header", - "Emit a header corresponding to the C++ file generated by " - "--emit-openfhe-pke", - [](Operation *op, llvm::raw_ostream &output) { - return translateToOpenFhePkeHeader(op, output); - }, - [](DialectRegistry ®istry) { - registry.insert(); -======= - lwe::LWEDialect, polynomial::PolynomialDialect>(); ->>>>>>> main - }); - } +void registerToOpenFhePkeHeaderTranslation() { + TranslateFromMLIRRegistration reg( + "emit-openfhe-pke-header", + "Emit a header corresponding to the C++ file generated by " + "--emit-openfhe-pke", + [](Operation *op, llvm::raw_ostream &output) { + return translateToOpenFhePkeHeader(op, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} - LogicalResult translateToOpenFhePkeHeader(Operation *op, - llvm::raw_ostream &os) { - SelectVariableNames variableNames(op); - OpenFhePkeHeaderEmitter emitter(os, &variableNames); - return emitter.translate(*op); - } +LogicalResult translateToOpenFhePkeHeader(Operation *op, + llvm::raw_ostream &os) { + SelectVariableNames variableNames(op); + OpenFhePkeHeaderEmitter emitter(os, &variableNames); + return emitter.translate(*op); +} - LogicalResult OpenFhePkeHeaderEmitter::translate(Operation &op) { - LogicalResult status = - llvm::TypeSwitch(op) - .Case([&](auto op) { return printOperation(op); }) - .Case([&](auto op) { return printOperation(op); }) - .Default([&](Operation &) { - return op.emitOpError("unable to find printer for op"); - }); +LogicalResult OpenFhePkeHeaderEmitter::translate(Operation &op) { + LogicalResult status = + llvm::TypeSwitch(op) + .Case([&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) + .Default([&](Operation &) { + return op.emitOpError("unable to find printer for op"); + }); - if (failed(status)) { - op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); - return failure(); - } - return success(); + if (failed(status)) { + op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); + return failure(); } + return success(); +} - LogicalResult OpenFhePkeHeaderEmitter::printOperation(ModuleOp moduleOp) { - os << kModulePrelude << "\n"; - for (Operation &op : moduleOp) { - if (failed(translate(op))) { - return failure(); - } +LogicalResult OpenFhePkeHeaderEmitter::printOperation(ModuleOp moduleOp) { + os << kModulePrelude << "\n"; + for (Operation &op : moduleOp) { + if (failed(translate(op))) { + return failure(); } - return success(); } + return success(); +} - LogicalResult OpenFhePkeHeaderEmitter::printOperation(func::FuncOp funcOp) { - // If keeping this consistent alongside OpenFheEmitter gets annoying, - // extract to a shared function in a base class. - if (funcOp.getNumResults() != 1) { - return funcOp.emitOpError() << "Only functions with a single return type " - "are supported, but this function has " - << funcOp.getNumResults(); - return failure(); - } +LogicalResult OpenFhePkeHeaderEmitter::printOperation(func::FuncOp funcOp) { + // If keeping this consistent alongside OpenFheEmitter gets annoying, + // extract to a shared function in a base class. + if (funcOp.getNumResults() != 1) { + return funcOp.emitOpError() << "Only functions with a single return type " + "are supported, but this function has " + << funcOp.getNumResults(); + return failure(); + } - Type result = funcOp.getResultTypes()[0]; - if (failed(emitType(result))) { - return funcOp.emitOpError() << "Failed to emit type " << result; - } + Type result = funcOp.getResultTypes()[0]; + if (failed(emitType(result))) { + return funcOp.emitOpError() << "Failed to emit type " << result; + } - os << " " << funcOp.getName() << "("; - os.indent(); + os << " " << funcOp.getName() << "("; + os.indent(); - for (Value arg : funcOp.getArguments()) { - if (failed(convertType(arg.getType()))) { - return funcOp.emitOpError() << "Failed to emit type " << arg.getType(); - } + for (Value arg : funcOp.getArguments()) { + if (failed(convertType(arg.getType()))) { + return funcOp.emitOpError() << "Failed to emit type " << arg.getType(); } + } - os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { - auto res = convertType(value.getType()); - return res.value() + " " + variableNames->getNameForValue(value); - }); - os.unindent(); - os << ");\n"; - os.indent(); + os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { + auto res = convertType(value.getType()); + return res.value() + " " + variableNames->getNameForValue(value); + }); + os.unindent(); + os << ");\n"; + os.indent(); - return success(); - } + return success(); +} - LogicalResult OpenFhePkeHeaderEmitter::emitType(Type type) { - auto result = convertType(type); - if (failed(result)) { - return failure(); - } - os << result; - return success(); +LogicalResult OpenFhePkeHeaderEmitter::emitType(Type type) { + auto result = convertType(type); + if (failed(result)) { + return failure(); } + os << result; + return success(); +} - OpenFhePkeHeaderEmitter::OpenFhePkeHeaderEmitter( - raw_ostream &os, SelectVariableNames *variableNames) - : os(os), variableNames(variableNames) {} - } // namespace openfhe - } // namespace heir +OpenFhePkeHeaderEmitter::OpenFhePkeHeaderEmitter( + raw_ostream &os, SelectVariableNames *variableNames) + : os(os), variableNames(variableNames) {} +} // namespace openfhe +} // namespace heir } // namespace mlir diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index 3715e8394..206b764a4 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -211,35 +211,13 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) { } } os.unindent(); -<<<<<<< HEAD - os << ") -> "; -======= os << ")"; ->>>>>>> main if (serverKeyArg_.empty()) { return funcOp.emitWarning() << "expected server key function argument to " "create default ciphertexts"; } -<<<<<<< HEAD - if (funcOp.getNumResults() == 1) { - Type result = funcOp.getResultTypes()[0]; - if (failed(emitType(result))) { - return funcOp.emitOpError() << "Failed to emit tfhe-rs type " << result; - } - } else { - auto result = commaSeparatedTypes( - funcOp.getResultTypes(), [&](Type type) -> FailureOr { - auto result = convertType(type); - if (failed(result)) { - return funcOp.emitOpError() - << "Failed to emit tfhe-rs type " << type; - } - return result; - }); - os << "(" << result.value() << ")"; -======= if (funcOp.getNumResults() > 0) { os << " -> "; if (funcOp.getNumResults() == 1) { @@ -259,25 +237,18 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) { }); os << "(" << result.value() << ")"; } ->>>>>>> main } os << " {\n"; os.indent(); // Create a global temp_nodes hashmap for any created SSA values. -<<<<<<< HEAD - // TODO(#462): Insert block argument that are encrypted ints into temp_nodes. - os << "let mut temp_nodes : HashMap = HashMap::new();\n"; - os << "let mut luts : HashMap<&str, LookupTableOwned> = HashMap::new();\n"; -======= // TODO(#462): Insert block argument that are encrypted ints into // temp_nodes. os << "let mut temp_nodes : HashMap = " "HashMap::new();\n"; os << "let mut luts : HashMap<&str, LookupTableOwned> = " "HashMap::new();\n"; ->>>>>>> main os << kRunLevelDefn << "\n"; @@ -320,13 +291,10 @@ LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) { return variableNames->getNameForValue(value); }; -<<<<<<< HEAD -======= if (op.getNumOperands() == 0) { return success(); } ->>>>>>> main if (op.getNumOperands() == 1) { os << valueOrClonedValue(op.getOperands()[0]) << "\n"; return success(); @@ -457,15 +425,9 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) { variableNames->getIntForValue(op->getResult(0)), commaSeparatedValues( getCiphertextOperands(op->getOperands()), [&](Value value) { -<<<<<<< HEAD - // TODO(#462): This assumes that all ciphertexts are loaded - // into temp_nodes. Currently, block arguments are not - // supported. -======= // TODO(#462): This assumes that all ciphertexts are // loaded into temp_nodes. Currently, block arguments are // not supported. ->>>>>>> main return "Tv(" + std::to_string(variableNames->getIntForValue(value)) + ")"; @@ -539,27 +501,6 @@ LogicalResult TfheRustEmitter::printOperation(CreateTrivialOp op) { } LogicalResult TfheRustEmitter::printOperation(arith::ConstantOp op) { -<<<<<<< HEAD - // auto valueAttr = op.getValue(); - // if (isa(op.getType()) && - // op.getType().getIntOrFloatBitWidth() == 1) { - // os << "let " << variableNames->getNameForValue(op.getResult()) - // << " : bool = "; - // os << (cast(valueAttr).getValue().isZero() ? "false" : - // "true") - // << ";\n"; - // return success(); - // } - - os << "let cost: bool = true;"; - // emitAssignPrefix(op.getResult()); - // if (auto intAttr = dyn_cast(valueAttr)) { - // os << intAttr.getValue() << ";\n"; - // } else { - // return op.emitError() << "Unknown constant type " << - // valueAttr.getType(); - // } -======= auto valueAttr = op.getValue(); if (isa(op.getType()) && op.getType().getIntOrFloatBitWidth() == 1) { @@ -576,7 +517,6 @@ LogicalResult TfheRustEmitter::printOperation(arith::ConstantOp op) { } else { return op.emitError() << "Unknown constant type " << valueAttr.getType(); } ->>>>>>> main return success(); } @@ -622,12 +562,8 @@ LogicalResult TfheRustEmitter::printOperation(::mlir::arith::TruncIOp op) { } LogicalResult TfheRustEmitter::printOperation(tensor::ExtractOp op) { -<<<<<<< HEAD - // We assume here that the indices are SSA values (not integer attributes). -======= // We assume here that the indices are SSA values (not integer // attributes). ->>>>>>> main emitAssignPrefix(op.getResult()); os << "&" << variableNames->getNameForValue(op.getTensor()) << "[" << commaSeparatedValues( @@ -768,8 +704,6 @@ FailureOr TfheRustEmitter::convertType(Type type) { // they will need to chance to the right values once we try to compile it // against a specific API version. return llvm::TypeSwitch>(type) -<<<<<<< HEAD -======= .Case( [&](RankedTensorType type) -> FailureOr { // Tensor types are emitted as vectors @@ -797,7 +731,6 @@ FailureOr TfheRustEmitter::convertType(Type type) { std::to_string(width.value()); }) // TODO(#474): Generalize to any encrypted uint. ->>>>>>> main .Case( [&](auto type) { return std::string("Ciphertext"); }) .Case([&](auto type) { return std::string("ServerKey"); }) diff --git a/lib/Target/TfheRust/TfheRustTemplates.h b/lib/Target/TfheRust/TfheRustTemplates.h index 22e97862f..3d6e5443e 100644 --- a/lib/Target/TfheRust/TfheRustTemplates.h +++ b/lib/Target/TfheRust/TfheRustTemplates.h @@ -70,10 +70,6 @@ let mut run_level = | }; )rust"; -constexpr std::string_view kBoolModulePrelude = R"rust( -use tfhe::boolean::prelude::*; -)rust"; - } // namespace tfhe_rust } // namespace heir } // namespace mlir diff --git a/templates/Conversion/lib/ConversionPass.cpp.jinja b/templates/Conversion/lib/ConversionPass.cpp.jinja index b1bef33ca..3ebf0f162 100644 --- a/templates/Conversion/lib/ConversionPass.cpp.jinja +++ b/templates/Conversion/lib/ConversionPass.cpp.jinja @@ -11,15 +11,9 @@ namespace mlir::heir { #include "include/Conversion/{{ pass_name }}/{{ pass_name }}.h.inc" // Remove this class if no type conversions are necessary -<<<<<<< HEAD -class PassTypeConverter : public TypeConverter { - public: - PassTypeConverter(MLIRContext *ctx) { -======= class {{ pass_name }}TypeConverter : public TypeConverter { public: {{ pass_name }}TypeConverter(MLIRContext *ctx) { ->>>>>>> main addConversion([](Type type) { return type; }); // FIXME: implement, replace FooType with the type that needs // to be converted or remove this class @@ -48,11 +42,7 @@ struct {{ pass_name }} : public impl::{{ pass_name }}Base<{{ pass_name }}> { void runOnOperation() override { MLIRContext *context = &getContext(); auto *module = getOperation(); -<<<<<<< HEAD - PassTypeConverter typeConverter(context); -======= {{ pass_name }}TypeConverter typeConverter(context); ->>>>>>> main RewritePatternSet patterns(context); ConversionTarget target(*context); diff --git a/tests/bgv/ops.mlir b/tests/bgv/ops.mlir index 985f3a1aa..a2524b3bf 100644 --- a/tests/bgv/ops.mlir +++ b/tests/bgv/ops.mlir @@ -3,26 +3,6 @@ // This simply tests for syntax. -<<<<<<< HEAD -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#otherrings = #bgv.rings<#ring1> - -// CHECK: module -module { - func.func @test_multiply(%arg0 : !bgv.ciphertext, %arg1: !bgv.ciphertext) -> !bgv.ciphertext { - %add = bgv.add(%arg0, %arg1) : !bgv.ciphertext - %sub = bgv.sub(%arg0, %arg1) : !bgv.ciphertext - %neg = bgv.negate(%arg0) : !bgv.ciphertext - - %0 = bgv.mul(%arg0, %arg1) : !bgv.ciphertext -> !bgv.ciphertext - %1 = bgv.relinearize(%0) {from_basis = array, to_basis = array } : (!bgv.ciphertext) -> !bgv.ciphertext - %2 = bgv.modulus_switch(%1) {from_level = 1, to_level=0} : (!bgv.ciphertext) -> !bgv.ciphertext - // CHECK: <>, >> - return %arg0 : !bgv.ciphertext -======= #encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> @@ -60,6 +40,5 @@ module { %mul = bgv.mul_plain(%sub, %arg2) : !ct // CHECK: rlwe_params = >> return %mul : !ct ->>>>>>> main } } diff --git a/tests/bgv/to_openfhe.mlir b/tests/bgv/to_openfhe.mlir index 5508e88e7..fbd173d6f 100644 --- a/tests/bgv/to_openfhe.mlir +++ b/tests/bgv/to_openfhe.mlir @@ -1,19 +1,5 @@ // RUN: heir-opt --bgv-to-openfhe %s | FileCheck %s -<<<<<<< HEAD -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -!ct = !bgv.ciphertext -!ct_dim = !bgv.ciphertext -!ct_level = !bgv.ciphertext - -// CHECK: module -module { - // CHECK: func.func @test_fn([[X:%.+]]: [[T:.*33538049.*]]) -> [[T]] -======= #encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> @@ -34,17 +20,12 @@ module { // CHECK: module module { // CHECK: func.func @test_fn([[X:%.+]]: [[T:.*161729713.*]]) -> [[T]] ->>>>>>> main func.func @test_fn(%x : !ct) -> !ct { // CHECK: return [[X]] : [[T]] return %x : !ct } -<<<<<<< HEAD - // CHECK: func.func @test_ops([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*33538049.*]], [[Y:%.+]]: [[T]]) -======= // CHECK: func.func @test_ops([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*161729713.*]], [[Y:%.+]]: [[T]]) ->>>>>>> main func.func @test_ops(%x : !ct, %y : !ct) { // CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]] %negate = bgv.negate(%x) : !ct @@ -53,22 +34,14 @@ module { // CHECK: %[[v3:.*]] = openfhe.sub [[C]], %[[x3:.*]], %[[y3:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] %sub = bgv.sub(%x, %y) : !ct // CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] -<<<<<<< HEAD - %mul = bgv.mul(%x, %y) : !ct -> !bgv.ciphertext -======= %mul = bgv.mul(%x, %y) : !ct -> !ct_level3 ->>>>>>> main // CHECK: %[[c5:.*]] = arith.constant 4 : i64 // CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]] %rot = bgv.rotate(%x) {offset = 4}: (!ct) -> !ct return } -<<<<<<< HEAD - // CHECK: func.func @test_relin([[C]]: [[S]], [[X:%.+]]: [[T:.*33538049.*]]) -======= // CHECK: func.func @test_relin([[C]]: [[S]], [[X:%.+]]: [[T:.*161729713.*]]) ->>>>>>> main func.func @test_relin(%x : !ct_dim) { // CHECK: %[[v6:.*]] = openfhe.relin [[C]], %[[x6:.*]]: ([[S]], [[T]]) -> [[T]] %relin = bgv.relinearize(%x) { @@ -77,20 +50,10 @@ module { return } -<<<<<<< HEAD - // CHECK: func.func @test_modswitch([[C]]: [[S]], [[X:%.+]]: [[T:.*33538049.*]]) - func.func @test_modswitch(%x : !ct_level) { - // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]]: ([[S]], [[T]]) -> [[T]] - %mod_switch = bgv.modulus_switch(%x) { - from_level = 2, to_level = 1 - }: (!ct_level) -> !bgv.ciphertext - return -======= // CHECK: func.func @test_modswitch([[C]]: [[S]], [[X:%.+]]: [[T:.*161729713.*]]) -> [[T1:.*2521.*]] { func.func @test_modswitch(%x : !ct) -> !ct_level { // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]] : ([[S]], [[T]]) -> [[T1]] %mod_switch = bgv.modulus_switch(%x) { to_ring=#ring2 }: (!ct) -> !ct_level return %mod_switch : !ct_level ->>>>>>> main } } diff --git a/tests/bgv/to_openfhe_invalid.mlir b/tests/bgv/to_openfhe_invalid.mlir index ce2c0c722..a81ea78d8 100644 --- a/tests/bgv/to_openfhe_invalid.mlir +++ b/tests/bgv/to_openfhe_invalid.mlir @@ -1,17 +1,3 @@ -<<<<<<< HEAD -// RUN: not heir-opt --bgv-to-openfhe --split-input-file %s 2>&1 - -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -!ct = !bgv.ciphertext - -func.func @test_relin_to_basis_error(%x: !bgv.ciphertext) { - // expected-error@+1 {{toBasis must be [0, 1], got [0, 2]}} - %relin_error = bgv.relinearize(%x) { from_basis = array, to_basis = array }: (!bgv.ciphertext) -> !ct -======= // RUN: heir-opt --bgv-to-openfhe --split-input-file --verify-diagnostics %s 2>&1 #encoding = #lwe.polynomial_evaluation_encoding @@ -33,22 +19,10 @@ func.func @test_relin_to_basis_error(%x: !ct1) { // expected-error@+2 {{toBasis must be [0, 1], got [0, 2]}} // expected-error@+1 {{failed to legalize operation 'bgv.relinearize' that was explicitly marked illegal}} %relin_error = bgv.relinearize(%x) { from_basis = array, to_basis = array }: (!ct1) -> !ct ->>>>>>> main return } // ----- -<<<<<<< HEAD -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings2 = #bgv.rings<#ring1, #ring2, #ring2> -func.func @test_modswitch_level_error(%x: !bgv.ciphertext) { - // expected-error@+1 {{fromLevel must be toLevel + 1, got fromLevel: 2 and toLevel: 0}} - %relin_error = bgv.modulus_switch(%x) { - from_level = 2, to_level = 0 - }: (!bgv.ciphertext) -> !bgv.ciphertext -======= #encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> @@ -65,6 +39,5 @@ func.func @test_modswitch_level_error(%x: !bgv.ciphertext !ct1 ->>>>>>> main return } diff --git a/tests/bgv/to_polynomial.mlir b/tests/bgv/to_polynomial.mlir index cd2151160..dea7e6dd2 100644 --- a/tests/bgv/to_polynomial.mlir +++ b/tests/bgv/to_polynomial.mlir @@ -3,13 +3,6 @@ // This simply tests for syntax. -<<<<<<< HEAD -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> -!ct1 = !bgv.ciphertext -======= #encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> @@ -18,7 +11,6 @@ #params1 = #lwe.rlwe_params !ct1 = !lwe.rlwe_ciphertext ->>>>>>> main // CHECK: module module { @@ -51,11 +43,7 @@ module { // CHECK: [[Z1:%.+]] = polynomial.add([[X0Y1]], [[X1Y0]]) : [[P]] // CHECK: [[Z2:%.+]] = polynomial.mul([[X1]], [[Y1]]) : [[P]] // CHECK: [[Z:%.+]] = tensor.from_elements [[Z0]], [[Z1]], [[Z2]] : tensor<3x[[P]]> -<<<<<<< HEAD - %mul = bgv.mul(%x, %y) : !ct1 -> !bgv.ciphertext -======= %mul = bgv.mul(%x, %y) : !ct1 -> !lwe.rlwe_ciphertext ->>>>>>> main return } } diff --git a/tests/bgv/verifier.mlir b/tests/bgv/verifier.mlir index 5dcb4e601..7e1d9d7a3 100644 --- a/tests/bgv/verifier.mlir +++ b/tests/bgv/verifier.mlir @@ -1,15 +1,5 @@ // RUN: heir-opt --verify-diagnostics %s -<<<<<<< HEAD -#my_poly = #polynomial.polynomial<1 + x**1024> -#ring1 = #polynomial.ring -#ring2 = #polynomial.ring -#rings = #bgv.rings<#ring1, #ring2> - -func.func @test_input_dimension_error(%input: !bgv.ciphertext) { - // expected-error@+1 {{x.dim == 2 does not hold}} - %out = bgv.rotate (%input) {offset = 4} : (!bgv.ciphertext) -> !bgv.ciphertext -======= #encoding = #lwe.polynomial_evaluation_encoding #my_poly = #polynomial.polynomial<1 + x**1024> @@ -24,7 +14,6 @@ func.func @test_input_dimension_error(%input: !bgv.ciphertext !ct1 ->>>>>>> main return } diff --git a/tests/cggi/straight_line_vectorizer.mlir b/tests/cggi/straight_line_vectorizer.mlir index 510f9ef8c..490d6dd94 100644 --- a/tests/cggi/straight_line_vectorizer.mlir +++ b/tests/cggi/straight_line_vectorizer.mlir @@ -1,18 +1,11 @@ -<<<<<<< HEAD -// RUN: heir-opt --cggi-straight-line-vectorizer %s | FileCheck %s -======= // RUN: heir-opt --straight-line-vectorize %s | FileCheck %s ->>>>>>> main #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext // CHECK-LABEL: add_one -<<<<<<< HEAD -======= // CHECK-COUNT-9: cggi.lut3 ->>>>>>> main // CHECK: cggi.lut3(%[[arg1:.*]], %[[arg2:.*]], %[[arg3:.*]]) {lookup_table = 105 : ui8} : tensor<6x!lwe.lwe_ciphertext func.func @add_one(%arg0: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { %true = arith.constant true diff --git a/tests/cggi_to_tfhe_rust/add_one.mlir b/tests/cggi_to_tfhe_rust/add_one.mlir index 40d182322..3e35c5e7a 100644 --- a/tests/cggi_to_tfhe_rust/add_one.mlir +++ b/tests/cggi_to_tfhe_rust/add_one.mlir @@ -6,37 +6,6 @@ // CHECK-NOT: cggi // CHECK-NOT: lwe // CHECK-COUNT-11: tfhe_rust.apply_lookup_table -<<<<<<< HEAD -#encoding = #lwe.unspecified_bit_field_encoding -!ct_ty = !lwe.lwe_ciphertext -!pt_ty = !lwe.lwe_plaintext -func.func @add_one(%arg0: !ct_ty, %arg1: !ct_ty, %arg2: !ct_ty, %arg3: !ct_ty, %arg4: !ct_ty, %arg5: !ct_ty, %arg6: !ct_ty, %arg7: !ct_ty) -> (!ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty) { - %false = arith.constant false - %0 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %1 = lwe.trivial_encrypt %0 : !pt_ty to !ct_ty - %2 = cggi.lut3(%arg0, %arg1, %1) {lookup_table = 6 : ui8} : !ct_ty - %3 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %4 = lwe.trivial_encrypt %3 : !pt_ty to !ct_ty - %5 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %6 = lwe.trivial_encrypt %5 : !pt_ty to !ct_ty - %7 = cggi.lut3(%arg0, %4, %6) {lookup_table = 1 : ui8} : !ct_ty - %8 = cggi.lut3(%arg0, %arg1, %arg2) {lookup_table = 120 : ui8} : !ct_ty - %9 = cggi.lut3(%arg0, %arg1, %arg2) {lookup_table = 128 : ui8} : !ct_ty - %10 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %11 = lwe.trivial_encrypt %10 : !pt_ty to !ct_ty - %12 = cggi.lut3(%9, %arg3, %11) {lookup_table = 6 : ui8} : !ct_ty - %13 = cggi.lut3(%9, %arg3, %arg4) {lookup_table = 120 : ui8} : !ct_ty - %14 = cggi.lut3(%9, %arg3, %arg4) {lookup_table = 128 : ui8} : !ct_ty - %15 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %16 = lwe.trivial_encrypt %15 : !pt_ty to !ct_ty - %17 = cggi.lut3(%14, %arg5, %16) {lookup_table = 6 : ui8} : !ct_ty - %18 = cggi.lut3(%14, %arg5, %arg6) {lookup_table = 120 : ui8} : !ct_ty - %19 = cggi.lut3(%14, %arg5, %arg6) {lookup_table = 128 : ui8} : !ct_ty - %20 = lwe.encode %false {encoding = #encoding} : i1 to !pt_ty - %21 = lwe.trivial_encrypt %20 : !pt_ty to !ct_ty - %22 = cggi.lut3(%19, %arg7, %21) {lookup_table = 6 : ui8} : !ct_ty - return %22, %18, %17, %13, %12, %8, %2, %7 : !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty, !ct_ty -======= // CHECK: [[ALLOC:%.*]] = memref.alloc // CHECK: return [[ALLOC]] : memref<8x!tfhe_rust.eui3> @@ -141,5 +110,4 @@ module { memref.store %52, %alloc_0[%c7] : memref<8x!ct_ty> return %alloc_0 : memref<8x!ct_ty> } ->>>>>>> main } diff --git a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir index 646c68e81..034d04e8c 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir index 8a6e62cd8..6d4cc5fb0 100644 --- a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir +++ b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir @@ -1,11 +1,9 @@ // RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s - #encoding = #lwe.unspecified_bit_field_encoding !ct_ty = !lwe.lwe_ciphertext !pt_ty = !lwe.lwe_plaintext - // CHECK-LABEL: add_one_bool // CHECK-NOT: cggi // CHECK-NOT: lwe diff --git a/tests/lwe/types.mlir b/tests/lwe/types.mlir index 3483a565b..a0a78ac64 100644 --- a/tests/lwe/types.mlir +++ b/tests/lwe/types.mlir @@ -23,14 +23,10 @@ func.func @test_valid_lwe_ciphertext_unspecified(%arg0 : !ciphertext_noparams) - return %arg0 : !ciphertext_noparams } -<<<<<<< HEAD -#rlwe_params = #lwe.rlwe_params -======= #my_poly = #polynomial.polynomial<1 + x**1024> #ring = #polynomial.ring #rlwe_params = #lwe.rlwe_params ->>>>>>> main !ciphertext_rlwe = !lwe.rlwe_ciphertext // CHECK-LABEL: test_valid_rlwe_ciphertext diff --git a/tests/openfhe/emit_openfhe_pke.mlir b/tests/openfhe/emit_openfhe_pke.mlir index 0bc3609cf..74bb054c7 100644 --- a/tests/openfhe/emit_openfhe_pke.mlir +++ b/tests/openfhe/emit_openfhe_pke.mlir @@ -1,14 +1,10 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding -<<<<<<< HEAD -#params = #lwe.rlwe_params -======= #my_poly = #polynomial.polynomial<1 + x**16384> #ring= #polynomial.ring #params = #lwe.rlwe_params ->>>>>>> main !cc = !openfhe.crypto_context !ek = !openfhe.eval_key !pt = !lwe.rlwe_plaintext diff --git a/tests/openfhe/end_to_end/binops.mlir b/tests/openfhe/end_to_end/binops.mlir index dc6485995..1d0e4e30d 100644 --- a/tests/openfhe/end_to_end/binops.mlir +++ b/tests/openfhe/end_to_end/binops.mlir @@ -1,12 +1,8 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding -<<<<<<< HEAD -#params = #lwe.rlwe_params -======= #ring = #polynomial.ring> #params = #lwe.rlwe_params ->>>>>>> main !cc = !openfhe.crypto_context !ct = !lwe.rlwe_ciphertext diff --git a/tests/openfhe/ops.mlir b/tests/openfhe/ops.mlir index 24746a607..265be8619 100644 --- a/tests/openfhe/ops.mlir +++ b/tests/openfhe/ops.mlir @@ -2,14 +2,10 @@ // This simply tests for syntax. #encoding = #lwe.polynomial_evaluation_encoding -<<<<<<< HEAD -#params = #lwe.rlwe_params -======= #my_poly = #polynomial.polynomial<1 + x**16384> #ring= #polynomial.ring #params = #lwe.rlwe_params ->>>>>>> main !pk = !openfhe.public_key !ek = !openfhe.eval_key !cc = !openfhe.crypto_context diff --git a/tests/secret/distribute_generic.mlir b/tests/secret/distribute_generic.mlir index 4219cb393..b0433ef5d 100644 --- a/tests/secret/distribute_generic.mlir +++ b/tests/secret/distribute_generic.mlir @@ -19,15 +19,9 @@ func.func @test_distribute_generic(%value: !secret.secret, %cond: i1) -> !s // CHECK-NEXT: secret.yield %[[g1_op]] : i32 // CHECK-NEXT: } -> !secret.secret -<<<<<<< HEAD - // CHECK-NEXT: %[[g2:.*]] = secret.generic ins(%[[g1]], %[[g1]] : !secret.secret, !secret.secret) { - // CHECK-NEXT: ^[[bb2:.*]](%[[clear_g2_in0:.*]]: i32, %[[clear_g2_in1:.*]]: i32): - // CHECK-NEXT: %[[g2_op:.*]] = arith.muli %[[clear_g2_in0]], %[[clear_g2_in1]] : i32 -======= // CHECK-NEXT: %[[g2:.*]] = secret.generic ins(%[[g1]] : !secret.secret) { // CHECK-NEXT: ^[[bb2:.*]](%[[clear_g2_in0:.*]]: i32): // CHECK-NEXT: %[[g2_op:.*]] = arith.muli %[[clear_g2_in0]], %[[clear_g2_in0]] : i32 ->>>>>>> main // CHECK-NEXT: secret.yield %[[g2_op]] : i32 // CHECK-NEXT: } -> !secret.secret diff --git a/tests/secret/distribute_generic_flags.mlir b/tests/secret/distribute_generic_flags.mlir index 2dbf85b46..796a51100 100644 --- a/tests/secret/distribute_generic_flags.mlir +++ b/tests/secret/distribute_generic_flags.mlir @@ -27,8 +27,6 @@ func.func @test_affine_for( } -> () func.return %data : !secret.secret> } -<<<<<<< HEAD -======= // CHECK-LABEL: test_affine_for_split_end // CHECK-SAME: %[[value:.*]]: !secret.secret @@ -187,4 +185,3 @@ func.func @affine_for_hello_world_reproducer(%arg0: !secret.secret !secret.secret> return %0 : !secret.secret> } ->>>>>>> main diff --git a/tests/secretize/wrap_generic.mlir b/tests/secretize/wrap_generic.mlir index 3f2e5e8b8..ae46c9e15 100644 --- a/tests/secretize/wrap_generic.mlir +++ b/tests/secretize/wrap_generic.mlir @@ -46,8 +46,6 @@ module { func.return %1 : i32 } } -<<<<<<< HEAD -======= // ----- @@ -86,4 +84,3 @@ module { return %alloc : memref<1x80xi8> } } ->>>>>>> main diff --git a/tests/tfhe_rust/end_to_end/src/boolean/main.rs b/tests/tfhe_rust/end_to_end/src/boolean/main.rs deleted file mode 100644 index 8497e27f5..000000000 --- a/tests/tfhe_rust/end_to_end/src/boolean/main.rs +++ /dev/null @@ -1,26 +0,0 @@ -use clap::Parser; - -mod fn_under_test; - -// TODO(https://github.com/google/heir/issues/235): improve generality -#[derive(Parser, Debug)] -struct Args { - /// arguments to forward to function under test - #[arg(id = "input_1", index = 1, action)] - input1: u8, - - #[arg(id = "input_2", index = 2, action)] - input2: u8, -} - -fn main() { - let flags = Args::parse(); - let (client_key, server_key) = tfhe::boolean::gen_keys(); - - let ct_1 = client_key.encrypt(flags.input1 == 0u8); - let ct_2 = client_key.encrypt(flags.input2 == 0u8); - - let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); - let output: bool = client_key.decrypt(&result); - println!("{:?}", output as u8); -} diff --git a/tests/tfhe_rust/end_to_end/src/shortint/main.rs b/tests/tfhe_rust/end_to_end/src/shortint/main.rs deleted file mode 100644 index 2deda149b..000000000 --- a/tests/tfhe_rust/end_to_end/src/shortint/main.rs +++ /dev/null @@ -1,36 +0,0 @@ -use clap::Parser; -use tfhe::shortint::parameters::get_parameters_from_message_and_carry; - -mod fn_under_test; - -// TODO(#235): improve generality -#[derive(Parser, Debug)] -struct Args { - #[arg(id = "message_bits", long)] - message_bits: usize, - - #[arg(id = "carry_bits", long, default_value = "2")] - carry_bits: usize, - - /// arguments to forward to function under test - #[arg(id = "input_1", index = 1)] - input1: u8, - - #[arg(id = "input_2", index = 2)] - input2: u8, -} - -fn main() { - let flags = Args::parse(); - - let parameters = get_parameters_from_message_and_carry((1 << flags.message_bits) - 1, flags.carry_bits); - let (client_key, server_key) = tfhe::shortint::gen_keys(parameters); - - let ct_1 = client_key.encrypt(flags.input1.into()); - let ct_2 = client_key.encrypt(flags.input2.into()); - - let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); - let output = client_key.decrypt(&result); - println!("{:?}", output); - -} diff --git a/tests/tfhe_rust/end_to_end/test_b_bitand.mlir b/tests/tfhe_rust/end_to_end/test_b_bitand.mlir deleted file mode 100644 index ce77db90d..000000000 --- a/tests/tfhe_rust/end_to_end/test_b_bitand.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// This test ensures the testing harness is working properly with minimal codegen. - -// RUN: heir-translate %s --emit-tfhe-rust > %S/src/boolean/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_bool -- 1 0 | FileCheck %s - -!bsks = !tfhe_rust.bool_server_key -!eb = !tfhe_rust.eb - -// CHECK: 0 -func.func @fn_under_test(%bsks : !bsks, %a: !eb, %b: !eb) -> !eb { - %res = tfhe_rust.bitand %bsks, %a, %b: (!bsks, !eb, !eb) -> !eb - return %res : !eb -} diff --git a/tests/tfhe_rust/end_to_end/test_si_add.mlir b/tests/tfhe_rust/end_to_end/test_si_add.mlir deleted file mode 100644 index 55064bd7a..000000000 --- a/tests/tfhe_rust/end_to_end/test_si_add.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// This test ensures the testing harness is working properly with minimal codegen. - -// RUN: heir-translate %s --emit-tfhe-rust > %S/src/shortint/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 2 3 --message_bits=3 | FileCheck %s - -!sks = !tfhe_rust.server_key -!lut = !tfhe_rust.lookup_table -!eui3 = !tfhe_rust.eui3 - -// CHECK: 5 -func.func @fn_under_test(%sks : !sks, %a: !eui3, %b: !eui3) -> !eui3 { - %res = tfhe_rust.add %sks, %a, %b: (!sks, !eui3, !eui3) -> !eui3 - return %res : !eui3 -} diff --git a/tests/tfhe_rust/end_to_end/test_si_simple_lut.mlir b/tests/tfhe_rust/end_to_end/test_si_simple_lut.mlir deleted file mode 100644 index 09a5fe33e..000000000 --- a/tests/tfhe_rust/end_to_end/test_si_simple_lut.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// This test ensures the testing harness is working properly with minimal codegen. - -// RUN: heir-translate %s --emit-tfhe-rust > %S/src/shortint/fn_under_test.rs -// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main -- 1 0 --message_bits=3 | FileCheck %s - -!sks = !tfhe_rust.server_key -!lut = !tfhe_rust.lookup_table -!eui3 = !tfhe_rust.eui3 - -// We're computing, effectively (0b00000111 >> (1 << 1)) & 1, i.e., 0b111 >> 2 -// CHECK: 1 -func.func @fn_under_test(%sks : !sks, %a: !eui3, %b: !eui3) -> !eui3 { - %lut = tfhe_rust.generate_lookup_table %sks {truthTable = 7 : ui8} : (!sks) -> !lut - %c1 = arith.constant 1 : i8 - %0 = tfhe_rust.scalar_left_shift %sks, %a, %c1 : (!sks, !eui3, i8) -> !eui3 - %1 = tfhe_rust.add %sks, %0, %b : (!sks, !eui3, !eui3) -> !eui3 - %2 = tfhe_rust.apply_lookup_table %sks, %1, %lut : (!sks, !eui3, !lut) -> !eui3 - return %2 : !eui3 -} diff --git a/tests/tfhe_rust/ops.mlir b/tests/tfhe_rust/ops.mlir index 5f35388f2..a7284621b 100644 --- a/tests/tfhe_rust/ops.mlir +++ b/tests/tfhe_rust/ops.mlir @@ -1,19 +1,12 @@ -<<<<<<< HEAD -// RUN: heir-translate %s | FileCheck %s -======= // RUN: heir-opt %s | FileCheck %s // RUN: heir-translate --emit-tfhe-rust %s | FileCheck --check-prefix=RS %s ->>>>>>> main // This simply tests for syntax. !sks = !tfhe_rust.server_key module { // CHECK-LABEL: func @test_create_trivial -<<<<<<< HEAD -======= // RS-LABEL: pub fn test_create_trivial ->>>>>>> main func.func @test_create_trivial(%sks : !sks) { %0 = arith.constant 1 : i8 %1 = arith.constant 1 : i3 @@ -25,10 +18,7 @@ module { } // CHECK-LABEL: func @test_bitand -<<<<<<< HEAD -======= // RS-LABEL: pub fn test_bitand ->>>>>>> main func.func @test_bitand(%sks : !sks) { %0 = arith.constant 1 : i1 %1 = arith.constant 1 : i1 @@ -41,10 +31,7 @@ module { // CHECK-LABEL: func @test_apply_lookup_table -<<<<<<< HEAD -======= // RS-LABEL: pub fn test_apply_lookup_table ->>>>>>> main func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) { %0 = arith.constant 1 : i3 %1 = arith.constant 2 : i3 diff --git a/tests/tfhe_rust_bool/end_to_end/Cargo.toml b/tests/tfhe_rust_bool/end_to_end/Cargo.toml index d85ad6362..6e6fa0a3c 100644 --- a/tests/tfhe_rust_bool/end_to_end/Cargo.toml +++ b/tests/tfhe_rust_bool/end_to_end/Cargo.toml @@ -12,10 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "x86_64-unix"] } [[bin]] name = "main" path = "src/main.rs" -<<<<<<< HEAD [[bin]] name = "main_bool_add" path = "src/main_bool_add.rs" -======= ->>>>>>> main diff --git a/tests/tfhe_rust_bool/end_to_end_fpga/README.md b/tests/tfhe_rust_bool/end_to_end_fpga/README.md index 280f60ac5..4eb3d1103 100644 --- a/tests/tfhe_rust_bool/end_to_end_fpga/README.md +++ b/tests/tfhe_rust_bool/end_to_end_fpga/README.md @@ -1,4 +1,4 @@ -# End to end Rust codegen tests - Boolean +# End to end Rust codegen tests - Boolean FPGA These tests exercise Rust codegen for the [tfhe-rs](https://github.com/zama-ai/tfhe-rs) backend library, including diff --git a/tests/tfhe_rust_bool/ops.mlir b/tests/tfhe_rust_bool/ops.mlir index 4dbc42729..1eb7f7353 100644 --- a/tests/tfhe_rust_bool/ops.mlir +++ b/tests/tfhe_rust_bool/ops.mlir @@ -11,13 +11,8 @@ module { %0 = arith.constant 1 : i1 %1 = arith.constant 0 : i1 -<<<<<<< HEAD - // %e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !eb - // %e2 = tfhe_rust_bool.create_trivial %bsks, %1 : (!bsks, i1) -> !eb -======= %e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !tfhe_rust_bool.eb %e2 = tfhe_rust_bool.create_trivial %bsks, %1 : (!bsks, i1) -> !tfhe_rust_bool.eb ->>>>>>> main return } @@ -26,15 +21,9 @@ module { %0 = arith.constant 1 : i1 %1 = arith.constant 1 : i1 -<<<<<<< HEAD - %e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !eb - %e2 = tfhe_rust_bool.create_trivial %bsks, %1 : (!bsks, i1) -> !eb - %out = tfhe_rust_bool.and %bsks, %e1, %e2: (!bsks, !eb, !eb) -> !eb -======= %e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !tfhe_rust_bool.eb %e2 = tfhe_rust_bool.create_trivial %bsks, %1 : (!bsks, i1) -> !tfhe_rust_bool.eb %out = tfhe_rust_bool.and %bsks, %e1, %e2: (!bsks, !tfhe_rust_bool.eb, !tfhe_rust_bool.eb) -> !tfhe_rust_bool.eb ->>>>>>> main return } diff --git a/tests/yosys_optimizer/unroll_and_optimize.mlir b/tests/yosys_optimizer/unroll_and_optimize.mlir index 997719ea1..343fd17d5 100644 --- a/tests/yosys_optimizer/unroll_and_optimize.mlir +++ b/tests/yosys_optimizer/unroll_and_optimize.mlir @@ -198,11 +198,8 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) { // // Extracted plaintext arith op // CHECK-NEXT: %[[index_minus_one:.*]] = arith.subi %[[index]], %[[c1]] -<<<<<<< HEAD -======= // Same deal, but for second unwrapped loop iteration marked by SECOND_SUB // CHECK-NEXT: arith.subi ->>>>>>> main // // Extracted load that can only be extracted because the previous // arith op was extracted. @@ -213,12 +210,7 @@ func.func @cumulative_sums(%arg0: !in_ty) -> (!out_ty) { // CHECK-NEXT: secret.yield // CHECK-NEXT: } // -<<<<<<< HEAD -// Same deal, but for second unwrapped loop iteration -// CHECK-NEXT: arith.subi -======= // mark: SECOND_SUB ->>>>>>> main // CHECK-NEXT: secret.generic // CHECK-NEXT: bb // CHECK-NEXT: memref.load diff --git a/tools/heir-lsp.cpp b/tools/heir-lsp.cpp index ef383e333..a50cfc864 100644 --- a/tools/heir-lsp.cpp +++ b/tools/heir-lsp.cpp @@ -6,10 +6,7 @@ #include "include/Dialect/PolyExt/IR/PolyExtDialect.h" #include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Dialect/Secret/IR/SecretDialect.h" -<<<<<<< HEAD -======= #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" - >>>>>>> main #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -23,7 +20,7 @@ #include "mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project - using namespace mlir; +using namespace mlir; using namespace heir; int main(int argc, char **argv) { @@ -39,10 +36,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); -<<<<<<< HEAD -======= registry.insert(); ->>>>>>> main // Add expected MLIR dialects to the registry. registerAllDialects(registry); diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 1a8093776..14dade8cd 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -20,21 +20,14 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/Transforms/DistributeGeneric.h" #include "include/Dialect/Secret/Transforms/Passes.h" -<<<<<<< HEAD -======= #include "include/Dialect/TensorExt/IR/TensorExtDialect.h" #include "include/Dialect/TensorExt/Transforms/Passes.h" - >>>>>>> main #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" - <<<<<<>>>>>> - main #include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project @@ -75,7 +68,7 @@ #include "include/Transforms/YosysOptimizer/YosysOptimizer.h" #endif - using namespace mlir; +using namespace mlir; using namespace tosa; using namespace heir; using mlir::func::FuncOp; @@ -280,10 +273,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); -<<<<<<< HEAD -======= registry.insert(); ->>>>>>> main // Add expected MLIR dialects to the registry. registry.insert(); @@ -303,17 +293,11 @@ int main(int argc, char **argv) { cggi::registerCGGIPasses(); lwe::registerLWEPasses(); secret::registerSecretPasses(); -<<<<<<< HEAD - registerSecretizePasses(); - registerFullLoopUnrollPasses(); - registerForwardStoreToLoadPasses(); -======= tensor_ext::registerTensorExtPasses(); registerSecretizePasses(); registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); registerStraightLineVectorizerPasses(); ->>>>>>> main // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS const char *abcEnvPath = std::getenv("HEIR_ABC_BINARY");