Skip to content

Commit

Permalink
bgv: add BGV ciphertext-plaintext ops and cleanup traits
Browse files Browse the repository at this point in the history
Fixes #100
Fixes #212

PiperOrigin-RevId: 614740812
  • Loading branch information
asraa authored and Copybara-Service committed Mar 11, 2024
1 parent bf1f9ca commit a00de65
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 17 deletions.
2 changes: 1 addition & 1 deletion include/Dialect/BGV/IR/BGVOps.h
Expand Up @@ -2,7 +2,7 @@
#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_

#include "include/Dialect/BGV/IR/BGVDialect.h"
#include "include/Dialect/BGV/IR/BGVTraits.h"
#include "include/Dialect/LWE/IR/LWETraits.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
Expand Down
39 changes: 34 additions & 5 deletions include/Dialect/BGV/IR/BGVOps.td
Expand Up @@ -10,19 +10,36 @@ include "include/Dialect/LWE/IR/LWETypes.td"
include "include/Dialect/Polynomial/IR/PolynomialAttributes.td"

def SameOperandsAndResultRings: NativeOpTrait<"SameOperandsAndResultRings"> {
let cppNamespace = "::mlir::heir::bgv";
let cppNamespace = "::mlir::heir::lwe";
}

class BGV_Op<string mnemonic, list<Trait> traits = []> :
Op<BGV_Dialect, mnemonic, traits> {

let cppNamespace = "::mlir::heir::bgv";

let assemblyFormat = [{
`(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
}];
let cppNamespace = "::mlir::heir::bgv";
}

// TODO(#100): Add plaintext-ciphertext operations.
class BGV_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
[AllTypesMatch<["x", "output"]>,
TypesMatchWith<"type of 'y' matches encoding type of 'x'",
"output", "y",
"lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast<lwe::RLWECiphertextType>($_self).getEncoding())">]> :
BGV_Op<mnemonic, traits> {
let arguments = (ins
RLWECiphertext:$x,
RLWEPlaintext:$y
);

let results = (outs
RLWECiphertext:$output
);

let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> {
let summary = "Addition operation between ciphertexts.";
Expand All @@ -39,6 +56,10 @@ def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> {
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_AddPlainOp : BGV_CiphertextPlaintextOp<"add_plain"> {
let summary = "Addition operation between ciphertext-plaintext.";
}

def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> {
let summary = "Subtraction operation between ciphertexts.";

Expand All @@ -54,7 +75,11 @@ def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> {
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands]> {
def BGV_SubPlainOp : BGV_CiphertextPlaintextOp<"sub_plain"> {
let summary = "Subtraction operation between ciphertext-plaintext.";
}

def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands, InferTypeOpAdaptor]> {
let summary = "Multiplication operation between ciphertexts.";

let arguments = (ins
Expand All @@ -71,6 +96,10 @@ def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameType
let hasVerifier = 1;
}

def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> {
let summary = "Multiplication operation between ciphertext-plaintext.";
}

def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism.";

Expand Down Expand Up @@ -100,7 +129,7 @@ def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> {
def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings, InferTypeOpAdaptor]> {
let summary = "Relinearize the ciphertext.";

let description = [{
Expand Down
1 change: 0 additions & 1 deletion include/Dialect/BGV/IR/BUILD
Expand Up @@ -11,7 +11,6 @@ exports_files(
[
"BGVDialect.h",
"BGVOps.h",
"BGVTraits.h",
],
)

Expand Down
1 change: 1 addition & 0 deletions include/Dialect/LWE/IR/BUILD
Expand Up @@ -13,6 +13,7 @@ exports_files(
"LWEAttributes.h",
"LWETypes.h",
"LWEOps.h",
"LWETraits.h",
],
)

Expand Down
@@ -1,14 +1,15 @@
#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_
#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_
#ifndef HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_
#define HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_

#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir::heir::bgv {
namespace mlir::heir::lwe {

// TODO(#212): Move to LWE dialect/namespace
// Trait that ensures that all operands and results ciphertext have the same set
// of rings.
template <typename ConcreteType>
Expand Down Expand Up @@ -48,6 +49,6 @@ class SameOperandsAndResultRings
}
};

} // namespace mlir::heir::bgv
} // namespace mlir::heir::lwe

#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_
#endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_
31 changes: 30 additions & 1 deletion lib/Dialect/BGV/IR/BGVDialect.cpp
@@ -1,9 +1,15 @@
#include "include/Dialect/BGV/IR/BGVDialect.h"

#include <optional>

#include "include/Dialect/BGV/IR/BGVOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "include/Dialect/LWE/IR/LWEAttributes.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

// Generated definitions
#include "include/Dialect/BGV/IR/BGVDialect.cpp.inc"
Expand Down Expand Up @@ -85,6 +91,29 @@ LogicalResult ModulusSwitch::verify() {
return success();
}

LogicalResult MulOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, MulOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto x = cast<lwe::RLWECiphertextType>(adaptor.getX().getType());
auto y = cast<lwe::RLWECiphertextType>(adaptor.getY().getType());
auto newDim =
x.getRlweParams().getDimension() + y.getRlweParams().getDimension() - 1;
inferredReturnTypes.push_back(lwe::RLWECiphertextType::get(
ctx, x.getEncoding(),
lwe::RLWEParamsAttr::get(ctx, newDim, x.getRlweParams().getRing())));
return success();
}

LogicalResult Relinearize::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, Relinearize::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto x = cast<lwe::RLWECiphertextType>(adaptor.getX().getType());
inferredReturnTypes.push_back(lwe::RLWECiphertextType::get(
ctx, x.getEncoding(),
lwe::RLWEParamsAttr::get(ctx, 2, x.getRlweParams().getRing())));
return success();
}

} // namespace bgv
} // namespace heir
} // namespace mlir
1 change: 0 additions & 1 deletion lib/Dialect/BGV/IR/BUILD
Expand Up @@ -13,7 +13,6 @@ cc_library(
hdrs = [
"@heir//include/Dialect/BGV/IR:BGVDialect.h",
"@heir//include/Dialect/BGV/IR:BGVOps.h",
"@heir//include/Dialect/BGV/IR:BGVTraits.h",
],
deps = [
"@heir//include/Dialect/BGV/IR:dialect_inc_gen",
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/LWE/IR/BUILD
Expand Up @@ -12,6 +12,7 @@ cc_library(
"@heir//include/Dialect/LWE/IR:LWEAttributes.h",
"@heir//include/Dialect/LWE/IR:LWEDialect.h",
"@heir//include/Dialect/LWE/IR:LWEOps.h",
"@heir//include/Dialect/LWE/IR:LWETraits.h",
"@heir//include/Dialect/LWE/IR:LWETypes.h",
],
deps = [
Expand All @@ -21,6 +22,7 @@ cc_library(
"@heir//include/Dialect/LWE/IR:types_inc_gen",
"@heir//include/Dialect/Polynomial/IR:attributes_inc_gen",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
10 changes: 10 additions & 0 deletions tests/bgv/ops.mlir
Expand Up @@ -14,6 +14,8 @@
#params1 = #lwe.rlwe_params<dimension=3, ring=#ring1>
#params2 = #lwe.rlwe_params<dimension=2, ring=#ring2>

!pt = !lwe.rlwe_plaintext<encoding=#encoding>

!ct = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params>
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1>
!ct2 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params2>
Expand All @@ -31,4 +33,12 @@ module {
// CHECK: rlwe_params = <dimension = 3, ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %arg0 : !ct
}

func.func @test_ciphertext_plaintext(%arg0: !pt, %arg1: !pt, %arg2: !pt, %arg3: !ct) -> !ct {
%add = bgv.add_plain(%arg3, %arg0) : !ct
%sub = bgv.sub_plain(%add, %arg1) : !ct
%mul = bgv.mul_plain(%sub, %arg2) : !ct
// CHECK: rlwe_params = <dimension = 3, ring = <cmod=161729713, ideal=#polynomial.polynomial<1 + x**1024>>>
return %mul : !ct
}
}

0 comments on commit a00de65

Please sign in to comment.