Skip to content

Commit

Permalink
bgv: use LWE dialect's RLWE ciphertext type in the BGV dialect
Browse files Browse the repository at this point in the history
Fixes #487

PiperOrigin-RevId: 614686524
  • Loading branch information
asraa authored and Copybara-Service committed Mar 11, 2024
1 parent 27875c3 commit f7c5ffd
Show file tree
Hide file tree
Showing 32 changed files with 222 additions and 341 deletions.
12 changes: 0 additions & 12 deletions include/Dialect/BGV/IR/BGVAttributes.h

This file was deleted.

17 changes: 0 additions & 17 deletions include/Dialect/BGV/IR/BGVAttributes.td

This file was deleted.

3 changes: 0 additions & 3 deletions include/Dialect/BGV/IR/BGVDialect.td
Expand Up @@ -14,9 +14,6 @@ def BGV_Dialect : Dialect {
}];

let cppNamespace = "::mlir::heir::bgv";

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}

#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVDIALECT_TD_
2 changes: 0 additions & 2 deletions include/Dialect/BGV/IR/BGVOps.h
@@ -1,10 +1,8 @@
#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_
#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_

#include "include/Dialect/BGV/IR/BGVAttributes.h"
#include "include/Dialect/BGV/IR/BGVDialect.h"
#include "include/Dialect/BGV/IR/BGVTraits.h"
#include "include/Dialect/BGV/IR/BGVTypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
Expand Down
54 changes: 27 additions & 27 deletions include/Dialect/BGV/IR/BGVOps.td
Expand Up @@ -2,11 +2,13 @@
#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_TD_

include "BGVDialect.td"
include "BGVTypes.td"

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

include "include/Dialect/LWE/IR/LWETypes.td"
include "include/Dialect/Polynomial/IR/PolynomialAttributes.td"

def SameOperandsAndResultRings: NativeOpTrait<"SameOperandsAndResultRings"> {
let cppNamespace = "::mlir::heir::bgv";
}
Expand All @@ -15,7 +17,7 @@ class BGV_Op<string mnemonic, list<Trait> traits = []> :
Op<BGV_Dialect, mnemonic, traits> {

let assemblyFormat = [{
`(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type(results)
`(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
}];
let cppNamespace = "::mlir::heir::bgv";
}
Expand All @@ -26,45 +28,45 @@ def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> {
let summary = "Addition operation between ciphertexts.";

let arguments = (ins
Ciphertext:$x,
Ciphertext:$y
RLWECiphertext:$x,
RLWECiphertext:$y
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ;
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> {
let summary = "Subtraction operation between ciphertexts.";

let arguments = (ins
Ciphertext:$x,
Ciphertext:$y
RLWECiphertext:$x,
RLWECiphertext:$y
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ;
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands]> {
let summary = "Multiplication operation between ciphertexts.";

let arguments = (ins
Ciphertext:$x,
Ciphertext:$y
RLWECiphertext:$x,
RLWECiphertext:$y
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let assemblyFormat = "`(` operands `)` attr-dict `:` type($x) `->` type($output)" ;
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($x)) `->` qualified(type($output))" ;

let hasVerifier = 1;
}
Expand All @@ -73,12 +75,12 @@ def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> {
let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism.";

let arguments = (ins
Ciphertext:$x,
RLWECiphertext:$x,
I64Attr:$offset
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let hasVerifier = 1;
Expand All @@ -88,14 +90,14 @@ def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> {
let summary = "Negate the coefficients of the ciphertext.";

let arguments = (ins
Ciphertext:$x
RLWECiphertext:$x
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let assemblyFormat = "`(` operands `)` attr-dict `:` type($output)" ;
let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ;
}

def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> {
Expand All @@ -112,30 +114,28 @@ def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> {
}];

let arguments = (ins
Ciphertext:$x,
RLWECiphertext:$x,
DenseI32ArrayAttr:$from_basis,
DenseI32ArrayAttr:$to_basis
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let hasVerifier = 1;
}

def BGV_ModulusSwitch : BGV_Op<"modulus_switch", [SameOperandsAndResultRings]> {
// This must be validated against the BGV ring parameter.
def BGV_ModulusSwitch : BGV_Op<"modulus_switch"> {
let summary = "Lower the modulus level of the ciphertext.";

let arguments = (ins
Ciphertext:$x,
I64Attr:$from_level,
I64Attr:$to_level
RLWECiphertext:$x,
Ring_Attr:$to_ring
);

let results = (outs
Ciphertext:$output
RLWECiphertext:$output
);

let hasVerifier = 1;
Expand Down
20 changes: 11 additions & 9 deletions include/Dialect/BGV/IR/BGVTraits.h
@@ -1,43 +1,45 @@
#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_
#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_

#include "include/Dialect/BGV/IR/BGVAttributes.h"
#include "include/Dialect/BGV/IR/BGVTypes.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir::heir::bgv {

// TODO(#212): Move to LWE dialect/namespace
// Trait that ensures that all operands and results ciphertext have the same set
// of rings.
template <typename ConcreteType>
class SameOperandsAndResultRings
: public OpTrait::TraitBase<ConcreteType, SameOperandsAndResultRings> {
public:
static LogicalResult verifyTrait(Operation *op) {
BGVRingsAttr rings = nullptr;
polynomial::RingAttr rings = nullptr;
for (auto rTy : op->getResultTypes()) {
auto ct = dyn_cast<CiphertextType>(rTy);
auto ct = dyn_cast<lwe::RLWECiphertextType>(rTy);
if (!ct) continue;
if (rings == nullptr) {
rings = ct.getRings();
rings = ct.getRlweParams().getRing();
continue;
}
if (rings != ct.getRings()) {
if (rings != ct.getRlweParams().getRing()) {
return op->emitOpError()
<< "requires all operands and results to have the same rings";
}
}

for (auto oTy : op->getOperandTypes()) {
auto ct = dyn_cast<CiphertextType>(oTy);
auto ct = dyn_cast<lwe::RLWECiphertextType>(oTy);
if (!ct) continue; // Check only ciphertexts

if (rings == nullptr) {
rings = ct.getRings();
rings = ct.getRlweParams().getRing();
continue;
}

if (rings != ct.getRings()) {
if (rings != ct.getRlweParams().getRing()) {
return op->emitOpError()
<< "requires all operands and results to have the same rings";
}
Expand Down
13 changes: 0 additions & 13 deletions include/Dialect/BGV/IR/BGVTypes.h

This file was deleted.

47 changes: 0 additions & 47 deletions include/Dialect/BGV/IR/BGVTypes.td

This file was deleted.

0 comments on commit f7c5ffd

Please sign in to comment.