PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: piyush_2007
Tester: wasd2401
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Math
PROBLEM:
Given N and K, consider the array A formed by concatenating [1, 2, \ldots, N], K times.
You’re also given integers a,b,c,d.
The score of a sequence equals
where g equals the GCD of the elements of the sequence, and m equals its length.
Find the sum of the scores of all non-empty subsequences of A.
EXPLANATION:
The array A itself can be quite long, but its elements remain small: they’re all at most N.
This means the GCD of any subsequence is also at most N.
So, let’s fix g, the GCD of the subsequence we’re considering.
Then, the numerator of the score of any subsequence with GCD g is always (g+a)\cdot (g+b), and is hence constant.
This means it’s enough to compute the sum of \frac{1}{(m+c)\cdot (m+d)} across all subsequences whose GCD is g.
Let this value be D(g).
It’s somewhat hard to obtain a subsequence whose GCD is exactly g, but it’s quite easy to obtain one whose GCD is a multiple of g: just choose any subset of elements that are multiples of g themselves!
In particular, there are L_g = \left\lfloor \frac{N}{g} \right\rfloor \cdot K elements of A that are multiples of g, and any subset of them will have a GCD that’s a multiple of g.
Now, note that
This is because the first term computes the sum we want across all multiples of g, so we remove from it the values that correspond specifically to larger multiples of g.
Iterating across all numbers and their multiples till N takes \mathcal{O}(N\log N) time, so we focus on computing the first term, i.e
To compute \displaystyle\sum_{i=1}^{L_g} \binom{L_g}{i} \frac{1}{(i+c)\cdot (i+d)}, we’ll need a few tricks.
The first is that of partial fractions, to write \displaystyle\frac{1}{(i+c)\cdot (i+d)} = \frac{x}{i+c} + \frac{y}{i+d} for some constants x and y.
You can work the math out to verify that x = \frac{1}{d-c} and y = \frac{1}{c-d}.
This means it’s enough for us to be able to compute
(multiplied by some constant), then do the same thing with the denominator being i+d instead.
To do this, we turn to generating functions.
Since we want binomial coefficients in the sum, we turn to the binomial expansion.
Note that we want i+c in the denominator, and one way of obtaining constants related to the degree in the denominator is via integration.
Specifically, we can do the following:
Note that the right side of the above equation, evaluated at x = 1, is exactly what we want (minus the term for i = 0, which can be subtracted out separately).
For this though, we’ll need to be able to compute the integral on the left (with x = 1).
Let’s define
Assume b\gt 0, since if b = 0 it’s trivial to compute.
Integrating this by parts, with u(t) = t^{b} and v'(t) = (1+t)^a, we get
This gives a recursive formulation for I(a, b), with I(a, 0) being the base case.
Since b decreases by 1 at each step, this can be computed in \mathcal{O}(b) recursive steps.
In our case, we want I(L_g, c-1) and I(L_g, d-1), which can be computed in \mathcal{O}(c+d).
Given that D(g) can be computed in \mathcal{O}(d) time, we now have a solution that works in \mathcal{O}(Nd\log{MOD} + N\log N).
To speed it up further, note that there are only \mathcal{O}(\sqrt N) values of \left\lfloor \frac{N}{g} \right\rfloor, so running the \mathcal{O}(d) algorithm for only all of them and reusing results gives an algorithm that’s \mathcal{O}(d\sqrt{N}\log{MOD} + N\log N), and hence fast enough.
TIME COMPLEXITY:
\mathcal{O}(d\sqrt{N}\log{MOD} + N\log N) per testcase.
Across T tests, the first term is bounded by 2000\cdot \sqrt{2\cdot 10^5}\cdot\log{MOD} and the latter by
2\cdot 10^5 \log({2\cdot 10^5}), so this is fast enough.
CODE:
Author's code (C++)
// ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int
#define ld long double
const int MOD = 998244353; // check mod
struct mod_int {
int val;
mod_int(long long v = 0) {
if (v < 0)
v = v % MOD + MOD;
if (v >= MOD)
v %= MOD;
val = v;
}
static int mod_inv(int a, int m = MOD) {
int g = m, r = a, x = 0, y = 1;
while (r != 0) {
int q = g / r;
g %= r; swap(g, r);
x -= q * y; swap(x, y);
}
return x < 0 ? x + m : x;
}
explicit operator int() const {
return val;
}
mod_int& operator+=(const mod_int &other) {
val += other.val;
if (val >= MOD) val -= MOD;
return *this;
}
mod_int& operator-=(const mod_int &other) {
val -= other.val;
if (val < 0) val += MOD;
return *this;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return x % m;
#endif
unsigned x_high = x >> 32, x_low = (unsigned) x;
unsigned quot, rem;
asm("divl %4\n"
: "=a" (quot), "=d" (rem)
: "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
mod_int& operator*=(const mod_int &other) {
val = fast_mod((uint64_t) val * other.val);
return *this;
}
mod_int& operator/=(const mod_int &other) {
return *this *= other.inv();
}
friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
mod_int& operator++() {
val = val == MOD - 1 ? 0 : val + 1;
return *this;
}
mod_int& operator--() {
val = val == 0 ? MOD - 1 : val - 1;
return *this;
}
mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
mod_int operator-() const {
return val == 0 ? 0 : MOD - val;
}
bool operator==(const mod_int &other) const { return val == other.val; }
bool operator!=(const mod_int &other) const { return val != other.val; }
mod_int inv() const {
return mod_inv(val);
}
mod_int pow(long long p) const {
assert(p >= 0);
mod_int a = *this, result = 1;
while (p > 0) {
if (p & 1)
result *= a;
a *= a;
p >>= 1;
}
return result;
}
friend ostream& operator<<(ostream &stream, const mod_int &m) {
return stream << m.val;
}
friend istream& operator >> (istream &stream, mod_int &m) {
return stream>>m.val;
}
};
mod_int f(ll a,ll b){
if(b==0){
return (mod_int(2).pow(a+1)-1)/(a+1);
}
return (mod_int(2).pow(a+1)/(a+1) - (f(a+1,b-1)*b)/(a+1));
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// freopen("inp.txt","r",stdin);
// freopen("out.txt","w",stdout);
int test=1;
cin>>test;
assert(test>=1 && test<=100);
int sum_n=0,sum_d=0;
while(test--){
ll n,k,a,b,c,d;
cin>>n>>k>>a>>b>>c>>d;
assert(n>=1 && n<=2e5);
assert(k>=1 && k<=1e9);
assert(a>=0 && b>a && c>b && d>c && d<=2000);
sum_n+=n;
sum_d+=d;
mod_int prev=0;
vector<mod_int> dp(n+1,0);
mod_int ans=0;
for(int i=n;i>=1;i--){
if(i<n && (n/i)==(n/(i+1))){
dp[i]=prev;
}
else{
dp[i]=((f((n/i)*k,c-1)-mod_int(1)/c)-(f((n/i)*k,d-1)-mod_int(1)/d))/(d-c);
prev=dp[i];
}
for(int j=2*i;j<=n;j+=i){
dp[i]-=dp[j];
}
mod_int mul=1LL*(i+a)*(i+b);
ans+=dp[i]*mul;
}
cout<<ans<<'\n';
}
assert(sum_n<=2e5 && sum_d<=2000);
return 0;
}
Editorialist's code (Python)
mod = 998244353
import sys
sys.setrecursionlimit(2500)
def I(a, b):
if b == 0: return (pow(2, a+1, mod)- 1) * pow(a+1, mod-2, mod) % mod
return (pow(2, a+1, mod)- b * I(a+1, b-1) % mod) * pow(a+1, mod-2, mod) % mod
def calc(n, c):
# sum i >= 1 (choose(n, i) / (i+c))
return (I(n, c-1) - pow(c, mod-2, mod))%mod
for _ in range(int(input())):
n, k = map(int, input().split())
a, b, c, d = map(int, input().split())
ans = 0
dp = [0]*(n+1)
for g in reversed(range(1, n+1)):
N = (n//g) * k
if n//g != n//(g+1): dp[g] = (calc(N, c) - calc(N, d)) * pow(d-c, mod-2, mod) % mod
else: dp[g] = val
val = dp[g]
if 2*g <= n: dp[g] -= sum(dp[2*g:n+1:g]) % mod
dp[g] %= mod
ans += dp[g] * (g+a) % mod * (g+b) % mod
print(ans % mod)