Skip to content

Commit

Permalink
[libc] Add max length argument to decimal to float (#84091)
Browse files Browse the repository at this point in the history
The implementation for from_chars in libcxx is possibly going to use our
decimal to float utilities, but to do that we need to support limiting
the length of the string to be parsed. This patch adds support for that
length limiting to decimal_exp_to_float, as well as the functions it
calls (high precision decimal, str to integer).
  • Loading branch information
michaelrj-google committed Mar 6, 2024
1 parent f8c5a68 commit d34b3c9
Show file tree
Hide file tree
Showing 7 changed files with 418 additions and 108 deletions.
116 changes: 63 additions & 53 deletions libc/src/__support/high_precision_decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H
#define LLVM_LIBC_SRC___SUPPORT_HIGH_PRECISION_DECIMAL_H

#include "src/__support/CPP/limits.h"
#include "src/__support/ctype_utils.h"
#include "src/__support/str_to_integer.h"
#include <stdint.h>
Expand Down Expand Up @@ -115,9 +116,10 @@ class HighPrecisionDecimal {
uint8_t digits[MAX_NUM_DIGITS];

private:
bool should_round_up(int32_t roundToDigit, RoundDirection round) {
if (roundToDigit < 0 ||
static_cast<uint32_t>(roundToDigit) >= this->num_digits) {
LIBC_INLINE bool should_round_up(int32_t round_to_digit,
RoundDirection round) {
if (round_to_digit < 0 ||
static_cast<uint32_t>(round_to_digit) >= this->num_digits) {
return false;
}

Expand All @@ -133,8 +135,8 @@ class HighPrecisionDecimal {
// Else round to nearest.

// If we're right in the middle and there are no extra digits
if (this->digits[roundToDigit] == 5 &&
static_cast<uint32_t>(roundToDigit + 1) == this->num_digits) {
if (this->digits[round_to_digit] == 5 &&
static_cast<uint32_t>(round_to_digit + 1) == this->num_digits) {

// Round up if we've truncated (since that means the result is slightly
// higher than what's represented.)
Expand All @@ -143,22 +145,22 @@ class HighPrecisionDecimal {
}

// If this exactly halfway, round to even.
if (roundToDigit == 0)
if (round_to_digit == 0)
// When the input is ".5".
return false;
return this->digits[roundToDigit - 1] % 2 != 0;
return this->digits[round_to_digit - 1] % 2 != 0;
}
// If there are digits after roundToDigit, they must be non-zero since we
// If there are digits after round_to_digit, they must be non-zero since we
// trim trailing zeroes after all operations that change digits.
return this->digits[roundToDigit] >= 5;
return this->digits[round_to_digit] >= 5;
}

// Takes an amount to left shift and returns the number of new digits needed
// to store the result based on LEFT_SHIFT_DIGIT_TABLE.
uint32_t get_num_new_digits(uint32_t lShiftAmount) {
LIBC_INLINE uint32_t get_num_new_digits(uint32_t lshift_amount) {
const char *power_of_five =
LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].power_of_five;
uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].new_digits;
LEFT_SHIFT_DIGIT_TABLE[lshift_amount].power_of_five;
uint32_t new_digits = LEFT_SHIFT_DIGIT_TABLE[lshift_amount].new_digits;
uint32_t digit_index = 0;
while (power_of_five[digit_index] != 0) {
if (digit_index >= this->num_digits) {
Expand All @@ -176,7 +178,7 @@ class HighPrecisionDecimal {
}

// Trim all trailing 0s
void trim_trailing_zeroes() {
LIBC_INLINE void trim_trailing_zeroes() {
while (this->num_digits > 0 && this->digits[this->num_digits - 1] == 0) {
--this->num_digits;
}
Expand All @@ -186,19 +188,19 @@ class HighPrecisionDecimal {
}

// Perform a digitwise binary non-rounding right shift on this value by
// shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
// overflow.
void right_shift(uint32_t shiftAmount) {
// shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
// prevent overflow.
LIBC_INLINE void right_shift(uint32_t shift_amount) {
uint32_t read_index = 0;
uint32_t write_index = 0;

uint64_t accumulator = 0;

const uint64_t shift_mask = (uint64_t(1) << shiftAmount) - 1;
const uint64_t shift_mask = (uint64_t(1) << shift_amount) - 1;

// Warm Up phase: we don't have enough digits to start writing, so just
// read them into the accumulator.
while (accumulator >> shiftAmount == 0) {
while (accumulator >> shift_amount == 0) {
uint64_t read_digit = 0;
// If there are still digits to read, read the next one, else the digit is
// assumed to be 0.
Expand All @@ -217,7 +219,7 @@ class HighPrecisionDecimal {
// read. Keep reading until we run out of digits.
while (read_index < this->num_digits) {
uint64_t read_digit = this->digits[read_index];
uint64_t write_digit = accumulator >> shiftAmount;
uint64_t write_digit = accumulator >> shift_amount;
accumulator &= shift_mask;
this->digits[write_index] = static_cast<uint8_t>(write_digit);
accumulator = accumulator * 10 + read_digit;
Expand All @@ -228,7 +230,7 @@ class HighPrecisionDecimal {
// Cool Down phase: All of the readable digits have been read, so just write
// the remainder, while treating any more digits as 0.
while (accumulator > 0) {
uint64_t write_digit = accumulator >> shiftAmount;
uint64_t write_digit = accumulator >> shift_amount;
accumulator &= shift_mask;
if (write_index < MAX_NUM_DIGITS) {
this->digits[write_index] = static_cast<uint8_t>(write_digit);
Expand All @@ -243,10 +245,10 @@ class HighPrecisionDecimal {
}

// Perform a digitwise binary non-rounding left shift on this value by
// shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
// overflow.
void left_shift(uint32_t shiftAmount) {
uint32_t new_digits = this->get_num_new_digits(shiftAmount);
// shift_amount. The shift_amount can't be more than MAX_SHIFT_AMOUNT to
// prevent overflow.
LIBC_INLINE void left_shift(uint32_t shift_amount) {
uint32_t new_digits = this->get_num_new_digits(shift_amount);

int32_t read_index = this->num_digits - 1;
uint32_t write_index = this->num_digits + new_digits;
Expand All @@ -260,7 +262,7 @@ class HighPrecisionDecimal {
// writing.
while (read_index >= 0) {
accumulator += static_cast<uint64_t>(this->digits[read_index])
<< shiftAmount;
<< shift_amount;
uint64_t next_accumulator = accumulator / 10;
uint64_t write_digit = accumulator - (10 * next_accumulator);
--write_index;
Expand Down Expand Up @@ -296,45 +298,52 @@ class HighPrecisionDecimal {
}

public:
// numString is assumed to be a string of numeric characters. It doesn't
// num_string is assumed to be a string of numeric characters. It doesn't
// handle leading spaces.
HighPrecisionDecimal(const char *__restrict numString) {
LIBC_INLINE
HighPrecisionDecimal(
const char *__restrict num_string,
const size_t num_len = cpp::numeric_limits<size_t>::max()) {
bool saw_dot = false;
size_t num_cur = 0;
// This counts the digits in the number, even if there isn't space to store
// them all.
uint32_t total_digits = 0;
while (isdigit(*numString) || *numString == '.') {
if (*numString == '.') {
while (num_cur < num_len &&
(isdigit(num_string[num_cur]) || num_string[num_cur] == '.')) {
if (num_string[num_cur] == '.') {
if (saw_dot) {
break;
}
this->decimal_point = total_digits;
saw_dot = true;
} else {
if (*numString == '0' && this->num_digits == 0) {
if (num_string[num_cur] == '0' && this->num_digits == 0) {
--this->decimal_point;
++numString;
++num_cur;
continue;
}
++total_digits;
if (this->num_digits < MAX_NUM_DIGITS) {
this->digits[this->num_digits] =
static_cast<uint8_t>(*numString - '0');
static_cast<uint8_t>(num_string[num_cur] - '0');
++this->num_digits;
} else if (*numString != '0') {
} else if (num_string[num_cur] != '0') {
this->truncated = true;
}
}
++numString;
++num_cur;
}

if (!saw_dot)
this->decimal_point = total_digits;

if ((*numString | 32) == 'e') {
++numString;
if (isdigit(*numString) || *numString == '+' || *numString == '-') {
auto result = strtointeger<int32_t>(numString, 10);
if (num_cur < num_len && ((num_string[num_cur] | 32) == 'e')) {
++num_cur;
if (isdigit(num_string[num_cur]) || num_string[num_cur] == '+' ||
num_string[num_cur] == '-') {
auto result =
strtointeger<int32_t>(num_string + num_cur, 10, num_len - num_cur);
if (result.has_error()) {
// TODO: handle error
}
Expand All @@ -358,33 +367,34 @@ class HighPrecisionDecimal {
this->trim_trailing_zeroes();
}

// Binary shift left (shiftAmount > 0) or right (shiftAmount < 0)
void shift(int shiftAmount) {
if (shiftAmount == 0) {
// Binary shift left (shift_amount > 0) or right (shift_amount < 0)
LIBC_INLINE void shift(int shift_amount) {
if (shift_amount == 0) {
return;
}
// Left
else if (shiftAmount > 0) {
while (static_cast<uint32_t>(shiftAmount) > MAX_SHIFT_AMOUNT) {
else if (shift_amount > 0) {
while (static_cast<uint32_t>(shift_amount) > MAX_SHIFT_AMOUNT) {
this->left_shift(MAX_SHIFT_AMOUNT);
shiftAmount -= MAX_SHIFT_AMOUNT;
shift_amount -= MAX_SHIFT_AMOUNT;
}
this->left_shift(shiftAmount);
this->left_shift(shift_amount);
}
// Right
else {
while (static_cast<uint32_t>(shiftAmount) < -MAX_SHIFT_AMOUNT) {
while (static_cast<uint32_t>(shift_amount) < -MAX_SHIFT_AMOUNT) {
this->right_shift(MAX_SHIFT_AMOUNT);
shiftAmount += MAX_SHIFT_AMOUNT;
shift_amount += MAX_SHIFT_AMOUNT;
}
this->right_shift(-shiftAmount);
this->right_shift(-shift_amount);
}
}

// Round the number represented to the closest value of unsigned int type T.
// This is done ignoring overflow.
template <class T>
T round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
LIBC_INLINE T
round_to_integer_type(RoundDirection round = RoundDirection::Nearest) {
T result = 0;
uint32_t cur_digit = 0;

Expand All @@ -404,10 +414,10 @@ class HighPrecisionDecimal {

// Extra functions for testing.

uint8_t *get_digits() { return this->digits; }
uint32_t get_num_digits() { return this->num_digits; }
int32_t get_decimal_point() { return this->decimal_point; }
void set_truncated(bool trunc) { this->truncated = trunc; }
LIBC_INLINE uint8_t *get_digits() { return this->digits; }
LIBC_INLINE uint32_t get_num_digits() { return this->num_digits; }
LIBC_INLINE int32_t get_decimal_point() { return this->decimal_point; }
LIBC_INLINE void set_truncated(bool trunc) { this->truncated = trunc; }
};

} // namespace internal
Expand Down
40 changes: 25 additions & 15 deletions libc/src/__support/str_to_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,15 @@ constexpr int32_t NUM_POWERS_OF_TWO =
// on the Simple Decimal Conversion algorithm by Nigel Tao, described at this
// link: https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
template <class T>
LIBC_INLINE FloatConvertReturn<T>
simple_decimal_conversion(const char *__restrict numStart,
RoundDirection round = RoundDirection::Nearest) {
LIBC_INLINE FloatConvertReturn<T> simple_decimal_conversion(
const char *__restrict numStart,
const size_t num_len = cpp::numeric_limits<size_t>::max(),
RoundDirection round = RoundDirection::Nearest) {
using FPBits = typename fputil::FPBits<T>;
using StorageType = typename FPBits::StorageType;

int32_t exp2 = 0;
HighPrecisionDecimal hpd = HighPrecisionDecimal(numStart);
HighPrecisionDecimal hpd = HighPrecisionDecimal(numStart, num_len);

FloatConvertReturn<T> output;

Expand Down Expand Up @@ -600,13 +601,17 @@ clinger_fast_path(ExpandedFloat<T> init_num,
// non-inf result for this size of float. The value is
// log10(2^(exponent bias)).
// The generic approximation uses the fact that log10(2^x) ~= x/3
template <typename T> constexpr int32_t get_upper_bound() {
template <typename T> LIBC_INLINE constexpr int32_t get_upper_bound() {
return fputil::FPBits<T>::EXP_BIAS / 3;
}

template <> constexpr int32_t get_upper_bound<float>() { return 39; }
template <> LIBC_INLINE constexpr int32_t get_upper_bound<float>() {
return 39;
}

template <> constexpr int32_t get_upper_bound<double>() { return 309; }
template <> LIBC_INLINE constexpr int32_t get_upper_bound<double>() {
return 309;
}

// The lower bound is the largest negative base-10 exponent that could possibly
// give a non-zero result for this size of float. The value is
Expand All @@ -616,18 +621,18 @@ template <> constexpr int32_t get_upper_bound<double>() { return 309; }
// low base 10 exponent with a very high intermediate mantissa can cancel each
// other out, and subnormal numbers allow for the result to be at the very low
// end of the final mantissa.
template <typename T> constexpr int32_t get_lower_bound() {
template <typename T> LIBC_INLINE constexpr int32_t get_lower_bound() {
using FPBits = typename fputil::FPBits<T>;
return -((FPBits::EXP_BIAS +
static_cast<int32_t>(FPBits::FRACTION_LEN + FPBits::STORAGE_LEN)) /
3);
}

template <> constexpr int32_t get_lower_bound<float>() {
template <> LIBC_INLINE constexpr int32_t get_lower_bound<float>() {
return -(39 + 6 + 10);
}

template <> constexpr int32_t get_lower_bound<double>() {
template <> LIBC_INLINE constexpr int32_t get_lower_bound<double>() {
return -(309 + 15 + 20);
}

Expand All @@ -637,9 +642,10 @@ template <> constexpr int32_t get_lower_bound<double>() {
// accuracy. The resulting mantissa and exponent are placed in outputMantissa
// and outputExp2.
template <class T>
LIBC_INLINE FloatConvertReturn<T>
decimal_exp_to_float(ExpandedFloat<T> init_num, const char *__restrict numStart,
bool truncated, RoundDirection round) {
LIBC_INLINE FloatConvertReturn<T> decimal_exp_to_float(
ExpandedFloat<T> init_num, bool truncated, RoundDirection round,
const char *__restrict numStart,
const size_t num_len = cpp::numeric_limits<size_t>::max()) {
using FPBits = typename fputil::FPBits<T>;
using StorageType = typename FPBits::StorageType;

Expand Down Expand Up @@ -701,7 +707,7 @@ decimal_exp_to_float(ExpandedFloat<T> init_num, const char *__restrict numStart,
#endif // LIBC_COPT_STRTOFLOAT_DISABLE_EISEL_LEMIRE

#ifndef LIBC_COPT_STRTOFLOAT_DISABLE_SIMPLE_DECIMAL_CONVERSION
output = simple_decimal_conversion<T>(numStart, round);
output = simple_decimal_conversion<T>(numStart, num_len, round);
#else
#warning "Simple decimal conversion is disabled, result may not be correct."
#endif // LIBC_COPT_STRTOFLOAT_DISABLE_SIMPLE_DECIMAL_CONVERSION
Expand Down Expand Up @@ -894,6 +900,8 @@ decimal_string_to_float(const char *__restrict src, const char DECIMAL_POINT,
if (!seen_digit)
return output;

// TODO: When adding max length argument, handle the case of a trailing
// EXPONENT MARKER, see scanf for more details.
if (tolower(src[index]) == EXPONENT_MARKER) {
bool has_sign = false;
if (src[index + 1] == '+' || src[index + 1] == '-') {
Expand Down Expand Up @@ -928,7 +936,7 @@ decimal_string_to_float(const char *__restrict src, const char DECIMAL_POINT,
output.value = {0, 0};
} else {
auto temp =
decimal_exp_to_float<T>({mantissa, exponent}, src, truncated, round);
decimal_exp_to_float<T>({mantissa, exponent}, truncated, round, src);
output.value = temp.num;
output.error = temp.error;
}
Expand Down Expand Up @@ -1071,6 +1079,8 @@ nan_mantissa_from_ncharseq(const cpp::string_view ncharseq) {

// Takes a pointer to a string and a pointer to a string pointer. This function
// is used as the backend for all of the string to float functions.
// TODO: Add src_len member to match strtointeger.
// TODO: Next, move from char* and length to string_view
template <class T>
LIBC_INLINE StrToNumResult<T> strtofloatingpoint(const char *__restrict src) {
using FPBits = typename fputil::FPBits<T>;
Expand Down

0 comments on commit d34b3c9

Please sign in to comment.