Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][polynomial] implement add for polynomial data structure #92169

Merged
merged 2 commits into from
May 17, 2024

Conversation

j2kun
Copy link
Contributor

@j2kun j2kun commented May 14, 2024

A change extracted from #91655, where I'm still trying to get the attributes working for elementwise constant folding of polynomial ops. This piece is self-contained.

  • use CRTP for base classes
  • Add unit test

- use CRTP for base classes
- Add unit test
@j2kun j2kun requested a review from ftynse May 14, 2024 20:09
@llvmbot llvmbot added the mlir label May 14, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 14, 2024

@llvm/pr-subscribers-mlir

Author: Jeremy Kun (j2kun)

Changes

A change extracted from #91655, where I'm still trying to get the attributes working for elementwise constant folding of polynomial ops. This piece is self-contained.

  • use CRTP for base classes
  • Add unit test

Full diff: https://github.com/llvm/llvm-project/pull/92169.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h (+66-19)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/Polynomial/CMakeLists.txt (+8)
  • (added) mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp (+44)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 7f44c29a98707..45823275ebb33 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -30,7 +30,7 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-template <typename CoefficientType>
+template <class Derived, typename CoefficientType>
 class MonomialBase {
 public:
   MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
     return (exponent.ult(other.exponent));
   }
 
+  Derived add(const Derived &other) {
+    assert(exponent == other.exponent);
+    CoefficientType newCoeff = coefficient + other.coefficient;
+    Derived result;
+    result.setCoefficient(newCoeff);
+    result.setExponent(exponent);
+    return result;
+  }
+
   virtual bool isMonic() const = 0;
   virtual void
   coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
 
 protected:
   CoefficientType coefficient;
@@ -69,7 +78,7 @@ class MonomialBase {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
+class IntMonomial : public MonomialBase<IntMonomial, APInt> {
 public:
   IntMonomial(int64_t coeff, uint64_t expo)
       : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
@@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase<APInt> {
   IntMonomial()
       : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  ~IntMonomial() = default;
+  ~IntMonomial() override = default;
 
   bool isMonic() const override { return coefficient == 1; }
 
@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
 public:
   FloatMonomial(double coeff, uint64_t expo)
       : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
 
   FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
 
-  ~FloatMonomial() = default;
+  ~FloatMonomial() override = default;
 
   bool isMonic() const override { return coefficient == APFloat(1.0); }
 
@@ -104,12 +113,12 @@ class FloatMonomial : public MonomialBase<APFloat> {
   }
 };
 
-template <typename Monomial>
+template <class Derived, typename Monomial>
 class PolynomialBase {
 public:
   PolynomialBase() = delete;
 
-  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
+  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
 
   explicit operator bool() const { return !terms.empty(); }
   bool operator==(const PolynomialBase &other) const {
@@ -149,6 +158,44 @@ class PolynomialBase {
     }
   }
 
+  Derived add(const Derived &other) {
+    SmallVector<Monomial> newTerms;
+    auto it1 = terms.begin();
+    auto it2 = other.terms.begin();
+    while (it1 != terms.end() || it2 != other.terms.end()) {
+      if (it1 == terms.end()) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        continue;
+      }
+
+      if (it2 == other.terms.end()) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        continue;
+      }
+
+      while (it1->getExponent().ult(it2->getExponent())) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        if (it1 == terms.end())
+          break;
+      }
+
+      while (it2->getExponent().ult(it1->getExponent())) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        if (it2 == terms.end())
+          break;
+      }
+
+      newTerms.emplace_back(it1->add(*it2));
+      it1++;
+      it2++;
+    }
+    return Derived(newTerms);
+  }
+
   // Prints polynomial to 'os'.
   void print(raw_ostream &os) const { print(os, " + ", "**"); }
 
@@ -168,8 +215,8 @@ class PolynomialBase {
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
 
 private:
   // The monomial terms for this polynomial.
@@ -179,7 +226,7 @@ class PolynomialBase {
 /// A single-variable polynomial with integer coefficients.
 ///
 /// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
+class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
 public:
   explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
 
@@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
 /// A single-variable polynomial with double coefficients.
 ///
 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
 public:
   explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
       : PolynomialBase(terms) {}
@@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
 };
 
 // Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-template <typename T>
+template <class D, typename T>
 inline raw_ostream &operator<<(raw_ostream &os,
-                               const PolynomialBase<T> &polynomial) {
+                               const PolynomialBase<D, T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(OpenACC)
+add_subdirectory(Polynomial)
 add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000000000..807deeca41c06
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRPolynomialTests
+  PolynomialMathTest.cpp
+)
+target_link_libraries(MLIRPolynomialTests
+  PRIVATE
+  MLIRIR
+  MLIRPolynomialDialect
+)
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
new file mode 100644
index 0000000000000..95906ad42588e
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -0,0 +1,44 @@
+//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
+  IntMonomial term2t = IntMonomial(2, 1);
+  IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}
+
+TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented May 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@j2kun j2kun merged commit e368675 into llvm:main May 17, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants