diff --git a/rational/rational_number.hpp b/rational/rational_number.hpp index cf65aad3..7400a792 100644 --- a/rational/rational_number.hpp +++ b/rational/rational_number.hpp @@ -1,26 +1,29 @@ #pragma once +#include #include // Rational number + {infinity(1 / 0), -infiity(-1 / 0), nan(0 / 0)} (有理数) -// Do not compare any number with nan // Verified: Yandex Cup 2022 Final E https://contest.yandex.com/contest/42710/problems/K -template struct Rational { - Int num, den; +template struct Rational { + Int num, den; // den >= 0 + static constexpr Int my_gcd(Int a, Int b) { // return __gcd(a, b); if (a < 0) a = -a; if (b < 0) b = -b; while (a and b) { - if (a > b) + if (a > b) { a %= b; - else + } else { b %= a; + } } return a + b; } + constexpr Rational(Int num = 0, Int den = 1) : num(num), den(den) { normalize(); } constexpr void normalize() noexcept { - if constexpr (AlwaysReduce) { // reduction + if constexpr (AutoReduce) { // reduction Int g = my_gcd(num, den); if (g) num /= g, den /= g; } else { @@ -31,12 +34,16 @@ template struct Rational { } if (den < 0) num = -num, den = -den; // denominator >= 0 } + + constexpr bool is_finite() const noexcept { return den != 0; } + constexpr bool is_infinite_or_nan() const noexcept { return den == 0; } + constexpr Rational operator+(const Rational &r) const noexcept { - if (!den and !r.den) return Rational(num + r.num, den); + if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num + r.num, 0); return Rational(num * r.den + den * r.num, den * r.den); } constexpr Rational operator-(const Rational &r) const noexcept { - if (!den and !r.den) return Rational(num - r.num, den); + if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num - r.num, 0); return Rational(num * r.den - den * r.num, den * r.den); } constexpr Rational operator*(const Rational &r) const noexcept { @@ -51,23 +58,36 @@ template struct Rational { constexpr Rational &operator/=(const Rational &r) noexcept { return *this = *this / r; } constexpr Rational operator-() const noexcept { return Rational(-num, den); } constexpr Rational abs() const noexcept { return Rational(num > 0 ? num : -num, den); } + + constexpr Int floor() const { + assert(is_finite()); + if (num > 0) { + return num / den; + } else { + return -((-num + den - 1) / den); + } + } + constexpr bool operator==(const Rational &r) const noexcept { - if constexpr (AlwaysReduce) { + if (is_infinite_or_nan() or r.is_infinite_or_nan()) { return num == r.num and den == r.den; } else { return num * r.den == r.num * den; } } + constexpr bool operator!=(const Rational &r) const noexcept { return !(*this == r); } + constexpr bool operator<(const Rational &r) const noexcept { - if (den == 0 and r.den == 0) + if (is_infinite_or_nan() and r.is_infinite_or_nan()) return num < r.num; - else if (den == 0) + else if (is_infinite_or_nan()) { return num < 0; - else if (r.den == 0) + } else if (r.is_infinite_or_nan()) { return r.num > 0; - else + } else { return num * r.den < den * r.num; + } } constexpr bool operator<=(const Rational &r) const noexcept { return (*this == r) or (*this < r); @@ -76,6 +96,7 @@ template struct Rational { constexpr bool operator>=(const Rational &r) const noexcept { return (r == *this) or (r < *this); } + constexpr explicit operator double() const noexcept { return (double)num / (double)den; } constexpr explicit operator long double() const noexcept { return (long double)num / (long double)den; diff --git a/utilities/floor_sum.hpp b/utilities/floor_sum.hpp index 451226af..ada8ab05 100644 --- a/utilities/floor_sum.hpp +++ b/utilities/floor_sum.hpp @@ -1,21 +1,23 @@ #pragma once #include -// CUT begin // \sum_{i=0}^{n-1} floor((ai + b) / m) -// 0 <= n < 2e32 -// 1 <= m < 2e32 +// 0 <= n < 2e32 (if Int is long long) +// 1 <= m < 2e32 (if Int is long long) // 0 <= a, b < m // Complexity: O(lg(m)) -long long floor_sum(long long n, long long m, long long a, long long b) { - auto safe_mod = [](long long x, long long m) -> long long { +template Int floor_sum(Int n, Int m, Int a, Int b) { + static_assert(-Int(1) < 0, "Int must be signed"); + static_assert(-Unsigned(1) > 0, "Unsigned must be unsigned"); + static_assert(sizeof(Unsigned) >= sizeof(Int), "Unsigned must be larger than Int"); + + auto safe_mod = [](Int x, Int m) -> Int { x %= m; if (x < 0) x += m; return x; }; - auto floor_sum_unsigned = [](unsigned long long n, unsigned long long m, unsigned long long a, - unsigned long long b) -> unsigned long long { - unsigned long long ans = 0; + auto floor_sum_unsigned = [](Unsigned n, Unsigned m, Unsigned a, Unsigned b) -> Unsigned { + Unsigned ans = 0; while (true) { if (a >= m) { ans += n * (n - 1) / 2 * (a / m); @@ -26,26 +28,26 @@ long long floor_sum(long long n, long long m, long long a, long long b) { b %= m; } - unsigned long long y_max = a * n + b; + Unsigned y_max = a * n + b; if (y_max < m) break; // y_max < m * (n + 1) // floor(y_max / m) <= n - n = (unsigned long long)(y_max / m); - b = (unsigned long long)(y_max % m); + n = (Unsigned)(y_max / m); + b = (Unsigned)(y_max % m); std::swap(m, a); } return ans; }; - unsigned long long ans = 0; + Unsigned ans = 0; if (a < 0) { - unsigned long long a2 = safe_mod(a, m); - ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m); + Unsigned a2 = safe_mod(a, m); + ans -= Unsigned(1) * n * (n - 1) / 2 * ((a2 - a) / m); a = a2; } if (b < 0) { - unsigned long long b2 = safe_mod(b, m); - ans -= 1ULL * n * ((b2 - b) / m); + Unsigned b2 = safe_mod(b, m); + ans -= Unsigned(1) * n * ((b2 - b) / m); b = b2; } return ans + floor_sum_unsigned(n, m, a, b); diff --git a/utilities/test/floor_sum.test.cpp b/utilities/test/floor_sum.test.cpp index 629700ef..b369d8c0 100644 --- a/utilities/test/floor_sum.test.cpp +++ b/utilities/test/floor_sum.test.cpp @@ -10,6 +10,6 @@ int main() { while (T--) { int N, M, A, B; cin >> N >> M >> A >> B; - cout << floor_sum(N, M, A, B) << '\n'; + cout << floor_sum(N, M, A, B) << '\n'; } }