Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

secret-to-bgv: add lowering patterns for rotate and add hamming distance example #565

Merged
merged 1 commit into from Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 23 additions & 2 deletions include/Dialect/BGV/IR/BGVOps.td
Expand Up @@ -102,12 +102,12 @@ def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> {
let summary = "Multiplication operation between ciphertext-plaintext.";
}

def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
def BGV_Rotate : BGV_Op<"rotate", [AllTypesMatch<["x", "output"]>]> {
let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism.";

let arguments = (ins
RLWECiphertext:$x,
I64Attr:$offset
SignlessIntegerLike:$offset
);

let results = (outs
Expand All @@ -117,6 +117,27 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let hasVerifier = 1;
}

def BGV_ExtractOp : BGV_Op<"extract", [AllTypesMatch<["x", "output"]>]> {
let summary = "Extract the i-th element of a ciphertext.";

let description = [{
While this operation is costly to compute in FHE, we represent it so we can
implement efficient lowerings and folders.

This op can be implemented as a plaintext multiplication with a one-hot
vector and a rotate.
}];

let arguments = (ins
RLWECiphertext:$x,
SignlessIntegerLike:$offset
);

let results = (outs
RLWECiphertext:$output
);
}

def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
let summary = "Negate the coefficients of the ciphertext.";

Expand Down
26 changes: 22 additions & 4 deletions lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp
Expand Up @@ -12,6 +12,7 @@
#include "include/Dialect/Openfhe/IR/OpenfheTypes.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand Down Expand Up @@ -148,12 +149,29 @@ struct ConvertRotateOp : public OpConversionPattern<Rotate> {
if (failed(result)) return result;

Value cryptoContext = result.value();
auto offsetValue = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adaptor.getOffset()));
Value castOffset =
llvm::TypeSwitch<Type, Value>(adaptor.getOffset().getType())
.Case<IndexType>([&](auto ty) {
return rewriter
.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI64Type(), adaptor.getOffset())
.getResult();
})
.Case<IntegerType>([&](IntegerType ty) {
if (ty.getWidth() < 64) {
return rewriter
.create<arith::ExtUIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
}
return rewriter
.create<arith::TruncIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
});
rewriter.replaceOp(
op, rewriter.create<openfhe::RotOp>(op.getLoc(), cryptoContext,
adaptor.getX(), offsetValue));
adaptor.getX(), castOffset));
return success();
}
};
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/SecretToBGV/BUILD
Expand Up @@ -18,11 +18,13 @@ cc_library(
"@heir//lib/Dialect/Polynomial/IR:Polynomial",
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
Expand Down
41 changes: 23 additions & 18 deletions lib/Conversion/SecretToBGV/SecretToBGV.cpp
Expand Up @@ -14,16 +14,18 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir {
Expand Down Expand Up @@ -64,12 +66,14 @@ class SecretToBGVTypeConverter : public TypeConverter {

// Convert secret types to BGV ciphertext types
addConversion([ctx, this](secret::SecretType type) -> Type {
RankedTensorType tensorTy = cast<RankedTensorType>(type.getValueType());
int bitWidth =
llvm::TypeSwitch<Type, int>(type.getValueType())
.Case<RankedTensorType>(
[&](auto ty) -> int { return ty.getElementTypeBitWidth(); })
.Case<IntegerType>([&](auto ty) -> int { return ty.getWidth(); });
return lwe::RLWECiphertextType::get(
ctx,
lwe::PolynomialEvaluationEncodingAttr::get(
ctx, tensorTy.getElementTypeBitWidth(),
tensorTy.getElementTypeBitWidth()),
lwe::PolynomialEvaluationEncodingAttr::get(ctx, bitWidth, bitWidth),
lwe::RLWEParamsAttr::get(ctx, 2, ring_));
});

Expand Down Expand Up @@ -108,9 +112,7 @@ class SecretGenericOpConversion
inputs.push_back(
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]);
} else {
return rewriter.notifyMatchFailure(
op->getLoc(),
"Plaintext-ciphertext operations are not yet supported.");
inputs.push_back(operand.get());
}
}

Expand Down Expand Up @@ -158,7 +160,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
for (auto value : op->getOperands()) {
if (auto secretTy = dyn_cast<secret::SecretType>(value.getType())) {
auto tensorTy = dyn_cast<RankedTensorType>(secretTy.getValueType());
if (!tensorTy ||
if (tensorTy &&
tensorTy.getShape() !=
ArrayRef<int64_t>{rlweRing.value().getIdeal().getDegree()}) {
return WalkResult::interrupt();
Expand All @@ -169,7 +171,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
});
if (compatibleTensors.wasInterrupted()) {
module->emitError(
"expected secret types to be tensors with dimension "
"expected batched secret types to be tensors with dimension "
"matching ring parameter");
return signalPassFailure();
}
Expand All @@ -183,6 +185,9 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {

addStructuralConversionPatterns(typeConverter, patterns, target);
patterns.add<SecretGenericOpConversion<arith::AddIOp, bgv::AddOp>,
SecretGenericOpConversion<arith::SubIOp, bgv::SubOp>,
SecretGenericOpConversion<tensor::ExtractOp, bgv::ExtractOp>,
SecretGenericOpConversion<tensor_ext::RotateOp, bgv::Rotate>,
SecretGenericOpMulConversion>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
Expand Down
14 changes: 13 additions & 1 deletion tests/bgv/ops.mlir
Expand Up @@ -22,6 +22,7 @@

// CHECK: module
module {
// CHECK-LABEL: @test_multiply
func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct {
%add = bgv.add(%arg0, %arg1) : !ct
%sub = bgv.sub(%arg0, %arg1) : !ct
Expand All @@ -34,11 +35,22 @@ module {
return %arg0 : !ct
}

// CHECK-LABEL: @test_ciphertext_plaintext
func.func @test_ciphertext_plaintext(%arg0: !pt, %arg1: !pt, %arg2: !pt, %arg3: !ct) -> !ct {
%add = bgv.add_plain(%arg3, %arg0) : !ct
%sub = bgv.sub_plain(%add, %arg1) : !ct
%mul = bgv.mul_plain(%sub, %arg2) : !ct
// CHECK: rlwe_params = <dimension = 3, ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
// CHECK: rlwe_params = <ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %mul : !ct
}

// CHECK-LABEL: @test_rotate_extract
func.func @test_rotate_extract(%arg3: !ct) -> !ct {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%add = bgv.rotate(%arg3, %c1) : (!ct, index) -> !ct
%sub = bgv.extract(%add, %c0) : (!ct, index) -> !ct
// CHECK: rlwe_params = <ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %sub : !ct
}
}
6 changes: 4 additions & 2 deletions tests/bgv/to_openfhe.mlir
Expand Up @@ -35,9 +35,11 @@ module {
%sub = bgv.sub(%x, %y) : !ct
// CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%mul = bgv.mul(%x, %y) : !ct -> !ct_level3
// CHECK: %[[c5:.*]] = arith.constant 4 : i64
// CHECK: %[[c5:.*]] = arith.index_cast
// CHECK-SAME: to i64
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]]
%rot = bgv.rotate(%x) {offset = 4}: (!ct) -> !ct
%c4 = arith.constant 4 : index
%rot = bgv.rotate(%x, %c4): (!ct, index) -> !ct
return
}

Expand Down
24 changes: 22 additions & 2 deletions tests/bgv/verifier.mlir
@@ -1,4 +1,4 @@
// RUN: heir-opt --verify-diagnostics %s
// RUN: heir-opt --verify-diagnostics --split-input-file %s

#encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

Expand All @@ -12,8 +12,28 @@
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1>

func.func @test_input_dimension_error(%input: !ct) {
%offset = arith.constant 4 : index
// expected-error@+1 {{x.dim == 2 does not hold}}
%out = bgv.rotate (%input) {offset = 4} : (!ct) -> !ct1
%out = bgv.rotate (%input, %offset) : (!ct, index) -> !ct
return
}

// -----

#encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

#my_poly = #polynomial.polynomial<1 + x**1024>
#ring = #polynomial.ring<cmod=463187969, ideal=#my_poly>

#params = #lwe.rlwe_params<dimension=3, ring=#ring>
#params1 = #lwe.rlwe_params<dimension=2, ring=#ring>

!ct = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params>
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1>

func.func @test_input_output_type_match(%input: !ct1) {
%offset = arith.constant 4 : index
// expected-error@+1 {{failed to verify that all of {x, output} have same type}}
%out = bgv.rotate (%input, %offset) : (!ct1, index) -> !ct
return
}
31 changes: 31 additions & 0 deletions tests/secret_to_bgv/hamming_distance_1024.mlir
@@ -0,0 +1,31 @@
// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic \
// RUN: --canonicalize --cse --heir-simd-vectorizer \
// RUN: --secret-distribute-generic --secret-to-bgv \
// RUN: %s | FileCheck %s

// CHECK-LABEL: @hamming
// CHECK: bgv.sub
// CHECK-NEXT: bgv.mul
// CHECK-NEXT: bgv.relinearize

// TODO(#521): After rotate-and-reduce works, only check for 10 bg.rotate
// CHECK-COUNT-1023: bgv.rotate
// CHECK: bgv.extract
// CHECK-NEXT: return

func.func @hamming(%arg0: tensor<1024xi16>, %arg1: tensor<1024xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
%0 = affine.for %arg2 = 0 to 1024 iter_args(%arg6 = %c0_si16) -> i16 {
%1 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%2 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%3 = arith.subi %1, %2 : i16
%4 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%5 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%6 = arith.subi %4, %5 : i16
%7 = arith.muli %3, %6 : i16
%8 = arith.addi %arg6, %7 : i16
affine.yield %8 : i16
}
return %0 : i16
}
11 changes: 1 addition & 10 deletions tests/secret_to_bgv/invalid.mlir
Expand Up @@ -2,16 +2,7 @@

// Tests invalid secret types

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_not_tensor(%arg0 : !secret.secret<i1>) -> (!secret.secret<i1>) {
return %arg0 : !secret.secret<i1>
}
}

// -----

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
// expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_invalid_dimension(%arg0 : !secret.secret<tensor<1000xi1>>) -> (!secret.secret<tensor<1000xi1>>) {
return %arg0 : !secret.secret<tensor<1000xi1>>
Expand Down