Skip to content

Commit

Permalink
poly: add ntt and intt ops
Browse files Browse the repository at this point in the history
  - Intoduce ntt/intt ops that convert them to/from tensors with their
    ideal ring attributes encoded
  - These tensors allow for the polynomial to be in the corresponding
    point-value representation for more efficient polynomial
    multiplicaiton

Resolves #182
  • Loading branch information
inbelic committed Mar 29, 2024
1 parent 32353fe commit cb5c66a
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 8 deletions.
12 changes: 4 additions & 8 deletions include/Dialect/LWE/IR/LWEAttributes.td
Expand Up @@ -192,10 +192,8 @@ def RLWE_PolynomialEvaluationEncoding
#lwe_encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

%evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16>
// TODO(#182): fix docs
// Note no `intt` operation exists in poly yet.
%poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>>
```

Expand Down Expand Up @@ -284,10 +282,8 @@ def RLWE_InverseCanonicalEmbeddingEncoding
#lwe_encoding = #lwe.polynomial_evaluation_encoding<cleartext_start=30, cleartext_bitwidth=3>

%evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16>
// TODO(#182): fix docs
// Note no `intt` operation exists in poly yet.
%poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding>
%poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding>
%rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>>
```

Expand Down
46 changes: 46 additions & 0 deletions include/Dialect/Polynomial/IR/PolynomialOps.td
Expand Up @@ -134,4 +134,50 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
}

def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
let summary = "Computes point-value tensor representation of a polynomial.";
let description = [{
`polynomial.ntt` computes the forward integer Number Theoretic Transform
(NTT) on the input polynomial. It returns a tensor containing a point-value
representation of the input polynomial. The output tensor has shape equal to
the degree of the ring's ideal generation polynomial. The polynomial's
RingAttr is embedded as the encoding attribute of the output tensor.

Given an input polynomial $F(x)$ (over a ring with degree $n$) and a
primitive $n$-th root of unity $\omega_n$, the output is the list of $n$
evaluations

$f_k = F(\omega_n^k) ; k \in [0, n)$
The choice of primitive root is determined by subsequent lowerings.
}];

let arguments = (ins Polynomial:$input);

let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";

let hasVerifier = 1;
}

def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
let description = [{
`polynomial.intt` computes the reverse integer Number Theoretic Transform
(INTT) on the input tensor. This is the inverse operation of the
`polynomial.ntt` operation.

The input tensor is interpreted as a point-value representation of the
output polynomial at powers of a primitive $n$-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.
}];

let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);

let results = (outs Polynomial:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";

let hasVerifier = 1;
}

#endif // INCLUDE_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_TD_
44 changes: 44 additions & 0 deletions lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Expand Up @@ -83,6 +83,50 @@ LogicalResult MonomialMulOp::verify() {
"must be of the form (x^n - 1) for some n";
}

static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType) {
auto encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
<< "a ring encoding was not provided to the tensor.";
}
auto encodedRing = dyn_cast<RingAttr>(encoding);
if (!encodedRing) {
return op->emitOpError()
<< "the provided tensor encoding is not a ring attribute.";
}

if (encodedRing != ring) {
return op->emitOpError()
<< "encoded ring type " << encodedRing
<< " is not equivalent to the polynomial ring " << ring << ".";
}

auto polyDegree = ring.getIdeal().getDegree();
auto tensorShape = tensorType.getShape();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (!compatible) {
return op->emitOpError()
<< "tensor type " << tensorType
<< " must be a tensor of shape [d] where d "
<< "is exactly the degree of the ring's ideal " << ring;
}

return success();
}

LogicalResult NTTOp::verify() {
auto ring = getInput().getType().getRing();
auto tensorType = getOutput().getType();
return verifyNTTOp(this->getOperation(), ring, tensorType);
}

LogicalResult INTTOp::verify() {
auto tensorType = getInput().getType();
auto ring = getOutput().getType().getRing();
return verifyNTTOp(this->getOperation(), ring, tensorType);
}

void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
populateWithGenerated(results);
Expand Down
15 changes: 15 additions & 0 deletions tests/polynomial/ops.mlir
Expand Up @@ -8,6 +8,11 @@
#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
#ring1 = #polynomial.ring<cmod=2837465, ideal=#my_poly>
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>

#ideal = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<cmod=18, ideal=#ideal>
!poly_ty = !polynomial.polynomial<#ring>

module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -68,4 +73,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}

func.func @test_ntt(%0 : !poly_ty) {
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
return
}

func.func @test_intt(%0 : tensor<1024xi32, #ring>) {
%1 = polynomial.intt %0 : tensor<1024xi32, #ring> -> !poly_ty
return
}
}

0 comments on commit cb5c66a

Please sign in to comment.