diff --git a/include/Dialect/LWE/IR/LWEAttributes.td b/include/Dialect/LWE/IR/LWEAttributes.td index b1ce34307..8c33659dd 100644 --- a/include/Dialect/LWE/IR/LWEAttributes.td +++ b/include/Dialect/LWE/IR/LWEAttributes.td @@ -192,10 +192,8 @@ def RLWE_PolynomialEvaluationEncoding #lwe_encoding = #lwe.polynomial_evaluation_encoding %evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16> - // TODO(#182): fix docs - // Note no `intt` operation exists in poly yet. - %poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding> - %poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding> + %poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding> + %poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding> %rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>> ``` @@ -284,10 +282,8 @@ def RLWE_InverseCanonicalEmbeddingEncoding #lwe_encoding = #lwe.polynomial_evaluation_encoding %evals = arith.constant <[1, 2, 4, 5]> : tensor<4xi16> - // TODO(#182): fix docs - // Note no `intt` operation exists in poly yet. - %poly1 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding> - %poly2 = poly.intt %evals : tensor<4xi16> -> !poly.poly<#ring, #eval_encoding> + %poly1 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding> + %poly2 = poly.intt %evals : tensor<4xi16, #ring> -> !poly.poly<#ring, #eval_encoding> %rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!poly.poly<#ring, #eval_encoding>> ``` diff --git a/include/Dialect/Polynomial/IR/PolynomialOps.td b/include/Dialect/Polynomial/IR/PolynomialOps.td index 0feaa7083..f97db505b 100644 --- a/include/Dialect/Polynomial/IR/PolynomialOps.td +++ b/include/Dialect/Polynomial/IR/PolynomialOps.td @@ -134,4 +134,50 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> { let assemblyFormat = "$input attr-dict `:` qualified(type($output))"; } +def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { + let summary = "Computes point-value tensor representation of a polynomial."; + let description = [{ + `polynomial.ntt` computes the forward integer Number Theoretic Transform + (NTT) on the input polynomial. It returns a tensor containing a point-value + representation of the input polynomial. The output tensor has shape equal to + the degree of the ring's ideal generation polynomial. The polynomial's + RingAttr is embedded as the encoding attribute of the output tensor. + + Given an input polynomial $F(x)$ (over a ring with degree $n$) and a + primitive $n$-th root of unity $\omega_n$, the output is the list of $n$ + evaluations + + $f_k = F(\omega_n^k) ; k \in [0, n)$ + The choice of primitive root is determined by subsequent lowerings. + }]; + + let arguments = (ins Polynomial:$input); + + let results = (outs RankedTensorOf<[AnyInteger]>:$output); + let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; + + let hasVerifier = 1; +} + +def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> { + let summary = "Computes the reverse integer Number Theoretic Transform (NTT)."; + let description = [{ + `polynomial.intt` computes the reverse integer Number Theoretic Transform + (INTT) on the input tensor. This is the inverse operation of the + `polynomial.ntt` operation. + + The input tensor is interpreted as a point-value representation of the + output polynomial at powers of a primitive $n$-th root of unity (see + `polynomial.ntt`). The ring of the polynomial is taken from the required + encoding attribute of the tensor. + }]; + + let arguments = (ins RankedTensorOf<[AnyInteger]>:$input); + + let results = (outs Polynomial:$output); + let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; + + let hasVerifier = 1; +} + #endif // INCLUDE_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_TD_ diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 61c984f1b..a124499d9 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -83,6 +83,50 @@ LogicalResult MonomialMulOp::verify() { "must be of the form (x^n - 1) for some n"; } +static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, + RankedTensorType tensorType) { + auto encoding = tensorType.getEncoding(); + if (!encoding) { + return op->emitOpError() + << "a ring encoding was not provided to the tensor."; + } + auto encodedRing = dyn_cast(encoding); + if (!encodedRing) { + return op->emitOpError() + << "the provided tensor encoding is not a ring attribute."; + } + + if (encodedRing != ring) { + return op->emitOpError() + << "encoded ring type " << encodedRing + << " is not equivalent to the polynomial ring " << ring << "."; + } + + auto polyDegree = ring.getIdeal().getDegree(); + auto tensorShape = tensorType.getShape(); + bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree; + if (!compatible) { + return op->emitOpError() + << "tensor type " << tensorType + << " must be a tensor of shape [d] where d " + << "is exactly the degree of the ring's ideal " << ring; + } + + return success(); +} + +LogicalResult NTTOp::verify() { + auto ring = getInput().getType().getRing(); + auto tensorType = getOutput().getType(); + return verifyNTTOp(this->getOperation(), ring, tensorType); +} + +LogicalResult INTTOp::verify() { + auto tensorType = getInput().getType(); + auto ring = getOutput().getType().getRing(); + return verifyNTTOp(this->getOperation(), ring, tensorType); +} + void SubOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { populateWithGenerated(results); diff --git a/tests/polynomial/ops.mlir b/tests/polynomial/ops.mlir index 3527b77e5..7d20256f1 100644 --- a/tests/polynomial/ops.mlir +++ b/tests/polynomial/ops.mlir @@ -8,6 +8,11 @@ #my_poly_4 = #polynomial.polynomial #ring1 = #polynomial.ring #one_plus_x_squared = #polynomial.polynomial<1 + x**2> + +#ideal = #polynomial.polynomial<-1 + x**1024> +#ring = #polynomial.ring +!poly_ty = !polynomial.polynomial<#ring> + module { func.func @test_multiply() -> !polynomial.polynomial<#ring1> { %c0 = arith.constant 0 : index @@ -68,4 +73,14 @@ module { %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1> return } + + func.func @test_ntt(%0 : !poly_ty) { + %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring> + return + } + + func.func @test_intt(%0 : tensor<1024xi32, #ring>) { + %1 = polynomial.intt %0 : tensor<1024xi32, #ring> -> !poly_ty + return + } }