Skip to content

Commit

Permalink
Optimize 64x64 extended multiplication implementation
Browse files Browse the repository at this point in the history
Now we use intrinsics when possible, and fallback to optimized
implementation in portable C++. The difference is about 4x when
we can use intrinsics and about 2x when we cannot.

This should speed up our Lemire's algorithm implementation nicely.
  • Loading branch information
horenmar committed Apr 2, 2024
1 parent 9271083 commit ae4baf8
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 45 deletions.
109 changes: 64 additions & 45 deletions src/catch2/internal/catch_random_integer_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,32 @@
#include <cstdint>
#include <type_traits>

// Note: We use the usual enable-disable-autodetect dance here even though
// we do not support these in CMake configuration options (yet?).
// It is highly unlikely that we will need to make these actually
// user-configurable, but this will make it simpler if weend up needing
// it, and it provides an escape hatch to the users who need it.
#if defined( __SIZEOF_INT128__ )
# define CATCH_CONFIG_INTERNAL_UINT128
#elif defined( _MSC_VER ) && ( defined( _WIN64 ) || defined( _M_ARM64 ) )
# define CATCH_CONFIG_INTERNAL_MSVC_UMUL128
#endif

#if defined( CATCH_CONFIG_INTERNAL_UINT128 ) && \
!defined( CATCH_CONFIG_NO_UINT128 ) && \
!defined( CATCH_CONFIG_UINT128 )
#define CATCH_CONFIG_UINT128
#endif

#if defined( CATCH_CONFIG_INTERNAL_MSVC_UMUL128 ) && \
!defined( CATCH_CONFIG_NO_MSVC_UMUL128 ) && \
!defined( CATCH_CONFIG_MSVC_UMUL128 )
# define CATCH_CONFIG_MSVC_UMUL128
# include <intrin.h>
# pragma intrinsic( _umul128 )
#endif


namespace Catch {
namespace Detail {

Expand Down Expand Up @@ -46,59 +72,52 @@ namespace Catch {
}
};

// Returns 128 bit result of multiplying lhs and rhs
/**
* Returns 128 bit result of lhs * rhs using portable C++ code
*
* This implementation is almost twice as fast as naive long multiplication,
* and unlike intrinsic-based approach, it supports constexpr evaluation.
*/
constexpr ExtendedMultResult<std::uint64_t>
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
// We use the simple long multiplication approach for
// correctness, we can use platform specific builtins
// for performance later.

// Split the lhs and rhs into two 32bit "digits", so that we can
// do 64 bit arithmetic to handle carry bits.
// 32b 32b 32b 32b
// lhs L1 L2
// * rhs R1 R2
// ------------------------
// | R2 * L2 |
// | R2 * L1 |
// | R1 * L2 |
// | R1 * L1 |
// -------------------------
// | a | b | c | d |

extendedMultPortable(std::uint64_t lhs, std::uint64_t rhs) {
#define CarryBits( x ) ( x >> 32 )
#define Digits( x ) ( x & 0xFF'FF'FF'FF )

auto r2l2 = Digits( rhs ) * Digits( lhs );
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );

// Sum to columns first
auto d = Digits( r2l2 );
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
auto a = CarryBits( r1l1 );

// Propagate carries between columns
c += CarryBits( d );
b += CarryBits( c );
a += CarryBits( b );

// Remove the used carries
c = Digits( c );
b = Digits( b );
a = Digits( a );

std::uint64_t lhs_low = Digits( lhs );
std::uint64_t rhs_low = Digits( rhs );
std::uint64_t low_low = ( lhs_low * rhs_low );
std::uint64_t high_high = CarryBits( lhs ) * CarryBits( rhs );

// We add in carry bits from low-low already
std::uint64_t high_low =
( CarryBits( lhs ) * rhs_low ) + CarryBits( low_low );
// Note that we can add only low bits from high_low, to avoid
// overflow with large inputs
std::uint64_t low_high =
( lhs_low * CarryBits( rhs ) ) + Digits( high_low );

return { high_high + CarryBits( high_low ) + CarryBits( low_high ),
( low_high << 32 ) | Digits( low_low ) };
#undef CarryBits
#undef Digits
}

