Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc] Refactor BigInt #86137

Merged
merged 19 commits into from Apr 4, 2024
Merged

Conversation

gchatelet
Copy link
Contributor

@gchatelet gchatelet commented Mar 21, 2024

This patch moves most of the multiprecision logic to the multiword
namespace and simplifies some logic in BigInt. It also fully
implements the mask and count functions and increases test coverage.

math_extras.h is also reworked to make it more concise.

Copy link

github-actions bot commented Mar 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@gchatelet gchatelet force-pushed the add_multiprecision_facilities branch 3 times, most recently from 9521ad8 to 71117a9 Compare March 21, 2024 19:43
@gchatelet gchatelet changed the title [libc] Add a Multiprecision type to math_extras [libc] Cleanup BigInt Mar 21, 2024
@gchatelet gchatelet force-pushed the add_multiprecision_facilities branch from 71117a9 to 85dafaa Compare March 21, 2024 20:14
@gchatelet gchatelet changed the title [libc] Cleanup BigInt [libc] Refactor BigInt Mar 21, 2024
@gchatelet gchatelet force-pushed the add_multiprecision_facilities branch 5 times, most recently from ea454d9 to eef78be Compare March 22, 2024 09:04
@gchatelet gchatelet marked this pull request as ready for review March 22, 2024 09:09
@llvmbot llvmbot added libc bazel "Peripheral" support tier build system: utils/bazel labels Mar 22, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-libc

Author: Guillaume Chatelet (gchatelet)

Changes

This patch moves most of the multiprecision logic to the multiword
namespace and simplifies some logic in BigInt. It also fully
implements the mask and count functions and increases test coverage.

math_extras.h is also reworked to make it more concise.


Patch is 72.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86137.diff

9 Files Affected:

  • (modified) libc/src/__support/FPUtil/dyadic_float.h (+3-3)
  • (modified) libc/src/__support/UInt.h (+511-545)
  • (modified) libc/src/__support/float_to_string.h (+4-3)
  • (modified) libc/src/__support/integer_literals.h (+9)
  • (modified) libc/src/__support/math_extras.h (+66-174)
  • (modified) libc/src/__support/number_pair.h (-11)
  • (modified) libc/test/src/__support/math_extras_test.cpp (+57)
  • (modified) libc/test/src/__support/uint_test.cpp (+186-1)
  • (modified) utils/bazel/llvm-project-overlay/libc/test/src/__support/BUILD.bazel (+1)
