Skip to content

Commit

Permalink
Add secret-to-bgv conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615502361
  • Loading branch information
asraa authored and Copybara-Service committed Mar 14, 2024
1 parent ca5387e commit 7f15c71
Show file tree
Hide file tree
Showing 13 changed files with 361 additions and 6 deletions.
37 changes: 37 additions & 0 deletions include/Conversion/SecretToBGV/BUILD
@@ -0,0 +1,37 @@
# SecretToBGV tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"SecretToBGV.h",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=SecretToBGV",
],
"SecretToBGV.h.inc",
),
(
["-gen-pass-doc"],
"SecretToBGV.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "SecretToBGV.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
16 changes: 16 additions & 0 deletions include/Conversion/SecretToBGV/SecretToBGV.h
@@ -0,0 +1,16 @@
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {

#define GEN_PASS_DECL
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc"

} // namespace mlir::heir

#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_
15 changes: 15 additions & 0 deletions include/Conversion/SecretToBGV/SecretToBGV.td
@@ -0,0 +1,15 @@
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_

include "mlir/Pass/PassBase.td"

def SecretToBGV : Pass<"secret-to-bgv"> {
let summary = "Lower `secret` to `bgv` dialect.";
let dependentDialects = [
"mlir::heir::secret::SecretDialect",
"mlir::heir::bgv::BGVDialect",
"mlir::heir::lwe::LWEDialect",
];
}

#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_
2 changes: 1 addition & 1 deletion include/Dialect/BGV/IR/BGVOps.td
Expand Up @@ -27,7 +27,7 @@ class BGV_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
[AllTypesMatch<["x", "output"]>,
TypesMatchWith<"type of 'y' matches encoding type of 'x'",
"output", "y",
"lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast<lwe::RLWECiphertextType>($_self).getEncoding())">]> :
"lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast<lwe::RLWECiphertextType>($_self).getEncoding(), ::llvm::cast<lwe::RLWECiphertextType>($_self).getRlweParams().getRing())">]> :
BGV_Op<mnemonic, traits> {
let arguments = (ins
RLWECiphertext:$x,
Expand Down
4 changes: 3 additions & 1 deletion include/Dialect/LWE/IR/LWEAttributes.td
Expand Up @@ -296,6 +296,8 @@ def RLWE_InverseCanonicalEmbeddingEncoding
}];
}

def AnyRLWEEncodingAttr : AnyAttrOf<[RLWE_PolynomialCoefficientEncoding, RLWE_PolynomialEvaluationEncoding, RLWE_InverseCanonicalEmbeddingEncoding]>;

