Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add secret-to-bgv conversion for ciphertext arith ops
Adds parameter setting options. All secret inputs are required to be a uniform tensor shape that matches the ring dimension in the parameters specified. PiperOrigin-RevId: 615502361
- Loading branch information
Showing
18 changed files
with
421 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SecretToBGV tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
[ | ||
"SecretToBGV.h", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=SecretToBGV", | ||
], | ||
"SecretToBGV.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"SecretToBGV.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "SecretToBGV.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ | ||
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir::heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
} // namespace mlir::heir | ||
|
||
#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#ifndef INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ | ||
#define INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def SecretToBGV : Pass<"secret-to-bgv"> { | ||
let summary = "Lower `secret` to `bgv` dialect."; | ||
|
||
let description = [{ | ||
This pass lowers an IR with `secret.generic` blocks containing arithmetic | ||
operations to operations on ciphertexts with the BGV dialect. | ||
|
||
The pass assumes that the `secret.generic` regions have been distributed | ||
through arithmetic operations so that only one ciphertext operation appears | ||
per generic block. It also requires that `canonicalize` was run so that | ||
non-secret values used are removed from the `secret.generic`'s block | ||
arguments. | ||
|
||
The pass requires that all types are tensors of a uniform shape matching the | ||
dimension of the ciphertext space specified my `poly-mod-degree`. | ||
}]; | ||
|
||
let dependentDialects = [ | ||
"mlir::heir::polynomial::PolynomialDialect", | ||
"mlir::heir::bgv::BGVDialect", | ||
"mlir::heir::lwe::LWEDialect", | ||
]; | ||
|
||
let options = [ | ||
Option<"polyModDegree", "poly-mod-degree", "int", | ||
/*default=*/"1024", "Default degree of the cyclotomic polynomial " | ||
"modulus to use for ciphertext space.">, | ||
Option<"coefficientModBits", "coefficient-mod-bits", "int", | ||
/*default=*/"29", "Default number of bits of the prime " | ||
"coefficient modulus to use " "for the ciphertext space."> | ||
]; | ||
} | ||
|
||
#endif // INCLUDE_CONVERSION_SECRETTOBGV_SECRETTOBGV_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "SecretToBGV", | ||
srcs = ["SecretToBGV.cpp"], | ||
hdrs = [ | ||
"@heir//include/Conversion/SecretToBGV:SecretToBGV.h", | ||
], | ||
deps = [ | ||
"@heir//include/Conversion/SecretToBGV:pass_inc_gen", | ||
"@heir//lib/Conversion:Utils", | ||
"@heir//lib/Dialect/BGV/IR:Dialect", | ||
"@heir//lib/Dialect/LWE/IR:Dialect", | ||
"@heir//lib/Dialect/Polynomial/IR:Dialect", | ||
"@heir//lib/Dialect/Polynomial/IR:Polynomial", | ||
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", | ||
"@heir//lib/Dialect/Secret/IR:Dialect", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h" | ||
|
||
#include <cassert> | ||
#include <utility> | ||
|
||
#include "include/Dialect/BGV/IR/BGVDialect.h" | ||
#include "include/Dialect/BGV/IR/BGVOps.h" | ||
#include "include/Dialect/LWE/IR/LWEAttributes.h" | ||
#include "include/Dialect/LWE/IR/LWEOps.h" | ||
#include "include/Dialect/LWE/IR/LWETypes.h" | ||
#include "include/Dialect/Polynomial/IR/Polynomial.h" | ||
#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" | ||
#include "include/Dialect/Polynomial/IR/PolynomialDialect.h" | ||
#include "include/Dialect/Secret/IR/SecretDialect.h" | ||
#include "include/Dialect/Secret/IR/SecretOps.h" | ||
#include "include/Dialect/Secret/IR/SecretTypes.h" | ||
#include "lib/Conversion/Utils.h" | ||
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project | ||
|
||
namespace mlir::heir { | ||
|
||
#define GEN_PASS_DEF_SECRETTOBGV | ||
#include "include/Conversion/SecretToBGV/SecretToBGV.h.inc" | ||
|
||
namespace { | ||
|
||
FailureOr<polynomial::RingAttr> getRlweRing(MLIRContext *ctx, | ||
int coefficientModBits, | ||
int polyModDegree) { | ||
std::vector<polynomial::Monomial> monomials; | ||
monomials.emplace_back(1, polyModDegree); | ||
monomials.emplace_back(1, 0); | ||
polynomial::Polynomial xnPlusOne = | ||
polynomial::Polynomial::fromMonomials(monomials, ctx); | ||
switch (coefficientModBits) { | ||
case 29: | ||
return polynomial::RingAttr::get( | ||
APInt(polynomial::APINT_BIT_WIDTH, 463187969), xnPlusOne); | ||
default: | ||
return failure(); | ||
} | ||
} | ||
|
||
// Templatized function for replacing an arithmetic operation T with a BGV | ||
// equivalent operation Y. | ||
template <typename T, typename Y> | ||
void replaceOp(ConversionPatternRewriter &rewriter, secret::GenericOp op, | ||
ValueRange inputs) { | ||
rewriter.replaceOpWithNewOp<Y>(op, inputs); | ||
} | ||
|
||
template <> | ||
void replaceOp<arith::MulIOp, bgv::MulOp>(ConversionPatternRewriter &rewriter, | ||
secret::GenericOp op, | ||
ValueRange inputs) { | ||
rewriter.replaceOpWithNewOp<bgv::Relinearize>( | ||
op, rewriter.create<bgv::MulOp>(op.getLoc(), inputs), | ||
rewriter.getDenseI32ArrayAttr({0, 1, 2}), | ||
rewriter.getDenseI32ArrayAttr({0, 1})); | ||
} | ||
|
||
} // namespace | ||
|
||
// Remove this class if no type conversions are necessary | ||
class SecretToBGVTypeConverter : public TypeConverter { | ||
public: | ||
SecretToBGVTypeConverter(MLIRContext *ctx, polynomial::RingAttr rlweRing) { | ||
addConversion([](Type type) { return type; }); | ||
|
||
// Convert secret types to BGV ciphertext types | ||
addConversion([ctx, this](secret::SecretType type) -> Type { | ||
RankedTensorType tensorTy = cast<RankedTensorType>(type.getValueType()); | ||
return lwe::RLWECiphertextType::get( | ||
ctx, | ||
lwe::PolynomialEvaluationEncodingAttr::get( | ||
ctx, tensorTy.getElementTypeBitWidth(), | ||
tensorTy.getElementTypeBitWidth()), | ||
lwe::RLWEParamsAttr::get(ctx, 2, ring_)); | ||
}); | ||
|
||
ring_ = rlweRing; | ||
} | ||
|
||
polynomial::RingAttr ring_; | ||
}; | ||
|
||
template <typename T, typename Y> | ||
class SecretGenericOpConversion | ||
: public OpConversionPattern<secret::GenericOp> { | ||
public: | ||
using OpConversionPattern<secret::GenericOp>::OpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
secret::GenericOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const final { | ||
if (op.getBody()->getOperations().size() > 2) { | ||
// Each secret.generic should contain at most one instruction - | ||
// secret-distribute-generic can be used to distribute through the | ||
// arithmetic ops. | ||
return failure(); | ||
} | ||
|
||
auto &innerOp = op.getBody()->getOperations().front(); | ||
if (!isa<T>(innerOp)) { | ||
return failure(); | ||
} | ||
|
||
// Assemble the arguments for the BGV operation. | ||
SmallVector<Value> inputs; | ||
for (OpOperand &operand : innerOp.getOpOperands()) { | ||
if (auto *secretArg = op.getOpOperandForBlockArgument(operand.get())) { | ||
inputs.push_back( | ||
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]); | ||
} else { | ||
// Plaintext-ciphertext operations are not handled. | ||
return failure(); | ||
} | ||
} | ||
|
||
// Directly convert the op if all operands are ciphertext. | ||
replaceOp<T, Y>(rewriter, op, inputs); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> { | ||
using SecretToBGVBase::SecretToBGVBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
auto *module = getOperation(); | ||
|
||
auto rlweRing = getRlweRing(context, coefficientModBits, polyModDegree); | ||
if (failed(rlweRing)) { | ||
return signalPassFailure(); | ||
} | ||
// Ensure that all secret types are uniform and matching the ring | ||
// parameter size. | ||
WalkResult compatibleTensors = module->walk([&](Operation *op) { | ||
for (auto value : op->getOperands()) { | ||
if (auto secretTy = dyn_cast<secret::SecretType>(value.getType())) { | ||
auto tensorTy = dyn_cast<RankedTensorType>(secretTy.getValueType()); | ||
if (!tensorTy || | ||
tensorTy.getShape() != | ||
ArrayRef<int64_t>{rlweRing.value().getIdeal().getDegree()}) { | ||
return WalkResult::interrupt(); | ||
} | ||
} | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (compatibleTensors.wasInterrupted()) { | ||
module->emitError( | ||
"expected secret types to be tensors with dimension " | ||
"matching ring parameter"); | ||
return signalPassFailure(); | ||
} | ||
|
||
SecretToBGVTypeConverter typeConverter(context, rlweRing.value()); | ||
RewritePatternSet patterns(context); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<bgv::BGVDialect>(); | ||
target.addIllegalDialect<secret::SecretDialect>(); | ||
target.addIllegalOp<secret::GenericOp>(); | ||
|
||
addStructuralConversionPatterns(typeConverter, patterns, target); | ||
patterns.add<SecretGenericOpConversion<arith::AddIOp, bgv::AddOp>, | ||
SecretGenericOpConversion<arith::MulIOp, bgv::MulOp>>( | ||
typeConverter, context); | ||
|
||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace mlir::heir |
Oops, something went wrong.