Skip to content

Commit

Permalink
Merge pull request #343 from hitonanode/fix-abc372g
Browse files Browse the repository at this point in the history
ABC372G で使いにくかった点の修正
  • Loading branch information
hitonanode authored Sep 22, 2024
2 parents 879824a + 06ed22b commit 7eaef17
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
47 changes: 34 additions & 13 deletions rational/rational_number.hpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
#pragma once
#include <cassert>
#include <limits>

// Rational number + {infinity(1 / 0), -infiity(-1 / 0), nan(0 / 0)} (有理数)
// Do not compare any number with nan
// Verified: Yandex Cup 2022 Final E https://contest.yandex.com/contest/42710/problems/K
template <class Int, bool AlwaysReduce = false> struct Rational {
Int num, den;
template <class Int, bool AutoReduce = false> struct Rational {
Int num, den; // den >= 0

static constexpr Int my_gcd(Int a, Int b) {
// return __gcd(a, b);
if (a < 0) a = -a;
if (b < 0) b = -b;
while (a and b) {
if (a > b)
if (a > b) {
a %= b;
else
} else {
b %= a;
}
}
return a + b;
}

constexpr Rational(Int num = 0, Int den = 1) : num(num), den(den) { normalize(); }
constexpr void normalize() noexcept {
if constexpr (AlwaysReduce) { // reduction
if constexpr (AutoReduce) { // reduction
Int g = my_gcd(num, den);
if (g) num /= g, den /= g;
} else {
Expand All @@ -31,12 +34,16 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
}
if (den < 0) num = -num, den = -den; // denominator >= 0
}

constexpr bool is_finite() const noexcept { return den != 0; }
constexpr bool is_infinite_or_nan() const noexcept { return den == 0; }

constexpr Rational operator+(const Rational &r) const noexcept {
if (!den and !r.den) return Rational(num + r.num, den);
if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num + r.num, 0);
return Rational(num * r.den + den * r.num, den * r.den);
}
constexpr Rational operator-(const Rational &r) const noexcept {
if (!den and !r.den) return Rational(num - r.num, den);
if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num - r.num, 0);
return Rational(num * r.den - den * r.num, den * r.den);
}
constexpr Rational operator*(const Rational &r) const noexcept {
Expand All @@ -51,23 +58,36 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
constexpr Rational &operator/=(const Rational &r) noexcept { return *this = *this / r; }
constexpr Rational operator-() const noexcept { return Rational(-num, den); }
constexpr Rational abs() const noexcept { return Rational(num > 0 ? num : -num, den); }

constexpr Int floor() const {
assert(is_finite());
if (num > 0) {
return num / den;
} else {
return -((-num + den - 1) / den);
}
}

constexpr bool operator==(const Rational &r) const noexcept {
if constexpr (AlwaysReduce) {
if (is_infinite_or_nan() or r.is_infinite_or_nan()) {
return num == r.num and den == r.den;
} else {
return num * r.den == r.num * den;
}
}

constexpr bool operator!=(const Rational &r) const noexcept { return !(*this == r); }

constexpr bool operator<(const Rational &r) const noexcept {
if (den == 0 and r.den == 0)
if (is_infinite_or_nan() and r.is_infinite_or_nan())
return num < r.num;
else if (den == 0)
else if (is_infinite_or_nan()) {
return num < 0;
else if (r.den == 0)
} else if (r.is_infinite_or_nan()) {
return r.num > 0;
else
} else {
return num * r.den < den * r.num;
}
}
constexpr bool operator<=(const Rational &r) const noexcept {
return (*this == r) or (*this < r);
Expand All @@ -76,6 +96,7 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
constexpr bool operator>=(const Rational &r) const noexcept {
return (r == *this) or (r < *this);
}

constexpr explicit operator double() const noexcept { return (double)num / (double)den; }
constexpr explicit operator long double() const noexcept {
return (long double)num / (long double)den;
Expand Down
34 changes: 18 additions & 16 deletions utilities/floor_sum.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#pragma once
#include <utility>

// CUT begin
// \sum_{i=0}^{n-1} floor((ai + b) / m)
// 0 <= n < 2e32
// 1 <= m < 2e32
// 0 <= n < 2e32 (if Int is long long)
// 1 <= m < 2e32 (if Int is long long)
// 0 <= a, b < m
// Complexity: O(lg(m))
long long floor_sum(long long n, long long m, long long a, long long b) {
auto safe_mod = [](long long x, long long m) -> long long {
template <class Int, class Unsigned> Int floor_sum(Int n, Int m, Int a, Int b) {
static_assert(-Int(1) < 0, "Int must be signed");
static_assert(-Unsigned(1) > 0, "Unsigned must be unsigned");
static_assert(sizeof(Unsigned) >= sizeof(Int), "Unsigned must be larger than Int");

auto safe_mod = [](Int x, Int m) -> Int {
x %= m;
if (x < 0) x += m;
return x;
};
auto floor_sum_unsigned = [](unsigned long long n, unsigned long long m, unsigned long long a,
unsigned long long b) -> unsigned long long {
unsigned long long ans = 0;
auto floor_sum_unsigned = [](Unsigned n, Unsigned m, Unsigned a, Unsigned b) -> Unsigned {
Unsigned ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
Expand All @@ -26,26 +28,26 @@ long long floor_sum(long long n, long long m, long long a, long long b) {
b %= m;
}

unsigned long long y_max = a * n + b;
Unsigned y_max = a * n + b;
if (y_max < m) break;
// y_max < m * (n + 1)
// floor(y_max / m) <= n
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
n = (Unsigned)(y_max / m);
b = (Unsigned)(y_max % m);
std::swap(m, a);
}
return ans;
};

unsigned long long ans = 0;
Unsigned ans = 0;
if (a < 0) {
unsigned long long a2 = safe_mod(a, m);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
Unsigned a2 = safe_mod(a, m);
ans -= Unsigned(1) * n * (n - 1) / 2 * ((a2 - a) / m);
a = a2;
}
if (b < 0) {
unsigned long long b2 = safe_mod(b, m);
ans -= 1ULL * n * ((b2 - b) / m);
Unsigned b2 = safe_mod(b, m);
ans -= Unsigned(1) * n * ((b2 - b) / m);
b = b2;
}
return ans + floor_sum_unsigned(n, m, a, b);
Expand Down
2 changes: 1 addition & 1 deletion utilities/test/floor_sum.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ int main() {
while (T--) {
int N, M, A, B;
cin >> N >> M >> A >> B;
cout << floor_sum(N, M, A, B) << '\n';
cout << floor_sum<long long, unsigned long long>(N, M, A, B) << '\n';
}
}

0 comments on commit 7eaef17

Please sign in to comment.