return {
a << 32 | b, // upper 64 bits
c << 32 | d // lower 64 bits
};
//! Returns 128 bit result of lhs * rhs
inline ExtendedMultResult<std::uint64_t>
extendedMult( std::uint64_t lhs, std::uint64_t rhs ) {
#if defined( CATCH_CONFIG_UINT128 )
auto result = __uint128_t( lhs ) * __uint128_t( rhs );
return { static_cast<std::uint64_t>( result >> 64 ),
static_cast<std::uint64_t>( result ) };
#elif defined( CATCH_CONFIG_MSVC_UMUL128 )
std::uint64_t high;
std::uint64_t low = _umul128( lhs, rhs, &high );
return { high, low };
#else
return extendedMultPortable( lhs, rhs );
#endif
}


template <typename UInt>
constexpr ExtendedMultResult<UInt> extendedMult( UInt lhs, UInt rhs ) {
static_assert( std::is_unsigned<UInt>::value,
Expand Down
74 changes: 74 additions & 0 deletions tests/SelfTest/IntrospectiveTests/Integer.tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <catch2/catch_test_macros.hpp>
#include <catch2/internal/catch_random_integer_helpers.hpp>
#include <random>

namespace {
template <typename Int>
Expand All @@ -20,6 +21,58 @@ namespace {
CHECK( extendedMult( b, a ) ==
ExtendedMultResult<Int>{ upper_result, lower_result } );
}

// Simple (and slow) implmentation of extended multiplication for tests
constexpr Catch::Detail::ExtendedMultResult<std::uint64_t>
extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) {
// This is a simple long multiplication, where we split lhs and rhs
// into two 32-bit "digits", so that we can do ops with carry in 64-bits.
//
// 32b 32b 32b 32b
// lhs L1 L2
// * rhs R1 R2
// ------------------------
// | R2 * L2 |
// | R2 * L1 |
// | R1 * L2 |
// | R1 * L1 |
// -------------------------
// | a | b | c | d |

#define CarryBits( x ) ( x >> 32 )
#define Digits( x ) ( x & 0xFF'FF'FF'FF )

auto r2l2 = Digits( rhs ) * Digits( lhs );
auto r2l1 = Digits( rhs ) * CarryBits( lhs );
auto r1l2 = CarryBits( rhs ) * Digits( lhs );
auto r1l1 = CarryBits( rhs ) * CarryBits( lhs );

// Sum to columns first
auto d = Digits( r2l2 );
auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 );
auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 );
auto a = CarryBits( r1l1 );

// Propagate carries between columns
c += CarryBits( d );
b += CarryBits( c );
a += CarryBits( b );

// Remove the used carries
c = Digits( c );
b = Digits( b );
a = Digits( a );

#undef CarryBits
#undef Digits

return {
a << 32 | b, // upper 64 bits
c << 32 | d // lower 64 bits
};
}


} // namespace

TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
Expand Down Expand Up @@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) {
0xdf44'2d22'ce48'59b9 );
}

TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") {
using Catch::Detail::extendedMult;
using Catch::Detail::extendedMultPortable;
using Catch::Detail::fillBitsFrom;

std::random_device rng;
for (size_t i = 0; i < 100; ++i) {
auto a = fillBitsFrom<std::uint64_t>( rng );
auto b = fillBitsFrom<std::uint64_t>( rng );
CAPTURE( a, b );

auto naive_ab = extendedMultNaive( a, b );

REQUIRE( naive_ab == extendedMultNaive( b, a ) );
REQUIRE( naive_ab == extendedMultPortable( a, b ) );
REQUIRE( naive_ab == extendedMultPortable( b, a ) );
REQUIRE( naive_ab == extendedMult( a, b ) );
REQUIRE( naive_ab == extendedMult( b, a ) );
}
}

TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) {
using Catch::Detail::SizedUnsignedType_t;
using Catch::Detail::DoubleWidthUnsignedType_t;
Expand Down

0 comments on commit ae4baf8

Please sign in to comment.