From 5bd3bdc65e22262d03ce8f2a27d2f3a97b8fb68e Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Fri, 29 Mar 2024 08:48:44 -0700 Subject: [PATCH] review comments --- .../Dialect/Polynomial/IR/PolynomialOps.td | 31 +++++++++++++------ lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 11 +++---- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/include/Dialect/Polynomial/IR/PolynomialOps.td b/include/Dialect/Polynomial/IR/PolynomialOps.td index 9c6bd6f6f..f97db505b 100644 --- a/include/Dialect/Polynomial/IR/PolynomialOps.td +++ b/include/Dialect/Polynomial/IR/PolynomialOps.td @@ -137,12 +137,18 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> { def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { let summary = "Computes point-value tensor representation of a polynomial."; let description = [{ - `polynomial.ntt` creates a tensor value containing the point-value - representation of the input polynomial. The polynomial's RingAttr is - embedded as the encoding attribute. - - The output tensor has shape equal to the degree of the ring's ideal - generator polynomial, including zeroes. + `polynomial.ntt` computes the forward integer Number Theoretic Transform + (NTT) on the input polynomial. It returns a tensor containing a point-value + representation of the input polynomial. The output tensor has shape equal to + the degree of the ring's ideal generation polynomial. The polynomial's + RingAttr is embedded as the encoding attribute of the output tensor. + + Given an input polynomial $F(x)$ (over a ring with degree $n$) and a + primitive $n$-th root of unity $\omega_n$, the output is the list of $n$ + evaluations + + $f_k = F(\omega_n^k) ; k \in [0, n)$ + The choice of primitive root is determined by subsequent lowerings. }]; let arguments = (ins Polynomial:$input); @@ -154,11 +160,16 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { } def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> { - let summary = "Computes the polynomial of a point-value representation tensor"; + let summary = "Computes the reverse integer Number Theoretic Transform (NTT)."; let description = [{ - `polynomial.intt` creates a polynomial value from the input tensor - containing the point-value representation. The ring of the polynomial is - taken from the required encoding attribute of the tensor. + `polynomial.intt` computes the reverse integer Number Theoretic Transform + (INTT) on the input tensor. This is the inverse operation of the + `polynomial.ntt` operation. + + The input tensor is interpreted as a point-value representation of the + output polynomial at powers of a primitive $n$-th root of unity (see + `polynomial.ntt`). The ring of the polynomial is taken from the required + encoding attribute of the tensor. }]; let arguments = (ins RankedTensorOf<[AnyInteger]>:$input); diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index a514fe5ae..a124499d9 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -83,18 +83,17 @@ LogicalResult MonomialMulOp::verify() { "must be of the form (x^n - 1) for some n"; } -template -static LogicalResult verifyNTTOp(Op *op, RingAttr ring, +static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, RankedTensorType tensorType) { auto encoding = tensorType.getEncoding(); if (!encoding) { return op->emitOpError() - << "a ring encoding was not provided to the tensor output"; + << "a ring encoding was not provided to the tensor."; } auto encodedRing = dyn_cast(encoding); if (!encodedRing) { return op->emitOpError() - << "the provided tensor output encoding is not a ring attribute"; + << "the provided tensor encoding is not a ring attribute."; } if (encodedRing != ring) { @@ -119,13 +118,13 @@ static LogicalResult verifyNTTOp(Op *op, RingAttr ring, LogicalResult NTTOp::verify() { auto ring = getInput().getType().getRing(); auto tensorType = getOutput().getType(); - return verifyNTTOp(this, ring, tensorType); + return verifyNTTOp(this->getOperation(), ring, tensorType); } LogicalResult INTTOp::verify() { auto tensorType = getInput().getType(); auto ring = getOutput().getType().getRing(); - return verifyNTTOp(this, ring, tensorType); + return verifyNTTOp(this->getOperation(), ring, tensorType); } void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,