Skip to content

Commit

Permalink
secret-to-bgv: add lowering patterns for rotate and add hamming dista…
Browse files Browse the repository at this point in the history
…nce example

PiperOrigin-RevId: 617242491
  • Loading branch information
asraa authored and Copybara-Service committed Mar 29, 2024
1 parent 33d8492 commit cdb0e02
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 39 deletions.
25 changes: 23 additions & 2 deletions include/Dialect/BGV/IR/BGVOps.td
Expand Up @@ -102,12 +102,12 @@ def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> {
let summary = "Multiplication operation between ciphertext-plaintext.";
}

def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
def BGV_Rotate : BGV_Op<"rotate", [AllTypesMatch<["x", "output"]>]> {
let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism.";

let arguments = (ins
RLWECiphertext:$x,
I64Attr:$offset
SignlessIntegerLike:$offset
);

let results = (outs
Expand All @@ -117,6 +117,27 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let hasVerifier = 1;
}

def BGV_ExtractOp : BGV_Op<"extract", [AllTypesMatch<["x", "output"]>]> {
let summary = "Extract the i-th element of a ciphertext.";

let description = [{
While this operation is costly to compute in FHE, we represent it so we can
implement efficient lowerings and folders.

This op can be implemented as a plaintext multiplication with a one-hot
vector and a rotate.
}];

let arguments = (ins
RLWECiphertext:$x,
SignlessIntegerLike:$offset
);

let results = (outs
RLWECiphertext:$output
);
}

def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
let summary = "Negate the coefficients of the ciphertext.";

Expand Down
26 changes: 22 additions & 4 deletions lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp
Expand Up @@ -12,6 +12,7 @@
#include "include/Dialect/Openfhe/IR/OpenfheTypes.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand Down Expand Up @@ -148,12 +149,29 @@ struct ConvertRotateOp : public OpConversionPattern<Rotate> {
if (failed(result)) return result;

Value cryptoContext = result.value();
auto offsetValue = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adaptor.getOffset()));
Value castOffset =
llvm::TypeSwitch<Type, Value>(adaptor.getOffset().getType())
.Case<IndexType>([&](auto ty) {
return rewriter
.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI64Type(), adaptor.getOffset())
.getResult();
})
.Case<IntegerType>([&](IntegerType ty) {
if (ty.getWidth() < 64) {
return rewriter
.create<arith::ExtUIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
}
return rewriter
.create<arith::TruncIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
});
rewriter.replaceOp(
op, rewriter.create<openfhe::RotOp>(op.getLoc(), cryptoContext,
adaptor.getX(), offsetValue));
adaptor.getX(), castOffset));
return success();
}
};
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/SecretToBGV/BUILD
Expand Up @@ -18,11 +18,13 @@ cc_library(
"@heir//lib/Dialect/Polynomial/IR:Polynomial",
"@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
Expand Down
41 changes: 23 additions & 18 deletions lib/Conversion/SecretToBGV/SecretToBGV.cpp
Expand Up @@ -14,16 +14,18 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Conversion/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir {
Expand Down Expand Up @@ -64,12 +66,14 @@ class SecretToBGVTypeConverter : public TypeConverter {

// Convert secret types to BGV ciphertext types
addConversion([ctx, this](secret::SecretType type) -> Type {
RankedTensorType tensorTy = cast<RankedTensorType>(type.getValueType());
int bitWidth =
llvm::TypeSwitch<Type, int>(type.getValueType())
.Case<RankedTensorType>(
[&](auto ty) -> int { return ty.getElementTypeBitWidth(); })
.Case<IntegerType>([&](auto ty) -> int { return ty.getWidth(); });
return lwe::RLWECiphertextType::get(
ctx,
lwe::PolynomialEvaluationEncodingAttr::get(
ctx, tensorTy.getElementTypeBitWidth(),
tensorTy.getElementTypeBitWidth()),
lwe::PolynomialEvaluationEncodingAttr::get(ctx, bitWidth, bitWidth),
lwe::RLWEParamsAttr::get(ctx, 2, ring_));
});

Expand Down Expand Up @@ -108,9 +112,7 @@ class SecretGenericOpConversion
inputs.push_back(
adaptor.getODSOperands(0)[secretArg->getOperandNumber()]);
} else {
return rewriter.notifyMatchFailure(
op->getLoc(),
"Plaintext-ciphertext operations are not yet supported.");
inputs.push_back(operand.get());
}
}

Expand Down Expand Up @@ -158,7 +160,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
for (auto value : op->getOperands()) {
if (auto secretTy = dyn_cast<secret::SecretType>(value.getType())) {
auto tensorTy = dyn_cast<RankedTensorType>(secretTy.getValueType());
if (!tensorTy ||
if (tensorTy &&
tensorTy.getShape() !=
ArrayRef<int64_t>{rlweRing.value().getIdeal().getDegree()}) {
return WalkResult::interrupt();
Expand All @@ -169,7 +171,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
});
if (compatibleTensors.wasInterrupted()) {
module->emitError(
"expected secret types to be tensors with dimension "
"expected batched secret types to be tensors with dimension "
"matching ring parameter");
return signalPassFailure();
}
Expand All @@ -183,6 +185,9 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {

addStructuralConversionPatterns(typeConverter, patterns, target);
patterns.add<SecretGenericOpConversion<arith::AddIOp, bgv::AddOp>,
SecretGenericOpConversion<arith::SubIOp, bgv::SubOp>,
SecretGenericOpConversion<tensor::ExtractOp, bgv::ExtractOp>,
SecretGenericOpConversion<tensor_ext::RotateOp, bgv::Rotate>,
SecretGenericOpMulConversion>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
Expand Down
14 changes: 13 additions & 1 deletion tests/bgv/ops.mlir
Expand Up @@ -22,6 +22,7 @@

// CHECK: module
module {
// CHECK-LABEL: @test_multiply
func.func @test_multiply(%arg0 : !ct, %arg1: !ct) -> !ct {
%add = bgv.add(%arg0, %arg1) : !ct
%sub = bgv.sub(%arg0, %arg1) : !ct
Expand All @@ -34,11 +35,22 @@ module {
return %arg0 : !ct
}

// CHECK-LABEL: @test_ciphertext_plaintext
func.func @test_ciphertext_plaintext(%arg0: !pt, %arg1: !pt, %arg2: !pt, %arg3: !ct) -> !ct {
%add = bgv.add_plain(%arg3, %arg0) : !ct
%sub = bgv.sub_plain(%add, %arg1) : !ct
%mul = bgv.mul_plain(%sub, %arg2) : !ct
// CHECK: rlwe_params = <dimension = 3, ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
// CHECK: rlwe_params = <ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %mul : !ct
}

// CHECK-LABEL: @test_rotate_extract
func.func @test_rotate_extract(%arg3: !ct) -> !ct {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%add = bgv.rotate(%arg3, %c1) : (!ct, index) -> !ct
%sub = bgv.extract(%add, %c0) : (!ct, index) -> !ct
// CHECK: rlwe_params = <ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %sub : !ct
}
}
6 changes: 4 additions & 2 deletions tests/bgv/to_openfhe.mlir
Expand Up @@ -35,9 +35,11 @@ module {
%sub = bgv.sub(%x, %y) : !ct
// CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T]]
%mul = bgv.mul(%x, %y) : !ct -> !ct_level3
// CHECK: %[[c5:.*]] = arith.constant 4 : i64
// CHECK: %[[c5:.*]] = arith.index_cast
// CHECK-SAME: to i64
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]]
%rot = bgv.rotate(%x) {offset = 4}: (!ct) -> !ct
%c4 = arith.constant 4 : index
%rot = bgv.rotate(%x, %c4): (!ct, index) -> !ct
return
}

Expand Down
24 changes: 22 additions & 2 deletions tests/bgv/verifier.mlir
@@ -1,4 +1,4 @@
// RUN: heir-opt --verify-diagnostics %s
// RUN: heir-opt --verify-diagnostics --split-input-file %s

#encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

Expand All @@ -12,8 +12,28 @@
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1>

func.func @test_input_dimension_error(%input: !ct) {
%offset = arith.constant 4 : index
// expected-error@+1 {{x.dim == 2 does not hold}}
%out = bgv.rotate (%input) {offset = 4} : (!ct) -> !ct1
%out = bgv.rotate (%input, %offset) : (!ct, index) -> !ct
return
}

// -----

#encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

#my_poly = #polynomial.polynomial<1 + x**1024>
#ring = #polynomial.ring<cmod=463187969, ideal=#my_poly>

#params = #lwe.rlwe_params<dimension=3, ring=#ring>
#params1 = #lwe.rlwe_params<dimension=2, ring=#ring>

!ct = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params>
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1>

func.func @test_input_output_type_match(%input: !ct1) {
%offset = arith.constant 4 : index
// expected-error@+1 {{failed to verify that all of {x, output} have same type}}
%out = bgv.rotate (%input, %offset) : (!ct1, index) -> !ct
return
}
31 changes: 31 additions & 0 deletions tests/secret_to_bgv/hamming_distance_1024.mlir
@@ -0,0 +1,31 @@
// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic \
// RUN: --canonicalize --cse --heir-simd-vectorizer \
// RUN: --secret-distribute-generic --secret-to-bgv \
// RUN: %s | FileCheck %s

// CHECK-LABEL: @hamming
// CHECK: bgv.sub
// CHECK-NEXT: bgv.mul
// CHECK-NEXT: bgv.relinearize

// TODO(#521): After rotate-and-reduce works, only check for 10 bg.rotate
// CHECK-COUNT-1023: bgv.rotate
// CHECK: bgv.extract
// CHECK-NEXT: return

func.func @hamming(%arg0: tensor<1024xi16>, %arg1: tensor<1024xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
%0 = affine.for %arg2 = 0 to 1024 iter_args(%arg6 = %c0_si16) -> i16 {
%1 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%2 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%3 = arith.subi %1, %2 : i16
%4 = tensor.extract %arg0[%arg2] : tensor<1024xi16>
%5 = tensor.extract %arg1[%arg2] : tensor<1024xi16>
%6 = arith.subi %4, %5 : i16
%7 = arith.muli %3, %6 : i16
%8 = arith.addi %arg6, %7 : i16
affine.yield %8 : i16
}
return %0 : i16
}
11 changes: 1 addition & 10 deletions tests/secret_to_bgv/invalid.mlir
Expand Up @@ -2,16 +2,7 @@

// Tests invalid secret types

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_not_tensor(%arg0 : !secret.secret<i1>) -> (!secret.secret<i1>) {
return %arg0 : !secret.secret<i1>
}
}

// -----

// expected-error@below {{expected secret types to be tensors with dimension matching ring parameter}}
// expected-error@below {{expected batched secret types to be tensors with dimension matching ring parameter}}
module {
func.func @test_invalid_dimension(%arg0 : !secret.secret<tensor<1000xi1>>) -> (!secret.secret<tensor<1000xi1>>) {
return %arg0 : !secret.secret<tensor<1000xi1>>
Expand Down

0 comments on commit cdb0e02

Please sign in to comment.