def LWE_LWEParams : AttrDef<LWE_Dialect, "LWEParams"> {
let mnemonic = "lwe_params";

Expand All @@ -317,7 +319,7 @@ def LWE_RLWEParams : AttrDef<LWE_Dialect, "RLWEParams"> {

let parameters = (ins
DefaultValuedParameter<"unsigned", "2">:$dimension,
"::mlir::heir::polynomial::RingAttr":$ring
OptionalParameter<"::mlir::heir::polynomial::RingAttr">:$ring
);

let assemblyFormat = "`<` struct(params) `>`";
Expand Down
33 changes: 32 additions & 1 deletion include/Dialect/LWE/IR/LWEOps.td
Expand Up @@ -3,6 +3,8 @@

include "include/Dialect/LWE/IR/LWEDialect.td"
include "include/Dialect/LWE/IR/LWETypes.td"
include "include/Dialect/Polynomial/IR/PolynomialAttributes.td"

include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -24,7 +26,7 @@ def LWE_EncodeOp : LWE_Op<"encode", [Pure]> {
Examples:

```
%Y = lwe.encode %value {encoding = #enc}: i1 to !lwe.lwe_plaintext<encoding = #enc, ring = #ring>
%Y = lwe.encode %value {encoding = #enc}: i1 to !lwe.lwe_plaintext<encoding = #enc>
```
}];

Expand Down Expand Up @@ -59,4 +61,33 @@ def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> {
let hasVerifier = 1;
}

def LWE_RLWEEncodeOp : LWE_Op<"rlwe_encode", [Pure]> {
let summary = "Encode an integer to yield an RLWE plaintext";
let description = [{
Encode an integer to yield an RLWE plaintext.

This op uses a an encoding attribute to encode the bits of the integer into
an RLWE plaintext value that can then be encrypted.

Examples:

```
%Y = lwe.rlwe_encode %value {encoding = #enc, ring = #ring}: i1 to !lwe.rlwe_plaintext<encoding = #enc, ring = #ring>
```
}];

let arguments = (ins
SignlessIntegerLike:$plaintext,
OptionalAttr<AnyRLWEEncodingAttr>:$encoding,
OptionalAttr<Ring_Attr>:$ring
);

let results = (outs RLWEPlaintext:$output);
let assemblyFormat = "$plaintext attr-dict `:` qualified(type($plaintext)) `to` qualified(type($output))";

// Verify that the encoding and ring parameter matches the output plaintext attribute.
let hasVerifier = 1;
}


#endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWEOPS_TD_
7 changes: 4 additions & 3 deletions include/Dialect/LWE/IR/LWETypes.td
Expand Up @@ -42,8 +42,8 @@ def RLWECiphertext : LWE_Type<"RLWECiphertext", "rlwe_ciphertext"> {
let summary = "A type for RLWE ciphertexts";

let parameters = (ins
"::mlir::Attribute":$encoding,
OptionalParameter<"RLWEParamsAttr">:$rlwe_params
OptionalParameter<"::mlir::Attribute">:$encoding,
"RLWEParamsAttr":$rlwe_params
);

let assemblyFormat = "`<` struct(params) `>`";
Expand All @@ -70,7 +70,8 @@ def RLWEPlaintext : LWE_Type<"RLWEPlaintext", "rlwe_plaintext"> {
let summary = "A type for RLWE plaintexts";

let parameters = (ins
"::mlir::Attribute":$encoding
OptionalParameter<"::mlir::Attribute">:$encoding,
OptionalParameter<"::mlir::heir::polynomial::RingAttr">:$ring
);

let assemblyFormat = "`<` struct(params) `>`";
Expand Down
26 changes: 26 additions & 0 deletions lib/Conversion/SecretToBGV/BUILD
@@ -0,0 +1,26 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "SecretToBGV",
srcs = ["SecretToBGV.cpp"],
hdrs = [
"@heir//include/Conversion/SecretToBGV:SecretToBGV.h",
],
deps = [
"@heir//include/Conversion/SecretToBGV:pass_inc_gen",
"@heir//lib/Conversion:Utils",
"@heir//lib/Dialect/BGV/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
177 changes: 177 additions & 0 deletions lib/Conversion/SecretToBGV/SecretToBGV.cpp
@@ -0,0 +1,177 @@
#include "include/Conversion/SecretToBGV/SecretToBGV.h"

#include <cassert>
#include <utility>

#include "include/Dialect/BGV/IR/BGVDialect.h"
#include "include/Dialect/BGV/IR/BGVOps.h"
#include "include/Dialect/LWE/IR/LWEAttributes.h"
#include "include/Dialect/LWE/IR/LWEOps.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir {

#define GEN_PASS_DEF_SECRETTOBGV
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc"

namespace {

// Templatized function for replacing an arithmetic operation T with a BGV
// equivalent operation Y.
template <typename T, typename Y>
void replaceOp(ConversionPatternRewriter &rewriter, secret::GenericOp op,
ValueRange inputs) {
rewriter.replaceOpWithNewOp<Y>(op, inputs);
}

template <>
void replaceOp<arith::MulIOp, bgv::MulOp>(ConversionPatternRewriter &rewriter,
secret::GenericOp op,
ValueRange inputs) {
rewriter.replaceOpWithNewOp<bgv::Relinearize>(
op, rewriter.create<bgv::MulOp>(op.getLoc(), inputs),
rewriter.getDenseI32ArrayAttr({0, 1, 2}),
rewriter.getDenseI32ArrayAttr({0, 1}));
}

template <>
void replaceOp<arith::AddIOp, bgv::AddPlainOp>(
ConversionPatternRewriter &rewriter, secret::GenericOp op,
ValueRange inputs) {
// Order inputs ciphertext then plaintext.
if (isa<lwe::RLWEPlaintextType>(inputs[0].getType())) {
inputs = {inputs.end(), inputs.begin()};
}
rewriter.replaceOpWithNewOp<bgv::AddPlainOp>(op, inputs);
}

template <>
void replaceOp<arith::MulIOp, bgv::MulPlainOp>(
ConversionPatternRewriter &rewriter, secret::GenericOp op,
ValueRange inputs) {
// Order input ciphertext then plaintext.
SmallVector<Value> orderedInputs;
for (Value input : inputs) {
if (isa<lwe::RLWECiphertextType>(input.getType())) {
orderedInputs.push_back(input);
}
}
for (Value input : inputs) {
if (isa<lwe::RLWEPlaintextType>(input.getType())) {
orderedInputs.push_back(input);
}
}
rewriter.replaceOpWithNewOp<bgv::MulPlainOp>(op, orderedInputs);
}

// Default RLWE Parameters

} // namespace

// Remove this class if no type conversions are necessary
class SecretToBGVTypeConverter : public TypeConverter {
public:
SecretToBGVTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });

// Convert secret types to BGV ciphertext types
addConversion([ctx](secret::SecretType type) -> Type {
return lwe::RLWECiphertextType::get(ctx, {},
lwe::RLWEParamsAttr::get(ctx, 2, {}));
});
}
};

template <typename T, typename Y, typename Z>
class SecretGenericOpConversion
: public OpConversionPattern<secret::GenericOp> {
public:
using OpConversionPattern<secret::GenericOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
secret::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (op.getBody()->getOperations().size() > 2) {
// Each secret.generic should contain at most one instruction -
// secret-distribute-generic can be used to distribute through the
// arithmetic ops.
return failure();
}

auto &innerOp = op.getBody()->getOperations().front();
if (!isa<T>(innerOp)) {
return failure();
}

// Assemble the arguments for the BGV operation.
SmallVector<Value> inputs;
bool ciphertextOnly = true;
for (OpOperand &operand : innerOp.getOpOperands()) {
if (auto *secretArg = op.getOpOperandForBlockArgument(operand.get())) {
inputs.push_back(
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]);
} else {
ciphertextOnly = false;
// This isn't a block argument - this must be a non-secret value used
// from the ambient scope.
assert(operand.get().getType().isSignlessInteger() &&
"expected signless integer like argument to ops");
auto encodeOp = rewriter.create<lwe::RLWEEncodeOp>(
op->getLoc(),
lwe::RLWEPlaintextType::get(rewriter.getContext(), {}, {}),
operand.get());
encodeOp.dump();
inputs.push_back(encodeOp.getResult());
}
}

if (ciphertextOnly) {
// Directly convert the op if all operands are ciphertext.
replaceOp<T, Y>(rewriter, op, ValueRange(inputs));
return success();
}

// One of the arguments must be a plaintext. Trivially encrypt and apply the
// binary operation.
replaceOp<T, Z>(rewriter, op, ValueRange(inputs));
return success();
}
};

struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
SecretToBGVTypeConverter typeConverter(context);

RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<bgv::BGVDialect>();
target.addIllegalDialect<secret::SecretDialect>();
target.addIllegalOp<secret::GenericOp>();

addStructuralConversionPatterns(typeConverter, patterns, target);
patterns.add<
SecretGenericOpConversion<arith::AddIOp, bgv::AddOp, bgv::AddPlainOp>,
SecretGenericOpConversion<arith::MulIOp, bgv::MulOp, bgv::MulPlainOp>>(
typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace mlir::heir
12 changes: 12 additions & 0 deletions lib/Dialect/LWE/IR/LWEDialect.cpp
Expand Up @@ -158,6 +158,18 @@ LogicalResult EncodeOp::verify() {
return success();
}

LogicalResult RLWEEncodeOp::verify() {
auto encodingAttr = this->getEncodingAttr();
auto outEncoding = this->getOutput().getType().getEncoding();

if (encodingAttr != outEncoding) {
return this->emitOpError()
<< "encoding attr must match output LWE plaintext encoding";
}

return success();
}

LogicalResult TrivialEncryptOp::verify() {
auto paramsAttr = this->getParamsAttr();
auto outParamsAttr = this->getOutput().getType().getLweParams();
Expand Down

0 comments on commit 7f15c71

Please sign in to comment.