Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ABC372G で使いにくかった点の修正 #343

Merged
merged 2 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions rational/rational_number.hpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
#pragma once
#include <cassert>
#include <limits>

// Rational number + {infinity(1 / 0), -infiity(-1 / 0), nan(0 / 0)} (有理数)
// Do not compare any number with nan
// Verified: Yandex Cup 2022 Final E https://contest.yandex.com/contest/42710/problems/K
template <class Int, bool AlwaysReduce = false> struct Rational {
Int num, den;
template <class Int, bool AutoReduce = false> struct Rational {
Int num, den; // den >= 0

static constexpr Int my_gcd(Int a, Int b) {
// return __gcd(a, b);
if (a < 0) a = -a;
if (b < 0) b = -b;
while (a and b) {
if (a > b)
if (a > b) {
a %= b;
else
} else {
b %= a;
}
}
return a + b;
}

constexpr Rational(Int num = 0, Int den = 1) : num(num), den(den) { normalize(); }
constexpr void normalize() noexcept {
if constexpr (AlwaysReduce) { // reduction
if constexpr (AutoReduce) { // reduction
Int g = my_gcd(num, den);
if (g) num /= g, den /= g;
} else {
Expand All @@ -31,12 +34,16 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
}
if (den < 0) num = -num, den = -den; // denominator >= 0
}

constexpr bool is_finite() const noexcept { return den != 0; }
constexpr bool is_infinite_or_nan() const noexcept { return den == 0; }

constexpr Rational operator+(const Rational &r) const noexcept {
if (!den and !r.den) return Rational(num + r.num, den);
if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num + r.num, 0);
return Rational(num * r.den + den * r.num, den * r.den);
}
constexpr Rational operator-(const Rational &r) const noexcept {
if (!den and !r.den) return Rational(num - r.num, den);
if (is_infinite_or_nan() and r.is_infinite_or_nan()) return Rational(num - r.num, 0);
return Rational(num * r.den - den * r.num, den * r.den);
}
constexpr Rational operator*(const Rational &r) const noexcept {
Expand All @@ -51,23 +58,36 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
constexpr Rational &operator/=(const Rational &r) noexcept { return *this = *this / r; }
constexpr Rational operator-() const noexcept { return Rational(-num, den); }
constexpr Rational abs() const noexcept { return Rational(num > 0 ? num : -num, den); }

constexpr Int floor() const {
assert(is_finite());
if (num > 0) {
return num / den;
} else {
return -((-num + den - 1) / den);
}
}

constexpr bool operator==(const Rational &r) const noexcept {
if constexpr (AlwaysReduce) {
if (is_infinite_or_nan() or r.is_infinite_or_nan()) {
return num == r.num and den == r.den;
} else {
return num * r.den == r.num * den;
}
}

constexpr bool operator!=(const Rational &r) const noexcept { return !(*this == r); }

constexpr bool operator<(const Rational &r) const noexcept {
if (den == 0 and r.den == 0)
if (is_infinite_or_nan() and r.is_infinite_or_nan())
return num < r.num;
else if (den == 0)
else if (is_infinite_or_nan()) {
return num < 0;
else if (r.den == 0)
} else if (r.is_infinite_or_nan()) {
return r.num > 0;
else
} else {
return num * r.den < den * r.num;
}
}
constexpr bool operator<=(const Rational &r) const noexcept {
return (*this == r) or (*this < r);
Expand All @@ -76,6 +96,7 @@ template <class Int, bool AlwaysReduce = false> struct Rational {
constexpr bool operator>=(const Rational &r) const noexcept {
return (r == *this) or (r < *this);
}

constexpr explicit operator double() const noexcept { return (double)num / (double)den; }
constexpr explicit operator long double() const noexcept {
return (long double)num / (long double)den;
Expand Down
34 changes: 18 additions & 16 deletions utilities/floor_sum.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#pragma once
#include <utility>

// CUT begin
// \sum_{i=0}^{n-1} floor((ai + b) / m)
// 0 <= n < 2e32
// 1 <= m < 2e32
// 0 <= n < 2e32 (if Int is long long)
// 1 <= m < 2e32 (if Int is long long)
// 0 <= a, b < m
// Complexity: O(lg(m))
long long floor_sum(long long n, long long m, long long a, long long b) {
auto safe_mod = [](long long x, long long m) -> long long {
template <class Int, class Unsigned> Int floor_sum(Int n, Int m, Int a, Int b) {
static_assert(-Int(1) < 0, "Int must be signed");
static_assert(-Unsigned(1) > 0, "Unsigned must be unsigned");
static_assert(sizeof(Unsigned) >= sizeof(Int), "Unsigned must be larger than Int");

auto safe_mod = [](Int x, Int m) -> Int {
x %= m;
if (x < 0) x += m;
return x;
};
auto floor_sum_unsigned = [](unsigned long long n, unsigned long long m, unsigned long long a,
unsigned long long b) -> unsigned long long {
unsigned long long ans = 0;
auto floor_sum_unsigned = [](Unsigned n, Unsigned m, Unsigned a, Unsigned b) -> Unsigned {
Unsigned ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
Expand All @@ -26,26 +28,26 @@ long long floor_sum(long long n, long long m, long long a, long long b) {
b %= m;
}

unsigned long long y_max = a * n + b;
Unsigned y_max = a * n + b;
if (y_max < m) break;
// y_max < m * (n + 1)
// floor(y_max / m) <= n
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
n = (Unsigned)(y_max / m);
b = (Unsigned)(y_max % m);
std::swap(m, a);
}
return ans;
};

unsigned long long ans = 0;
Unsigned ans = 0;
if (a < 0) {
unsigned long long a2 = safe_mod(a, m);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
Unsigned a2 = safe_mod(a, m);
ans -= Unsigned(1) * n * (n - 1) / 2 * ((a2 - a) / m);
a = a2;
}
if (b < 0) {
unsigned long long b2 = safe_mod(b, m);
ans -= 1ULL * n * ((b2 - b) / m);
Unsigned b2 = safe_mod(b, m);
ans -= Unsigned(1) * n * ((b2 - b) / m);
b = b2;
}
return ans + floor_sum_unsigned(n, m, a, b);
Expand Down
2 changes: 1 addition & 1 deletion utilities/test/floor_sum.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ int main() {
while (T--) {
int N, M, A, B;
cin >> N >> M >> A >> B;
cout << floor_sum(N, M, A, B) << '\n';
cout << floor_sum<long long, unsigned long long>(N, M, A, B) << '\n';
}
}
Loading