/* * Copyright (c) 2021, Leon Albrecht * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include #include #include #include #include #include #include #include namespace AK { template requires(sizeof(T) >= sizeof(u64) && IsUnsigned) class UFixedBigInt; // FIXME: This breaks formatting // template // constexpr inline bool Detail::IsIntegral> = true; template constexpr inline bool IsUnsigned> = true; template constexpr inline bool IsSigned> = false; template struct NumericLimits> { static constexpr UFixedBigInt min() { return 0; } static constexpr UFixedBigInt max() { return { NumericLimits::max(), NumericLimits::max() }; } static constexpr bool is_signed() { return false; } }; template struct UFixedBigIntMultiplicationResult { T low; T high; }; template requires(sizeof(T) >= sizeof(u64) && IsUnsigned) class UFixedBigInt { public: using R = UFixedBigInt; constexpr UFixedBigInt() = default; template requires(sizeof(T) >= sizeof(U)) constexpr UFixedBigInt(U low) : m_low(low) , m_high(0u) { } template requires(sizeof(T) >= sizeof(U) && sizeof(T) >= sizeof(U2)) constexpr UFixedBigInt(U low, U2 high) : m_low(low) , m_high(high) { } constexpr T& low() { return m_low; } constexpr T const& low() const { return m_low; } constexpr T& high() { return m_high; } constexpr T const& high() const { return m_high; } Span bytes() { return Span(reinterpret_cast(this), sizeof(R)); } Span bytes() const { return Span(reinterpret_cast(this), sizeof(R)); } template requires(sizeof(T) >= sizeof(U)) constexpr explicit operator U() const { return static_cast(m_low); } // Utils constexpr size_t clz() const requires(IsSame) { if (m_high) return count_leading_zeroes(m_high); else return sizeof(T) * 8 + count_leading_zeroes(m_low); } constexpr size_t clz() const requires(!IsSame) { if (m_high) return m_high.clz(); else return sizeof(T) * 8 + m_low.clz(); } constexpr size_t ctz() const requires(IsSame) { if (m_low) return count_trailing_zeroes(m_low); else return sizeof(T) * 8 + count_trailing_zeroes(m_high); } constexpr size_t ctz() const requires(!IsSame) { if (m_low) return m_low.ctz(); else return sizeof(T) * 8 + m_high.ctz(); } constexpr size_t popcnt() const requires(IsSame) { return __builtin_popcntll(m_low) + __builtin_popcntll(m_high); } constexpr size_t popcnt() const requires(!IsSame) { return m_low.popcnt() + m_high.popcnt(); } // Comparison Operations constexpr bool operator!() const { return !m_low && !m_high; } constexpr explicit operator bool() const { return m_low || m_high; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator==(U const& other) const { return !m_high && m_low == other; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator!=(U const& other) const { return m_high || m_low != other; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator>(U const& other) const { return m_high || m_low > other; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator<(U const& other) const { return !m_high && m_low < other; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator>=(U const& other) const { return *this == other || *this > other; } template requires(sizeof(T) >= sizeof(U)) constexpr bool operator<=(U const& other) const { return *this == other || *this < other; } constexpr bool operator==(R const& other) const { return m_low == other.low() && m_high == other.high(); } constexpr bool operator!=(R const& other) const { return m_low != other.low() || m_high != other.high(); } constexpr bool operator>(R const& other) const { return m_high > other.high() || (m_high == other.high() && m_low > other.low()); } constexpr bool operator<(R const& other) const { return m_high < other.high() || (m_high == other.high() && m_low < other.low()); } constexpr bool operator>=(R const& other) const { return *this == other || *this > other; } constexpr bool operator<=(R const& other) const { return *this == other || *this < other; } // Bitwise operations constexpr R operator~() const { return { ~m_low, ~m_high }; } template requires(sizeof(T) >= sizeof(U)) constexpr U operator&(U const& other) const { return static_cast(m_low) & other; } template requires(sizeof(T) >= sizeof(U)) constexpr R operator|(U const& other) const { return { m_low | other, m_high }; } template requires(sizeof(T) >= sizeof(U)) constexpr R operator^(U const& other) const { return { m_low ^ other, m_high }; } template constexpr R operator<<(U const& shift) const { if (shift >= sizeof(R) * 8u) return 0u; if (shift >= sizeof(T) * 8u) return R { 0u, m_low << (shift - sizeof(T) * 8u) }; if (!shift) return *this; T overflow = m_low >> (sizeof(T) * 8u - shift); return R { m_low << shift, (m_high << shift) | overflow }; } template constexpr R operator>>(U const& shift) const { if (shift >= sizeof(R) * 8u) return 0u; if (shift >= sizeof(T) * 8u) return m_high >> (shift - sizeof(T) * 8u); if (!shift) return *this; T underflow = m_high << (sizeof(T) * 8u - shift); return R { (m_low >> shift) | underflow, m_high >> shift }; } template constexpr R rol(U const& shift) const { return (*this >> sizeof(T) * 8u - shift) | (*this << shift); } template constexpr R ror(U const& shift) const { return (*this << sizeof(T) * 8u - shift) | (*this >> shift); } constexpr R operator&(R const& other) const { return { m_low & other.low(), m_high & other.high() }; } constexpr R operator|(R const& other) const { return { m_low | other.low(), m_high | other.high() }; } constexpr R operator^(R const& other) const { return { m_low ^ other.low(), m_high ^ other.high() }; } // Bitwise assignment template requires(sizeof(T) >= sizeof(U)) constexpr R& operator&=(U const& other) { m_high = 0u; m_low &= other; return *this; } template requires(sizeof(T) >= sizeof(U)) constexpr R& operator|=(U const& other) { m_low |= other; return *this; } template requires(sizeof(T) >= sizeof(U)) constexpr R& operator^=(U const& other) { m_low ^= other; return *this; } template constexpr R& operator>>=(U const& other) { *this = *this >> other; return *this; } template constexpr R& operator<<=(U const& other) { *this = *this << other; return *this; } constexpr R& operator&=(R const& other) { m_high &= other.high(); m_low &= other.low(); return *this; } constexpr R& operator|=(R const& other) { m_high |= other.high(); m_low |= other.low(); return *this; } constexpr R& operator^=(R const& other) { m_high ^= other.high(); m_low ^= other.low(); return *this; } static constexpr size_t my_size() { return sizeof(R); } // Arithmetic // implies size of less than u64, so passing references isn't useful template requires(sizeof(T) >= sizeof(U) && IsSame) constexpr R addc(const U other, bool& carry) const { bool low_carry = Checked::addition_would_overflow(m_low, other); low_carry |= Checked::addition_would_overflow(m_low, carry); bool high_carry = Checked::addition_would_overflow(m_high, low_carry); T lower = m_low + other + carry; T higher = m_high + low_carry; carry = high_carry; return { lower, higher }; } template requires(my_size() > sizeof(U) && sizeof(T) > sizeof(u64)) constexpr R addc(U const& other, bool& carry) const { T lower = m_low.addc(other, carry); T higher = m_high.addc(0u, carry); return { lower, higher }; } template requires(IsSame && IsSame) constexpr R addc(U const& other, bool& carry) const { bool low_carry = Checked::addition_would_overflow(m_low, other.low()); bool high_carry = Checked::addition_would_overflow(m_high, other.high()); T lower = m_low + other.low(); T higher = m_high + other.high(); low_carry |= Checked::addition_would_overflow(lower, carry); high_carry |= Checked::addition_would_overflow(higher, low_carry); lower += carry; higher += low_carry; carry = high_carry; return { lower, higher }; } template requires(IsSame && sizeof(T) > sizeof(u64)) constexpr R addc(U const& other, bool& carry) const { T lower = m_low.addc(other.low(), carry); T higher = m_high.addc(other.high(), carry); return { lower, higher }; } template requires(my_size() < sizeof(U)) constexpr U addc(U const& other, bool& carry) const { return other.addc(*this, carry); } // FIXME: subc for sizeof(T) < sizeof(U) template requires(sizeof(T) >= sizeof(U)) constexpr R subc(U const& other, bool& carry) const { bool low_carry = (!m_low && carry) || (m_low - carry) < other; bool high_carry = !m_high && low_carry; T lower = m_low - other - carry; T higher = m_high - low_carry; carry = high_carry; return { lower, higher }; } constexpr R subc(R const& other, bool& carry) const { bool low_carry = (!m_low && carry) || (m_low - carry) < other.low(); bool high_carry = (!m_high && low_carry) || (m_high - low_carry) < other.high(); T lower = m_low - other.low() - carry; T higher = m_high - other.high() - low_carry; carry = high_carry; return { lower, higher }; } constexpr R operator+(bool const& other) const { bool carry = false; // unused return addc((u8)other, carry); } template constexpr R operator+(U const& other) const { bool carry = false; // unused return addc(other, carry); } constexpr R operator-(bool const& other) const { bool carry = false; // unused return subc((u8)other, carry); } template constexpr R operator-(U const& other) const { bool carry = false; // unused return subc(other, carry); } template constexpr R& operator+=(U const& other) { *this = *this + other; return *this; } template constexpr R& operator-=(U const& other) { *this = *this - other; return *this; } constexpr R operator++() { // x++ auto old = *this; *this += 1; return old; } constexpr R& operator++(int) { // ++x *this += 1; return *this; } constexpr R operator--() { // x-- auto old = *this; *this -= 1; return old; } constexpr R& operator--(int) { // --x *this -= 1; return *this; } // FIXME: no restraints on this template requires(my_size() >= sizeof(U)) constexpr R div_mod(U const& divisor, U& remainder) const { // FIXME: Is there a better way to raise a division by 0? // Maybe as a compiletime warning? #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdiv-by-zero" if (!divisor) { int volatile x = 1; int volatile y = 0; [[maybe_unused]] int volatile z = x / y; } #pragma GCC diagnostic pop // fastpaths if (*this < divisor) { remainder = static_cast(*this); return 0u; } if (*this == divisor) { remainder = 0u; return 1u; } if (divisor == 1u) { remainder = 0u; return *this; } remainder = 0u; R quotient = 0u; for (ssize_t i = sizeof(R) * 8 - clz() - 1; i >= 0; --i) { remainder <<= 1u; remainder |= (*this >> (size_t)i) & 1u; if (remainder >= divisor) { remainder -= divisor; quotient |= R { 1u } << (size_t)i; } } return quotient; } template constexpr R operator*(U other) const { R res = 0u; R that = *this; for (; other != 0u; other >>= 1u) { if (other & 1u) res += that; that <<= 1u; } return res; } template requires(IsSame && IsSame) constexpr UFixedBigIntMultiplicationResult wide_multiply(U const& other) const { auto mult_64_to_128 = [](u64 a, u64 b) -> UFixedBigIntMultiplicationResult { #ifdef __SIZEOF_INT128__ unsigned __int128 result = (unsigned __int128)a * b; u64 low = result; u64 high = result >> 64; return { low, high }; #else u32 a_low = a; u32 a_high = (a >> 32); u32 b_low = b; u32 b_high = (b >> 32); u64 ll_result = (u64)a_low * b_low; u64 lh_result = (u64)a_low * b_high; u64 hl_result = (u64)a_high * b_low; u64 hh_result = (u64)a_high * b_high; UFixedBigInt ll { ll_result, 0u }; UFixedBigInt lh { lh_result << 32, lh_result >> 32 }; UFixedBigInt hl { hl_result << 32, hl_result >> 32 }; UFixedBigInt hh { 0u, hh_result }; UFixedBigInt result = ll + lh + hl + hh; return { result.low(), result.high() }; #endif }; auto ll_result = mult_64_to_128(m_low, other.low()); auto lh_result = mult_64_to_128(m_low, other.high()); auto hl_result = mult_64_to_128(m_high, other.low()); auto hh_result = mult_64_to_128(m_high, other.high()); UFixedBigInt ll { R { ll_result.low, ll_result.high }, R { 0u, 0u } }; UFixedBigInt lh { R { 0u, lh_result.low }, R { lh_result.high, 0u } }; UFixedBigInt hl { R { 0u, hl_result.low }, R { hl_result.high, 0u } }; UFixedBigInt hh { R { 0u, 0u }, R { hh_result.low, hh_result.high } }; UFixedBigInt result = ll + lh + hl + hh; return { result.low(), result.high() }; } template requires(IsSame && sizeof(T) > sizeof(u64)) constexpr UFixedBigIntMultiplicationResult wide_multiply(U const& other) const { T left_low = m_low; T left_high = m_high; T right_low = other.low(); T right_high = other.high(); auto ll_result = left_low.wide_multiply(right_low); auto lh_result = left_low.wide_multiply(right_high); auto hl_result = left_high.wide_multiply(right_low); auto hh_result = left_high.wide_multiply(right_high); UFixedBigInt ll { R { ll_result.low, ll_result.high }, R { 0u, 0u } }; UFixedBigInt lh { R { 0u, lh_result.low }, R { lh_result.high, 0u } }; UFixedBigInt hl { R { 0u, hl_result.low }, R { hl_result.high, 0u } }; UFixedBigInt hh { R { 0u, 0u }, R { hh_result.low, hh_result.high } }; UFixedBigInt result = ll + lh + hl + hh; return { result.low(), result.high() }; } template constexpr R operator/(U const& other) const { U mod { 0u }; // unused return div_mod(other, mod); } template constexpr U operator%(U const& other) const { R res { 0u }; div_mod(other, res); return res; } template constexpr R& operator*=(U const& other) { *this = *this * other; return *this; } template constexpr R& operator/=(U const& other) { *this = *this / other; return *this; } template constexpr R& operator%=(U const& other) { *this = *this % other; return *this; } constexpr R sqrt() const { // Bitwise method: https://en.wikipedia.org/wiki/Integer_square_root#Using_bitwise_operations // the bitwise method seems to be way faster then Newtons: // https://quick-bench.com/q/eXZwW1DVhZxLE0llumeCXkfOK3Q if (*this == 1u) return 1u; ssize_t shift = (sizeof(R) * 8 - clz()) & ~1ULL; // should be equivalent to: // long shift = 2; // while ((val >> shift) != 0) // shift += 2; R res = 0u; while (shift >= 0) { res = res << 1u; R large_cand = (res | 1u); if (*this >> (size_t)shift >= large_cand * large_cand) res = large_cand; shift -= 2; } return res; } constexpr R pow(u64 exp) { // Montgomery's Ladder Technique // https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Montgomery's_ladder_technique R x1 = *this; R x2 = *this * *this; u64 exp_copy = exp; for (ssize_t i = sizeof(u64) * 8 - count_leading_zeroes(exp) - 2; i >= 0; --i) { if (exp_copy & 1u) { x2 *= x1; x1 *= x1; } else { x1 *= x2; x2 *= x2; } exp_copy >>= 1u; } return x1; } template requires(sizeof(U) > sizeof(u64)) constexpr R pow(U exp) { // Montgomery's Ladder Technique // https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Montgomery's_ladder_technique R x1 = *this; R x2 = *this * *this; U exp_copy = exp; for (ssize_t i = sizeof(U) * 8 - exp().clz() - 2; i >= 0; --i) { if (exp_copy & 1u) { x2 *= x1; x1 *= x1; } else { x1 *= x2; x2 *= x2; } exp_copy >>= 1u; } return x1; } template constexpr U pow_mod(u64 exp, U mod) { // Left to right binary method: // https://en.wikipedia.org/wiki/Modular_exponentiation#Left-to-right_binary_method // FIXME: this is not sidechanel proof if (!mod) return 0u; U res = 1; u64 exp_copy = exp; for (size_t i = sizeof(u64) - count_leading_zeroes(exp) - 1u; i < exp; ++i) { res *= res; res %= mod; if (exp_copy & 1u) { res = (*this * res) % mod; } exp_copy >>= 1u; } return res; } template requires(sizeof(ExpT) > sizeof(u64)) constexpr U pow_mod(ExpT exp, U mod) { // Left to right binary method: // https://en.wikipedia.org/wiki/Modular_exponentiation#Left-to-right_binary_method // FIXME: this is not side channel proof if (!mod) return 0u; U res = 1; ExpT exp_copy = exp; for (size_t i = sizeof(ExpT) - exp.clz() - 1u; i < exp; ++i) { res *= res; res %= mod; if (exp_copy & 1u) { res = (*this * res) % mod; } exp_copy >>= 1u; } return res; } constexpr size_t log2() { // FIXME: proper rounding return sizeof(R) - clz(); } constexpr size_t logn(u64 base) { // FIXME: proper rounding return log2() / (sizeof(u64) - count_leading_zeroes(base)); } template requires(sizeof(U) > sizeof(u64)) constexpr size_t logn(U base) { // FIXME: proper rounding return log2() / base.log2(); } constexpr u64 fold_or() const requires(IsSame) { return m_low | m_high; } constexpr u64 fold_or() const requires(!IsSame) { return m_low.fold_or() | m_high.fold_or(); } constexpr bool is_zero_constant_time() const { return fold_or() == 0; } constexpr u64 fold_xor_pair(R& other) const requires(IsSame) { return (m_low ^ other.low()) | (m_high ^ other.high()); } constexpr u64 fold_xor_pair(R& other) const requires(!IsSame) { return (m_low.fold_xor_pair(other.low())) | (m_high.fold_xor_pair(other.high())); } constexpr bool is_equal_to_constant_time(R& other) { return fold_xor_pair(other) == 0; } private: T m_low; T m_high; }; // reverse operators template requires(sizeof(U) < sizeof(T) * 2) constexpr bool operator<(const U a, UFixedBigInt const& b) { return b >= a; } template requires(sizeof(U) < sizeof(T) * 2) constexpr bool operator>(const U a, UFixedBigInt const& b) { return b <= a; } template requires(sizeof(U) < sizeof(T) * 2) constexpr bool operator<=(const U a, UFixedBigInt const& b) { return b > a; } template requires(sizeof(U) < sizeof(T) * 2) constexpr bool operator>=(const U a, UFixedBigInt const& b) { return b < a; } template struct Formatter> : StandardFormatter { Formatter() = default; explicit Formatter(StandardFormatter formatter) : StandardFormatter(formatter) { } ErrorOr format(FormatBuilder& builder, UFixedBigInt value) { if (m_precision.has_value()) VERIFY_NOT_REACHED(); if (m_mode == Mode::Pointer) { // these are way to big for a pointer VERIFY_NOT_REACHED(); } if (m_mode == Mode::Default) m_mode = Mode::Hexadecimal; if (!value.high()) { Formatter formatter { *this }; return formatter.format(builder, value.low()); } u8 base = 0; if (m_mode == Mode::Binary) { base = 2; } else if (m_mode == Mode::BinaryUppercase) { base = 2; } else if (m_mode == Mode::Octal) { TODO(); } else if (m_mode == Mode::Decimal) { TODO(); } else if (m_mode == Mode::Hexadecimal) { base = 16; } else if (m_mode == Mode::HexadecimalUppercase) { base = 16; } else { VERIFY_NOT_REACHED(); } ssize_t width = m_width.value_or(0); ssize_t lower_length = ceil_div(sizeof(T) * 8, (ssize_t)base); Formatter formatter { *this }; formatter.m_width = max(width - lower_length, (ssize_t)0); TRY(formatter.format(builder, value.high())); TRY(builder.put_literal("'"sv)); formatter.m_zero_pad = true; formatter.m_alternative_form = false; formatter.m_width = lower_length; TRY(formatter.format(builder, value.low())); return {}; } }; } // Nit: Doing these as custom classes might be faster, especially when writing // then in SSE, but this would cause a lot of Code duplication and due to // the nature of constexprs and the intelligence of the compiler they might // be using SSE/MMX either way // these sizes should suffice for most usecases using u128 = AK::UFixedBigInt; using u256 = AK::UFixedBigInt; using u512 = AK::UFixedBigInt; using u1024 = AK::UFixedBigInt; using u2048 = AK::UFixedBigInt; using u4096 = AK::UFixedBigInt;