Skip to content

Commit

Permalink
poly: add ntt and intt ops
Browse files Browse the repository at this point in the history
  - Intoduce ntt/intt ops that convert them to/from tensors with their
    ideal ring attributes encoded
  - These tensors allow for the polynomial to be in the corresponding
    point-value representation for more efficient polynomial
    multiplicaiton

Resolves #182
  • Loading branch information
inbelic committed Mar 27, 2024
1 parent 32353fe commit 2e996c8
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 8 deletions.
12 changes: 4 additions & 8 deletions include/Dialect/LWE/IR/LWEAttributes.td
Expand Up @@ -192,10 +192,8 @@ def RLWE_PolynomialEvaluationEncoding
#lwe_encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

%evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16>
// TODO(#182): fix docs
// Note no `intt` operation exists in poly yet.
%poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>>
```

Expand Down Expand Up @@ -284,10 +282,8 @@ def RLWE_InverseCanonicalEmbeddingEncoding
#lwe_encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

%evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16>
// TODO(#182): fix docs
// Note no `intt` operation exists in poly yet.
%poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>>
```

Expand Down
35 changes: 35 additions & 0 deletions include/Dialect/Polynomial/IR/PolynomialOps.td
Expand Up @@ -134,4 +134,39 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
}

def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
let summary = "Computes point-value tensor representation of a polynomial.";
let description = [{
`polynomial.ntt` creates a tensor value containing the point-value
representation of the input polynomial. The polynomial's RingAttr is
embedded as the encoding attribute.

The output tensor has shape equal to the degree of the ring's ideal
generator polynomial, including zeroes.
}];

let arguments = (ins Polynomial:$input);

let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";

let hasVerifier = 1;
}

def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
let summary = "Computes the polynomial of a point-value representation tensor";
let description = [{
`polynomial.intt` creates a polynomial value from the input tensor
containing the point-value representation. The ring of the polynomial is
taken from the required encoding attribute of the tensor.
}];

let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);

let results = (outs Polynomial:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";

let hasVerifier = 1;
}

#endif // INCLUDE_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_TD_
45 changes: 45 additions & 0 deletions lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Expand Up @@ -83,6 +83,51 @@ LogicalResult MonomialMulOp::verify() {
"must be of the form (x^n - 1) for some n";
}

template <typename Op>
static LogicalResult verifyNTTOp(Op *op, RingAttr ring,
RankedTensorType tensorType) {
auto encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
<< "a ring encoding was not provided to the tensor output";
}
auto encodedRing = dyn_cast<RingAttr>(encoding);
if (!encodedRing) {
return op->emitOpError()
<< "the provided tensor output encoding is not a ring attribute";
}

if (encodedRing != ring) {
return op->emitOpError()
<< "encoded ring type " << encodedRing
<< " is not equivalent to the polynomial ring " << ring << ".";
}

auto polyDegree = ring.getIdeal().getDegree();
auto tensorShape = tensorType.getShape();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (!compatible) {
return op->emitOpError()
<< "tensor type " << tensorType
<< " must be a tensor of shape [d] where d "
<< "is exactly the degree of the ring's ideal " << ring;
}

return success();
}

LogicalResult NTTOp::verify() {
auto ring = getInput().getType().getRing();
auto tensorType = getOutput().getType();
return verifyNTTOp<NTTOp>(this, ring, tensorType);
}

LogicalResult INTTOp::verify() {
auto tensorType = getInput().getType();
auto ring = getOutput().getType().getRing();
return verifyNTTOp<INTTOp>(this, ring, tensorType);
}

void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
populateWithGenerated(results);
Expand Down
15 changes: 15 additions & 0 deletions tests/polynomial/ops.mlir
Expand Up @@ -8,6 +8,11 @@
#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
#ring1 = #polynomial.ring<cmod=2837465, ideal=#my_poly>
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>

#ideal = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<cmod=18, ideal=#ideal>
!poly_ty = !polynomial.polynomial<#ring>

module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -68,4 +73,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}

func.func @test_ntt(%0 : !poly_ty) {
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
return
}

func.func @test_intt(%0 : tensor<1024xi32, #ring>) {
%1 = polynomial.intt %0 : tensor<1024xi32, #ring> -> !poly_ty
return
}
}

0 comments on commit 2e996c8

Please sign in to comment.