Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 615502361
- Loading branch information
Showing
13 changed files
with
361 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SecretToBGV tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
[ | ||
"SecretToBGV.h", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=SecretToBGV", | ||
], | ||
"SecretToBGV.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"SecretToBGV.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "SecretToBGV.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ | ||
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir::heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
} // namespace mlir::heir | ||
|
||
#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ | ||
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def SecretToBGV : Pass<"secret-to-bgv"> { | ||
let summary = "Lower `secret` to `bgv` dialect."; | ||
let dependentDialects = [ | ||
"mlir::heir::secret::SecretDialect", | ||
"mlir::heir::bgv::BGVDialect", | ||
"mlir::heir::lwe::LWEDialect", | ||
]; | ||
} | ||
|
||
#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "SecretToBGV", | ||
srcs = ["SecretToBGV.cpp"], | ||
hdrs = [ | ||
"@heir//include/Conversion/SecretToBGV:SecretToBGV.h", | ||
], | ||
deps = [ | ||
"@heir//include/Conversion/SecretToBGV:pass_inc_gen", | ||
"@heir//lib/Conversion:Utils", | ||
"@heir//lib/Dialect/BGV/IR:Dialect", | ||
"@heir//lib/Dialect/LWE/IR:Dialect", | ||
"@heir//lib/Dialect/Secret/IR:Dialect", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h" | ||
|
||
#include <cassert> | ||
#include <utility> | ||
|
||
#include "include/Dialect/BGV/IR/BGVDialect.h" | ||
#include "include/Dialect/BGV/IR/BGVOps.h" | ||
#include "include/Dialect/LWE/IR/LWEAttributes.h" | ||
#include "include/Dialect/LWE/IR/LWEOps.h" | ||
#include "include/Dialect/LWE/IR/LWETypes.h" | ||
#include "include/Dialect/Secret/IR/SecretOps.h" | ||
#include "include/Dialect/Secret/IR/SecretTypes.h" | ||
#include "lib/Conversion/Utils.h" | ||
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project | ||
|
||
namespace mlir::heir { | ||
|
||
#define GEN_PASS_DEF_SECRETTOBGV | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
namespace { | ||
|
||
// Templatized function for replacing an arithmetic operation T with a BGV | ||
// equivalent operation Y. | ||
template <typename T, typename Y> | ||
void replaceOp(ConversionPatternRewriter &rewriter, secret::GenericOp op, | ||
ValueRange inputs) { | ||
rewriter.replaceOpWithNewOp<Y>(op, inputs); | ||
} | ||
|
||
template <> | ||
void replaceOp<arith::MulIOp, bgv::MulOp>(ConversionPatternRewriter &rewriter, | ||
secret::GenericOp op, | ||
ValueRange inputs) { | ||
rewriter.replaceOpWithNewOp<bgv::Relinearize>( | ||
op, rewriter.create<bgv::MulOp>(op.getLoc(), inputs), | ||
rewriter.getDenseI32ArrayAttr({0, 1, 2}), | ||
rewriter.getDenseI32ArrayAttr({0, 1})); | ||
} | ||
|
||
template <> | ||
void replaceOp<arith::AddIOp, bgv::AddPlainOp>( | ||
ConversionPatternRewriter &rewriter, secret::GenericOp op, | ||
ValueRange inputs) { | ||
// Order inputs ciphertext then plaintext. | ||
if (isa<lwe::RLWEPlaintextType>(inputs[0].getType())) { | ||
inputs = {inputs.end(), inputs.begin()}; | ||
} | ||
rewriter.replaceOpWithNewOp<bgv::AddPlainOp>(op, inputs); | ||
} | ||
|
||
template <> | ||
void replaceOp<arith::MulIOp, bgv::MulPlainOp>( | ||
ConversionPatternRewriter &rewriter, secret::GenericOp op, | ||
ValueRange inputs) { | ||
// Order input ciphertext then plaintext. | ||
SmallVector<Value> orderedInputs; | ||
for (Value input : inputs) { | ||
if (isa<lwe::RLWECiphertextType>(input.getType())) { | ||
orderedInputs.push_back(input); | ||
} | ||
} | ||
for (Value input : inputs) { | ||
if (isa<lwe::RLWEPlaintextType>(input.getType())) { | ||
orderedInputs.push_back(input); | ||
} | ||
} | ||
rewriter.replaceOpWithNewOp<bgv::MulPlainOp>(op, orderedInputs); | ||
} | ||
|
||
// Default RLWE Parameters | ||
|
||
} // namespace | ||
|
||
// Remove this class if no type conversions are necessary | ||
class SecretToBGVTypeConverter : public TypeConverter { | ||
public: | ||
SecretToBGVTypeConverter(MLIRContext *ctx) { | ||
addConversion([](Type type) { return type; }); | ||
|
||
// Convert secret types to BGV ciphertext types | ||
addConversion([ctx](secret::SecretType type) -> Type { | ||
return lwe::RLWECiphertextType::get(ctx, {}, | ||
lwe::RLWEParamsAttr::get(ctx, 2, {})); | ||
}); | ||
} | ||
}; | ||
|
||
template <typename T, typename Y, typename Z> | ||
class SecretGenericOpConversion | ||
: public OpConversionPattern<secret::GenericOp> { | ||
public: | ||
using OpConversionPattern<secret::GenericOp>::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
secret::GenericOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const final { | ||
if (op.getBody()->getOperations().size() > 2) { | ||
// Each secret.generic should contain at most one instruction - | ||
// secret-distribute-generic can be used to distribute through the | ||
// arithmetic ops. | ||
return failure(); | ||
} | ||
|
||
auto &innerOp = op.getBody()->getOperations().front(); | ||
if (!isa<T>(innerOp)) { | ||
return failure(); | ||
} | ||
|
||
// Assemble the arguments for the BGV operation. | ||
SmallVector<Value> inputs; | ||
bool ciphertextOnly = true; | ||
for (OpOperand &operand : innerOp.getOpOperands()) { | ||
if (auto *secretArg = op.getOpOperandForBlockArgument(operand.get())) { | ||
inputs.push_back( | ||
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); | ||
} else { | ||
ciphertextOnly = false; | ||
// This isn't a block argument - this must be a non-secret value used | ||
// from the ambient scope. | ||
assert(operand.get().getType().isSignlessInteger() && | ||
"expected signless integer like argument to ops"); | ||
auto encodeOp = rewriter.create<lwe::RLWEEncodeOp>( | ||
op->getLoc(), | ||
lwe::RLWEPlaintextType::get(rewriter.getContext(), {}, {}), | ||
operand.get()); | ||
encodeOp.dump(); | ||
inputs.push_back(encodeOp.getResult()); | ||
} | ||
} | ||
|
||
if (ciphertextOnly) { | ||
// Directly convert the op if all operands are ciphertext. | ||
replaceOp<T, Y>(rewriter, op, ValueRange(inputs)); | ||
return success(); | ||
} | ||
|
||
// One of the arguments must be a plaintext. Trivially encrypt and apply the | ||
// binary operation. | ||
replaceOp<T, Z>(rewriter, op, ValueRange(inputs)); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> { | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
auto *module = getOperation(); | ||
SecretToBGVTypeConverter typeConverter(context); | ||
|
||
RewritePatternSet patterns(context); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<bgv::BGVDialect>(); | ||
target.addIllegalDialect<secret::SecretDialect>(); | ||
target.addIllegalOp<secret::GenericOp>(); | ||
|
||
addStructuralConversionPatterns(typeConverter, patterns, target); | ||
patterns.add< | ||
SecretGenericOpConversion<arith::AddIOp, bgv::AddOp, bgv::AddPlainOp>, | ||
SecretGenericOpConversion<arith::MulIOp, bgv::MulOp, bgv::MulPlainOp>>( | ||
typeConverter, context); | ||
|
||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace mlir::heir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.