diff --git a/libc/src/__support/FPUtil/dyadic_float.h b/libc/src/__support/FPUtil/dyadic_float.h
index 73fd7381c3c838..e0c205f52383ba 100644
--- a/libc/src/__support/FPUtil/dyadic_float.h
+++ b/libc/src/__support/FPUtil/dyadic_float.h
@@ -58,9 +58,9 @@ template <size_t Bits> struct DyadicFloat {
   // significant bit.
   LIBC_INLINE constexpr DyadicFloat &normalize() {
     if (!mantissa.is_zero()) {
-      int shift_length = static_cast<int>(mantissa.clz());
+      int shift_length = cpp::countl_zero(mantissa);
       exponent -= shift_length;
-      mantissa.shift_left(static_cast<size_t>(shift_length));
+      mantissa <<= static_cast<size_t>(shift_length);
     }
     return *this;
   }
@@ -233,7 +233,7 @@ LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
     result.sign = a.sign;
     result.exponent = a.exponent;
     result.mantissa = a.mantissa;
-    if (result.mantissa.add(b.mantissa)) {
+    if (result.mantissa.add_overflow(b.mantissa)) {
       // Mantissa addition overflow.
       result.shift_right(1);
       result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
diff --git a/libc/src/__support/UInt.h b/libc/src/__support/UInt.h
index df01e081e3c19e..96e346915721e7 100644
--- a/libc/src/__support/UInt.h
+++ b/libc/src/__support/UInt.h
@@ -17,7 +17,7 @@
 #include "src/__support/macros/attributes.h"   // LIBC_INLINE
 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
 #include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128, LIBC_TYPES_HAS_INT64
-#include "src/__support/math_extras.h" // SumCarry, DiffBorrow
+#include "src/__support/math_extras.h" // add_with_carry, sub_with_borrow
 #include "src/__support/number_pair.h"
 
 #include <stddef.h> // For size_t
@@ -25,64 +25,275 @@
 
 namespace LIBC_NAMESPACE {
 
-namespace internal {
-template <typename T> struct half_width;
+namespace multiword {
 
-template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
-template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+// A type trait mapping unsigned integers to their half-width unsigned
+// counterparts.
+template <typename T> struct half_width;
 template <> struct half_width<uint16_t> : cpp::type_identity<uint8_t> {};
+template <> struct half_width<uint32_t> : cpp::type_identity<uint16_t> {};
+#ifdef LIBC_TYPES_HAS_INT64
+template <> struct half_width<uint64_t> : cpp::type_identity<uint32_t> {};
 #ifdef LIBC_TYPES_HAS_INT128
 template <> struct half_width<__uint128_t> : cpp::type_identity<uint64_t> {};
 #endif // LIBC_TYPES_HAS_INT128
-
+#endif // LIBC_TYPES_HAS_INT64
 template <typename T> using half_width_t = typename half_width<T>::type;
 
-template <typename T> constexpr NumberPair<T> full_mul(T a, T b) {
-  NumberPair<T> pa = split(a);
-  NumberPair<T> pb = split(b);
-  NumberPair<T> prod;
+// An array of two elements that can be used in multiword operations.
+template <typename T> struct Double final : cpp::array<T, 2> {
+  using UP = cpp::array<T, 2>;
+  using UP::UP;
+  LIBC_INLINE constexpr Double(T lo, T hi) : UP({lo, hi}) {}
+};
 
-  prod.lo = pa.lo * pb.lo;                    // exact
-  prod.hi = pa.hi * pb.hi;                    // exact
-  NumberPair<T> lo_hi = split(pa.lo * pb.hi); // exact
-  NumberPair<T> hi_lo = split(pa.hi * pb.lo); // exact
+// Converts an unsigned value into a Double<half_width_t<T>>.
+template <typename T> LIBC_INLINE constexpr auto split(T value) {
+  static_assert(cpp::is_unsigned_v<T>);
+  return cpp::bit_cast<Double<half_width_t<T>>>(value);
+}
 
-  constexpr size_t HALF_BIT_WIDTH = sizeof(T) * CHAR_BIT / 2;
+// The low part of a Double value.
+template <typename T> LIBC_INLINE constexpr T lo(const Double<T> &value) {
+  return value[0];
+}
+// The high part of a Double value.
+template <typename T> LIBC_INLINE constexpr T hi(const Double<T> &value) {
+  return value[1];
+}
+// The low part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> lo(T value) {
+  return lo(split(value));
+}
+// The high part of an unsigned value.
+template <typename T> LIBC_INLINE constexpr half_width_t<T> hi(T value) {
+  return hi(split(value));
+}
 
-  auto r1 = add_with_carry(prod.lo, lo_hi.lo << HALF_BIT_WIDTH, T(0));
-  prod.lo = r1.sum;
-  prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
+// Returns 'a' times 'b' in a Double<word>. Cannot overflow by definition.
+template <typename word>
+LIBC_INLINE constexpr Double<word> mul2(word a, word b) {
+  if constexpr (cpp::is_same_v<word, uint8_t>) {
+    return split<uint16_t>(uint16_t(a) * uint16_t(b));
+  } else if constexpr (cpp::is_same_v<word, uint16_t>) {
+    return split<uint32_t>(uint32_t(a) * uint32_t(b));
+  }
+#ifdef LIBC_TYPES_HAS_INT64
+  else if constexpr (cpp::is_same_v<word, uint32_t>) {
+    return split<uint64_t>(uint64_t(a) * uint64_t(b));
+  }
+#endif
+#ifdef LIBC_TYPES_HAS_INT128
+  else if constexpr (cpp::is_same_v<word, uint64_t>) {
+    return split<__uint128_t>(__uint128_t(a) * __uint128_t(b));
+  }
+#endif
+  else {
+    using half_word = half_width_t<word>;
+    const auto shiftl = [](word value) -> word {
+      return value << cpp::numeric_limits<half_word>::digits;
+    };
+    const auto shiftr = [](word value) -> word {
+      return value >> cpp::numeric_limits<half_word>::digits;
+    };
+    // Here we do a one digit multiplication where 'a' and 'b' are of type
+    // word. We split 'a' and 'b' into half words and perform the classic long
+    // multiplication with 'a' and 'b' being two-digit numbers.
+
+    //    a      a_hi a_lo
+    //  x b => x b_hi b_lo
+    // ----    -----------
+    //    c         result
+    // We convert 'lo' and 'hi' from 'half_word' to 'word' so multiplication
+    // doesn't overflow.
+    const word a_lo = lo(a);
+    const word b_lo = lo(b);
+    const word a_hi = hi(a);
+    const word b_hi = hi(b);
+    const word step1 = b_lo * a_lo; // no overflow;
+    const word step2 = b_lo * a_hi; // no overflow;
+    const word step3 = b_hi * a_lo; // no overflow;
+    const word step4 = b_hi * a_hi; // no overflow;
+    word lo_digit = step1;
+    word hi_digit = step4;
+    const word zero_carry = 0;
+    word carry;
+    const auto add_with_carry = LIBC_NAMESPACE::add_with_carry<word>;
+    lo_digit = add_with_carry(lo_digit, shiftl(step2), zero_carry, &carry);
+    hi_digit = add_with_carry(hi_digit, shiftr(step2), carry, nullptr);
+    lo_digit = add_with_carry(lo_digit, shiftl(step3), zero_carry, &carry);
+    hi_digit = add_with_carry(hi_digit, shiftr(step3), carry, nullptr);
+    return Double<word>(lo_digit, hi_digit);
+  }
+}
 
-  auto r2 = add_with_carry(prod.lo, hi_lo.lo << HALF_BIT_WIDTH, T(0));
-  prod.lo = r2.sum;
-  prod.hi = add_with_carry(prod.hi, hi_lo.hi, r2.carry).sum;
+// Inplace binary operation with carry propagation. Returns carry.
+template <typename Function, typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word inplace_binop(Function op_with_carry,
+                                         cpp::array<word, N> &dst,
+                                         const cpp::array<word, M> &rhs) {
+  static_assert(N >= M);
+  word carry_out = 0;
+  for (size_t i = 0; i < N; ++i) {
+    const bool has_rhs_value = i < M;
+    const word rhs_value = has_rhs_value ? rhs[i] : 0;
+    const word carry_in = carry_out;
+    dst[i] = op_with_carry(dst[i], rhs_value, carry_in, &carry_out);
+    // stop early when rhs is over and no carry is to be propagated.
+    if (!has_rhs_value && carry_out == 0)
+      break;
+  }
+  return carry_out;
+}
 
-  return prod;
+// Inplace addition. Returns carry.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word add_with_carry(cpp::array<word, N> &dst,
+                                          const cpp::array<word, M> &rhs) {
+  return inplace_binop(LIBC_NAMESPACE::add_with_carry<word>, dst, rhs);
 }
 
-template <>
-LIBC_INLINE constexpr NumberPair<uint32_t> full_mul<uint32_t>(uint32_t a,
-                                                              uint32_t b) {
-  uint64_t prod = uint64_t(a) * uint64_t(b);
-  NumberPair<uint32_t> result;
-  result.lo = uint32_t(prod);
-  result.hi = uint32_t(prod >> 32);
-  return result;
+// Inplace subtraction. Returns borrow.
+template <typename word, size_t N, size_t M>
+LIBC_INLINE constexpr word sub_with_borrow(cpp::array<word, N> &dst,
+                                           const cpp::array<word, M> &rhs) {
+  return inplace_binop(LIBC_NAMESPACE::sub_with_borrow<word>, dst, rhs);
 }
 
-#ifdef LIBC_TYPES_HAS_INT128
-template <>
-LIBC_INLINE constexpr NumberPair<uint64_t> full_mul<uint64_t>(uint64_t a,
-                                                              uint64_t b) {
-  __uint128_t prod = __uint128_t(a) * __uint128_t(b);
-  NumberPair<uint64_t> result;
-  result.lo = uint64_t(prod);
-  result.hi = uint64_t(prod >> 64);
-  return result;
+// Inplace multiply-add. Returns carry.
+// i.e., 'dst += b x c'
+template <typename word, size_t N>
+LIBC_INLINE constexpr word mad_with_carry(cpp::array<word, N> &dst, word b,
+                                          word c) {
+  return add_with_carry(dst, mul2(b, c));
 }
-#endif // LIBC_TYPES_HAS_INT128
 
-} // namespace internal
+// An array of two elements serving as an accumulator during multiword
+// computations.
+template <typename T> struct Accumulator final : cpp::array<T, 2> {
+  using UP = cpp::array<T, 2>;
+  LIBC_INLINE constexpr Accumulator() : UP({0, 0}) {}
+  LIBC_INLINE constexpr T advance(T carry_in) {
+    auto result = UP::front();
+    UP::front() = UP::back();
+    UP::back() = carry_in;
+    return result;
+  }
+  LIBC_INLINE constexpr T sum() const { return UP::front(); }
+  LIBC_INLINE constexpr T carry() const { return UP::back(); }
+};
+
+// Inplace multiplication by a single word. Returns carry.
+template <typename word, size_t N>
+LIBC_INLINE constexpr word scalar_multiply_with_carry(cpp::array<word, N> &dst,
+                                                      word x) {
+  Accumulator<word> acc;
+  for (auto &val : dst) {
+    const word carry = mad_with_carry(acc, val, x);
+    val = acc.advance(carry);
+  }
+  return acc.carry();
+}
+
+// Multiplication of 'lhs' by 'rhs' into 'dst'. Returns carry.
+// This function is safe to use for signed numbers.
+// https://stackoverflow.com/a/20793834
+// https://pages.cs.wisc.edu/%7Emarkhill/cs354/Fall2008/beyond354/int.mult.html
+template <typename word, size_t O, size_t M, size_t N>
+LIBC_INLINE constexpr word multiply_with_carry(cpp::array<word, O> &dst,
+                                               const cpp::array<word, M> &lhs,
+                                               const cpp::array<word, N> &rhs) {
+  static_assert(O >= M + N);
+  Accumulator<word> acc;
+  for (size_t i = 0; i < O; ++i) {
+    const size_t lower_idx = i < N ? 0 : i - N + 1;
+    const size_t upper_idx = i < M ? i : M - 1;
+    word carry = 0;
+    for (size_t j = lower_idx; j <= upper_idx; ++j)
+      carry += mad_with_carry(acc, lhs[j], rhs[i - j]);
+    dst[i] = acc.advance(carry);
+  }
+  return acc.carry();
+}
+
+template <typename word, size_t N>
+LIBC_INLINE constexpr void quick_mul_hi(cpp::array<word, N> &dst,
+                                        const cpp::array<word, N> &lhs,
+                                        const cpp::array<word, N> &rhs) {
+  Accumulator<word> acc;
+  word carry = 0;
+  // First round of accumulation for those at N - 1 in the full product.
+  for (size_t i = 0; i < N; ++i)
+    carry += mad_with_carry(acc, lhs[i], rhs[N - 1 - i]);
+  for (size_t i = N; i < 2 * N - 1; ++i) {
+    acc.advance(carry);
+    carry = 0;
+    for (size_t j = i - N + 1; j < N; ++j)
+      carry += mad_with_carry(acc, lhs[j], rhs[i - j]);
+    dst[i - N] = acc.sum();
+  }
+  dst.back() = acc.carry();
+}
+
+// An enum for the shift function below.
+enum Direction { LEFT, RIGHT };
+
+// A bitwise shift on an array of elements.
+template <Direction direction, typename word, size_t N>
+LIBC_INLINE constexpr void shift(cpp::array<word, N> &array, size_t offset) {
+  constexpr size_t WORD_BITS = cpp::numeric_limits<word>::digits;
+  constexpr size_t TOTAL_BITS = N * WORD_BITS;
+  if (offset == 0)
+    return;
+  if (offset >= TOTAL_BITS) {
+    array = {};
+    return;
+  }
+  const auto at = [&](size_t index) -> int {
+    // reverse iteration when direction == LEFT.
+    if constexpr (direction == LEFT)
+      return int(N) - int(index) - 1;
+    return int(index);
+  };
+  const auto safe_get_at = [&](size_t index) -> word {
+    // return 0 when accessing out of bound elements.
+    const int i = at(index);
+    return i >= 0 && i < int(N) ? array[i] : 0;
+  };
+  const size_t index_offset = offset / WORD_BITS;
+  const size_t bit_offset = offset % WORD_BITS;
+  for (size_t index = 0; index < N; ++index) {
+    const word part1 = safe_get_at(index + index_offset);
+    const word part2 = safe_get_at(index + index_offset + 1);
+    word &dst = array[at(index)];
+    if (bit_offset == 0)
+      dst = part1; // no crosstalk between parts.
+    else if constexpr (direction == RIGHT)
+      dst = (part1 >> bit_offset) | (part2 << (WORD_BITS - bit_offset));
+    else if constexpr (direction == LEFT)
+      dst = (part1 << bit_offset) | (part2 >> (WORD_BITS - bit_offset));
+  }
+}
+
+#define DECLARE_COUNTBIT(NAME, INDEX_EXPR)                                     \
+  template <typename word, size_t N>                                           \
+  LIBC_INLINE constexpr int NAME(const cpp::array<word, N> &val) {             \
+    int bit_count = 0;                                                         \
+    for (size_t i = 0; i < N; ++i) {                                           \
+      const int word_count = cpp::NAME<word>(val[INDEX_EXPR]);                 \
+      bit_count += word_count;                                                 \
+      if (word_count != cpp::numeric_limits<word>::digits)                     \
+        break;                                                                 \
+    }                                                                          \
+    return bit_count;                                                          \
+  }
+
+DECLARE_COUNTBIT(countr_zero, i)         // iterating forward
+DECLARE_COUNTBIT(countr_one, i)          // iterating forward
+DECLARE_COUNTBIT(countl_zero, N - i - 1) // iterating backward
+DECLARE_COUNTBIT(countl_one, N - i - 1)  // iterating backward
+
+} // namespace multiword
 
 template <size_t Bits, bool Signed, typename WordType = uint64_t>
 struct BigInt {
@@ -90,6 +301,9 @@ struct BigInt {
                 "WordType must be unsigned integer.");
 
   using word_type = WordType;
+  using unsigned_type = BigInt<Bits, false, word_type>;
+  using signed_type = BigInt<Bits, true, word_type>;
+
   LIBC_INLINE_VAR static constexpr bool SIGNED = Signed;
   LIBC_INLINE_VAR static constexpr size_t BITS = Bits;
   LIBC_INLINE_VAR
@@ -100,10 +314,7 @@ struct BigInt {
 
   LIBC_INLINE_VAR static constexpr size_t WORD_COUNT = Bits / WORD_SIZE;
 
-  using unsigned_type = BigInt<BITS, false, word_type>;
-  using signed_type = BigInt<BITS, true, word_type>;
-
-  cpp::array<WordType, WORD_COUNT> val{};
+  cpp::array<WordType, WORD_COUNT> val{}; // zero initialized.
 
   LIBC_INLINE constexpr BigInt() = default;
 
@@ -112,76 +323,67 @@ struct BigInt {
   template <size_t OtherBits, bool OtherSigned>
   LIBC_INLINE constexpr BigInt(
       const BigInt<OtherBits, OtherSigned, WordType> &other) {
-    if (OtherBits >= Bits) {
+    if (OtherBits >= Bits) { // truncate
       for (size_t i = 0; i < WORD_COUNT; ++i)
         val[i] = other[i];
-    } else {
+    } else { // zero or sign extend
       size_t i = 0;
       for (; i < OtherBits / WORD_SIZE; ++i)
         val[i] = other[i];
-      WordType sign = 0;
-      if constexpr (Signed && OtherSigned) {
-        sign = static_cast<WordType>(
-            -static_cast<cpp::make_signed_t<WordType>>(other.is_neg()));
-      }
-      for (; i < WORD_COUNT; ++i)
-        val[i] = sign;
+      extend(i, Signed && other.is_neg());
     }
   }
 
   // Construct a BigInt from a C array.
-  template <size_t N, cpp::enable_if_t<N <= WORD_COUNT, int> = 0>
-  LIBC_INLINE constexpr BigInt(const WordType (&nums)[N]) {
-    size_t min_wordcount = N < WORD_COUNT ? N : WORD_COUNT;
-    size_t i = 0;
-    for (; i < min_wordcount; ++i)
+  template <size_t N> LIBC_INLINE constexpr BigInt(const WordType (&nums)[N]) {
+    static_assert(N == WORD_COUNT);
+    for (size_t i = 0; i < WORD_COUNT; ++i)
       val[i] = nums[i];
+  }
 
-    // If nums doesn't completely fill val, then fill the rest with zeroes.
-    for (; i < WORD_COUNT; ++i)
-      val[i] = 0;
+  LIBC_INLINE constexpr explicit BigInt(
+      const cpp::array<WordType, WORD_COUNT> &words) {
+    val = words;
   }
 
   // Initialize the first word to |v| and the rest to 0.
   template <typename T, typename = cpp::enable_if_t<cpp::is_integral_v<T>>>
   LIBC_INLINE constexpr BigInt(T v) {
-    val[0] = static_cast<WordType>(v);
-
-    if constexpr (WORD_COUNT == 1)
-      return;
-
-    if constexpr (Bits < sizeof(T) * CHAR_BIT) {
-      for (int i = 1; i < WORD_COUNT; ++i) {
-        v >>= WORD_SIZE;
-        val[i] = static_cast<WordType>(v);
+    constexpr size_t T_SIZE = sizeof(T) * CHAR_BIT;
+    const bool is_neg = Signed && (v < 0);
+    for (size_t i = 0; i < WORD_COUNT; ++i) {
+      if (v == 0) {
+        extend(i, is_neg);
+        return;
       }
-      return;
-    }
-
-    size_t i = 1;
-
-    if constexpr (WORD_SIZE < sizeof(T) * CHAR_BIT)
-      for (; i < sizeof(T) * CHAR_BIT / WORD_SIZE; ++i) {
+      val[i] = static_cast<WordType>(v);
+      if constexpr (T_SIZE > WORD_SIZE)
         v >>= WORD_SIZE;
-        val[i] = static_cast<WordType>(v);
-      }
-
-    WordType sign = (Signed && (v < 0)) ? ~WordType(0) : WordType(0);
-    for (; i < WORD_COUNT; ++i) {
-      val[i] = sign;
+      else
+        v = 0;
     }
   }
+  LIBC_INLINE constexpr BigInt &operator=(const BigInt &other) = default;
 
-  LIBC_INLINE constexpr explicit BigInt(
-      const cpp::array<WordType, WORD_COUNT> &words) {
-    for (size_t i = 0; i < WORD_COUNT; ++i)
-      val[i] = words[i];
+  // constants
+  LIBC_INLINE static constexpr BigInt zero() { return BigInt(); }
+  LIBC_INLINE static constexpr BigInt one() { return BigInt(1); }
+  LIBC_INLINE static constexpr BigInt all_ones() { return ~zero(); }
+  LIBC_INLINE static constexpr BigInt min() {
+    BigInt out;
+    if constexpr (SIGNED)
+      out.set_msb();
+    return out;
+  }
+  LIBC_INLINE static constexpr BigInt max() {
+    BigInt out = all_ones();
+    if constexpr (SIGNED)
+      out.clear_msb();
+    return out;
   }
 
   // TODO: Reuse the Sign type.
-  LIBC_INLINE constexpr bool is_neg() const {
-    return val.back() >> (WORD_SIZE - 1);
-  }
+  LIBC_INLINE constexpr bool is_neg() const { return SIGNED && get_msb(); }
 
   template <typename T> LIBC_INLINE constexpr explicit operator T() const {
     return to<T>();
@@ -191,200 +393,100 @@ struct BigInt {
   LIBC_INLINE constexpr cpp::enable_if_t<
       cpp::is_integral_v<T> && !cpp::is_same_v<T, bool>, T>
   to() const {
+    constexpr size_t T_SIZE = sizeof(T) * CHAR_BIT;
     T lo = static_cast<T>(val[0]);
-
-    constexpr size_t T_BITS = sizeof(T) * CHAR_BIT;
-
-    if constexpr (T_BITS <= WORD_SIZE)
+    if constexpr (T_SIZE <= WORD_SIZE)
       return lo;
-
     constexpr size_t MAX_COUNT =
-        T_BITS > Bits ? WORD_COUNT : T_BITS / WORD_SIZE;
+        T_SIZE > Bits ? WORD_COUNT : T_SIZE / WORD_SIZE;
     for (size_t i = 1; i < MAX_COUNT; ++i)
       lo += static_cast<T>(val[i]) << (WORD_SIZE * i);
-
-    if constexpr (Signed && (T_BITS > Bits)) {
+    if constexpr (Signed && (T_SIZE > Bits)) {
       // Extend sign for negative numbers.
       constexpr T MASK = (~T(0) << Bits);
       if (is_neg())
         lo |= MASK;
     }
-
     return lo;
   }
 
   LIBC_INLINE constexpr explicit operator bool() const { return !is_zero(); }
 
-  LIBC_INLINE constexpr BigInt &operator=(const BigInt &other) = default;
-
   LIBC_INLINE constexpr bool is_zero() const {
-    for (size_t i = 0; i < WORD_COUNT; ++i) {
-      if (val[i] != 0)
+    for (auto part : val)
+      if (part != 0)
         return false;
-    }
     return true;
   }
 
-  // Add x to this number and store the result in this number.
+  // Ad...
[truncated]

Copy link
Contributor

@legrosbuffle legrosbuffle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass of comments. Can you run benchmarks to check that the generated code is still good ?

libc/src/__support/integer_literals.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/test/src/__support/uint_test.cpp Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
prod.lo = r1.sum;
prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
// Returns 'a' times 'b' in a Double<word>. Cannot overflow by definition.
template <typename word>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/word/T/ ? This does not have to be a word.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer sticking to word instead of T here because we're in the multiword namespace and this is classic terminology for this kind of operation. I can change if you feel strongly about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be words (plural) then? So you have mul2<4>(a, b); implying a and b are 4 words? Do we capitalize template parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really a type not a number. But I prefer to keep word instead of T because the code refers to half_word or double_wide<word>.

libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
Copy link

✅ With the latest revision this PR passed the Python code formatter.

@gchatelet
Copy link
Contributor Author

gchatelet commented Mar 25, 2024

First pass of comments. Can you run benchmarks to check that the generated code is still good ?

I met with @lntue who helped me benchmark the new implementation. I've used the CORE-MATH test suite and made sure to disable all fast paths in the implementations. The following numbers are reverse throughput for 20 trials and count N=10000 (lowered tenfold to keep benchmarking time under control). The results are very reproducible and show a 20% hit so something is wrong. I'll report back as soon as I've pinpointed the issue.

Before - c5f839bd58e7f888acc4cb39a18e9e5bbaa9fb0a
exp    Ntrial = 20 ; Min = 668.971 + 10.352 clc/call; Median-Min = 3.276 clc/call; Max = 693.620 clc/call;
exp2   Ntrial = 20 ; Min = 591.189 + 15.880 clc/call; Median-Min = 18.427 clc/call; Max = 628.805 clc/call;
exp10  Ntrial = 20 ; Min = 686.535 + 5.688 clc/call; Median-Min = 2.962 clc/call; Max = 697.168 clc/call;
expm1  Ntrial = 20 ; Min = 845.390 + 9.590 clc/call; Median-Min = 10.446 clc/call; Max = 876.777 clc/call;
log    Ntrial = 20 ; Min = 419.107 + 4.236 clc/call; Median-Min = 3.245 clc/call; Max = 428.442 clc/call;
log2   Ntrial = 20 ; Min = 400.071 + 1.709 clc/call; Median-Min = 0.920 clc/call; Max = 465.417 clc/call;
log10  Ntrial = 20 ; Min = 409.805 + 2.859 clc/call; Median-Min = 2.379 clc/call; Max = 415.131 clc/call;
log1p  Ntrial = 20 ; Min = 657.658 + 9.524 clc/call; Median-Min = 12.002 clc/call; Max = 672.413 clc/call;
After - 60b37004bfeab35af89075269fad8e9bdeaa483b
exp    Ntrial = 20 ; Min = 820.570 + 12.091 clc/call; Median-Min = 12.470 clc/call; Max = 865.002 clc/call;
exp2   Ntrial = 20 ; Min = 697.665 + 9.523 clc/call; Median-Min = 5.665 clc/call; Max = 717.686 clc/call;
exp10  Ntrial = 20 ; Min = 810.953 + 8.040 clc/call; Median-Min = 2.217 clc/call; Max = 829.657 clc/call;
expm1  Ntrial = 20 ; Min = 1006.518 + 17.347 clc/call; Median-Min = 20.243 clc/call; Max = 1039.603 clc/call;
log    Ntrial = 20 ; Min = 515.516 + 11.423 clc/call; Median-Min = 6.194 clc/call; Max = 584.475 clc/call;
log2   Ntrial = 20 ; Min = 496.955 + 6.045 clc/call; Median-Min = 1.590 clc/call; Max = 508.410 clc/call;
log10  Ntrial = 20 ; Min = 505.182 + 11.358 clc/call; Median-Min = 13.359 clc/call; Max = 519.998 clc/call;
log1p  Ntrial = 20 ; Min = 843.250 + 7.807 clc/call; Median-Min = 6.602 clc/call; Max = 858.043 clc/call;

edit: I've identified the culprit. The slowdown comes an optimization for shifts that I did not port, namely

#ifdef LIBC_TYPES_HAS_INT128
if constexpr ((Bits == 128) && (WORD_SIZE == 64)) {
// Use builtin 128 bits if available;
if (s >= 128) {
val[0] = 0;
val[1] = 0;
return;
}
__uint128_t tmp = __uint128_t(val[0]) + (__uint128_t(val[1]) << 64);
tmp <<= s;
val[0] = uint64_t(tmp);
val[1] = uint64_t(tmp >> 64);
return;
}

With the optimization in, the new version is 20 to 40% faster.

exp    Ntrial = 20 ; Min = 519.567 + 23.705 clc/call; Median-Min = 31.757 clc/call; Max = 555.568 clc/call;
exp2   Ntrial = 20 ; Min = 481.905 + 18.692 clc/call; Median-Min = 3.772 clc/call; Max = 542.986 clc/call;
exp10  Ntrial = 20 ; Min = 529.096 + 17.568 clc/call; Median-Min = 20.023 clc/call; Max = 566.512 clc/call;
expm1  Ntrial = 20 ; Min = 655.869 + 13.370 clc/call; Median-Min = 3.861 clc/call; Max = 681.159 clc/call;
log    Ntrial = 20 ; Min = 289.876 + 1.324 clc/call; Median-Min = 0.924 clc/call; Max = 312.580 clc/call;
log2   Ntrial = 20 ; Min = 306.178 + 2.070 clc/call; Median-Min = 0.928 clc/call; Max = 331.935 clc/call;
log10  Ntrial = 20 ; Min = 314.928 + 1.400 clc/call; Median-Min = 1.058 clc/call; Max = 317.728 clc/call;
log1p  Ntrial = 20 ; Min = 400.284 + 1.718 clc/call; Median-Min = 1.535 clc/call; Max = 408.017 clc/call;

This shows that it's probably better to have DyadicFloat work with __uint128 directly rather than BigInt<128, false, uint64_t> (this patch is a step in that direction).

libc/src/__support/UInt.h Outdated Show resolved Hide resolved
prod.lo = r1.sum;
prod.hi = add_with_carry(prod.hi, lo_hi.hi, r1.carry).sum;
// Returns 'a' times 'b' in a Double<word>. Cannot overflow by definition.
template <typename word>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be words (plural) then? So you have mul2<4>(a, b); implying a and b are 4 words? Do we capitalize template parameters?

libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Show resolved Hide resolved
libc/src/__support/UInt.h Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Outdated Show resolved Hide resolved
libc/src/__support/math_extras.h Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
libc/src/__support/UInt.h Outdated Show resolved Hide resolved
@@ -244,18 +244,30 @@ LIBC_INLINE constexpr bool is_negative(cpp::array<word, N> &array) {
enum Direction { LEFT, RIGHT };

// A bitwise shift on an array of elements.
// TODO: Make the result UB when 'offset' is greater or equal to the number of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or similar to other ones, we can have both safe and unsafe shifts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or make the case of offset >= TOTAL_BITS into a LIBC_ASSERT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do this as a separate PR because some tests and code seem to depend on this behavior.

@gchatelet gchatelet force-pushed the add_multiprecision_facilities branch from bd8ca98 to 6aef60e Compare April 4, 2024 08:17
@gchatelet gchatelet merged commit a2306b6 into llvm:main Apr 4, 2024
4 checks passed
@gchatelet gchatelet deleted the add_multiprecision_facilities branch April 4, 2024 08:27
gchatelet added a commit that referenced this pull request Apr 4, 2024
gchatelet added a commit that referenced this pull request Apr 4, 2024
Reverts #86137

Some aarch64 compilers seem to consider that `uint128_t` is not
`is_trivially_constructible` which prevents `bit_cast`-ing.
gchatelet added a commit that referenced this pull request Apr 4, 2024
This is a reland of #86137 with a fix for platforms / compiler that do
not support trivially constructible int128 types.
gchatelet added a commit to gchatelet/llvm-project that referenced this pull request Apr 5, 2024
…emantics.

This patch removes the test for cases where the shift operand is greater or equal to the bit width of the number. This is done for two reasons, first it makes `BigInt` consistent with regular integral bitwise shift semantics, and second it makes the shift operation faster. The shift operation is on the critical path for `exp` and `log` operations, see llvm#86137 (comment).
gchatelet added a commit to gchatelet/llvm-project that referenced this pull request Apr 6, 2024
…emantics.

This patch removes the test for cases where the shift operand is greater or equal to the bit width of the number. This is done for two reasons, first it makes `BigInt` consistent with regular integral bitwise shift semantics, and second it makes the shift operation faster. The shift operation is on the critical path for `exp` and `log` operations, see llvm#86137 (comment).
gchatelet added a commit to gchatelet/llvm-project that referenced this pull request Apr 6, 2024
…emantics.

This patch removes the test for cases where the shift operand is greater or equal to the bit width of the number. This is done for two reasons, first it makes `BigInt` consistent with regular integral bitwise shift semantics, and second it makes the shift operation faster. The shift operation is on the critical path for `exp` and `log` operations, see llvm#86137 (comment).
gchatelet added a commit that referenced this pull request Apr 6, 2024
…emantics. (#87874)

This patch removes the test for cases where the shift operand is greater
or equal to the bit width of the number. This is done for two reasons,
first it makes `BigInt` consistent with regular integral bitwise shift
semantics, and second it makes the shift operation faster. The shift
operation is on the critical path for `exp` and `log` operations, see
#86137 (comment).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel libc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants