Skip to content

Commit

Permalink
PR Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest committed Mar 7, 2024
1 parent c8c0d50 commit df3856e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 72 deletions.
6 changes: 3 additions & 3 deletions lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp
Expand Up @@ -84,9 +84,9 @@ int widthFromEncodingAttr(Attribute encoding) {
});
}

class PassTypeConverter : public TypeConverter {
class CGGIToTfheRustTypeConverter : public TypeConverter {
public:
PassTypeConverter(MLIRContext *ctx) {
CGGIToTfheRustTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](lwe::LWECiphertextType type) -> Type {
int width = widthFromEncodingAttr(type.getEncoding());
Expand Down Expand Up @@ -402,7 +402,7 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
MLIRContext *context = &getContext();
auto *op = getOperation();

PassTypeConverter typeConverter(context);
CGGIToTfheRustTypeConverter typeConverter(context);
RewritePatternSet patterns(context);
ConversionTarget target(*context);
addStructuralConversionPatterns(typeConverter, patterns, target);
Expand Down
68 changes: 14 additions & 54 deletions lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp
Expand Up @@ -32,9 +32,9 @@ namespace mlir::heir {
#define GEN_PASS_DEF_CGGITOTFHERUSTBOOL
#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc"

class BoolPassTypeConverter : public TypeConverter {
class CGGIToTfheRustBoolTypeConverter : public TypeConverter {
public:
BoolPassTypeConverter(MLIRContext *ctx) {
CGGIToTfheRustBoolTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](lwe::LWECiphertextType type) -> Type {
return tfhe_rust_bool::EncryptedBoolType::get(ctx);
Expand Down Expand Up @@ -131,68 +131,28 @@ struct AddBoolServerKeyArg : public OpConversionPattern<func::FuncOp> {
}
};

struct ConvertBoolAndOp : public OpConversionPattern<cggi::AndOp> {
ConvertBoolAndOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::AndOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::AndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
FailureOr<Value> result = getContextualBoolServerKey(op);
if (failed(result)) return result;

Value serverKey = result.value();

rewriter.replaceOp(op, b.create<tfhe_rust_bool::AndOp>(
serverKey, adaptor.getLhs(), adaptor.getRhs()));
return success();
}
};

struct ConvertBoolOrOp : public OpConversionPattern<cggi::OrOp> {
ConvertBoolOrOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::OrOp>(context) {}

using OpConversionPattern::OpConversionPattern;
template <typename BinOp, typename TfheRustBoolBinOp>
struct ConvertBinOp : public OpConversionPattern<BinOp> {
using OpConversionPattern<BinOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::OrOp op, OpAdaptor adaptor,
BinOp op, typename BinOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
FailureOr<Value> result = getContextualBoolServerKey(op);
if (failed(result)) return result;

Value serverKey = result.value();

rewriter.replaceOp(op, b.create<tfhe_rust_bool::OrOp>(
rewriter.replaceOp(op, b.create<TfheRustBoolBinOp>(
serverKey, adaptor.getLhs(), adaptor.getRhs()));
return success();
}
};

struct ConvertBoolXorOp : public OpConversionPattern<cggi::XorOp> {
ConvertBoolXorOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::XorOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::XorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
FailureOr<Value> result = getContextualBoolServerKey(op);
if (failed(result)) return result;

Value serverKey = result.value();

rewriter.replaceOp(op, b.create<tfhe_rust_bool::XorOp>(
serverKey, adaptor.getLhs(), adaptor.getRhs()));
return success();
}
};
using ConvertBoolAndOp = ConvertBinOp<cggi::AndOp, tfhe_rust_bool::AndOp>;
using ConvertBoolOrOp = ConvertBinOp<cggi::OrOp, tfhe_rust_bool::OrOp>;
using ConvertBoolXorOp = ConvertBinOp<cggi::XorOp, tfhe_rust_bool::XorOp>;

struct ConvertBoolNotOp : public OpConversionPattern<cggi::NotOp> {
ConvertBoolNotOp(mlir::MLIRContext *context)
Expand All @@ -218,7 +178,7 @@ struct ConvertBoolNotOp : public OpConversionPattern<cggi::NotOp> {
struct ConvertBoolTrivialEncryptOp
: public OpConversionPattern<lwe::TrivialEncryptOp> {
ConvertBoolTrivialEncryptOp(mlir::MLIRContext *context)
: OpConversionPattern<lwe::TrivialEncryptOp>(context, /*benefit=*/2) {}
: OpConversionPattern<lwe::TrivialEncryptOp>(context, /*benefit=*/1) {}

using OpConversionPattern::OpConversionPattern;

Expand All @@ -236,7 +196,7 @@ struct ConvertBoolTrivialEncryptOp
<< op.getInput().getDefiningOp()->getName();
}
auto outputType = tfhe_rust_bool::EncryptedBoolType::get(getContext());
;

auto createTrivialOp = rewriter.create<tfhe_rust_bool::CreateTrivialOp>(
op.getLoc(), outputType, serverKey, encodeOp.getPlaintext());
rewriter.replaceOp(op, createTrivialOp);
Expand Down Expand Up @@ -264,7 +224,7 @@ class CGGIToTfheRustBool
MLIRContext *context = &getContext();
auto *op = getOperation();

BoolPassTypeConverter typeConverter(context);
CGGIToTfheRustBoolTypeConverter typeConverter(context);
RewritePatternSet patterns(context);
ConversionTarget target(*context);
addStructuralConversionPatterns(typeConverter, patterns, target);
Expand Down Expand Up @@ -307,7 +267,7 @@ class CGGIToTfheRustBool
GenericOpPattern<memref::LoadOp>, GenericOpPattern<memref::SubViewOp>,
GenericOpPattern<memref::CopyOp>,
GenericOpPattern<tensor::FromElementsOp>,
GenericOpPattern<tensor::ExtractOp>>(typeConverter, context);
GenericOpPattern<tensor::ExtractOp> >(typeConverter, context);

if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
return signalPassFailure();
Expand Down
15 changes: 0 additions & 15 deletions tests/tfhe_rust_bool/ops.mlir
Expand Up @@ -29,21 +29,6 @@ module {

// CHECK-LABEL: func @test_packed_and
func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%4 = arith.constant 4 : index

%c0 = arith.constant 0 : i1
%c1 = arith.constant 1 : i1

scf.for %i = %0 to %4 step %1 {
%tmp1 = tfhe_rust_bool.create_trivial %bsks, %c0 : (!bsks, i1) -> !eb
%tmp2 = tfhe_rust_bool.create_trivial %bsks, %c1 : (!bsks, i1) -> !eb

tensor.insert %tmp1 into %lhs[%i] : tensor<4x!eb>
tensor.insert %tmp2 into %rhs[%i] : tensor<4x!eb>
}

%out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb>
return
}
Expand Down

0 comments on commit df3856e

Please sign in to comment.