PHONYPERM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: piyush_2007
Tester: wasd2401
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math

PROBLEM:

Given N and K, consider the array A formed by concatenating [1, 2, \ldots, N], K times.

You’re also given integers a,b,c,d.
The score of a sequence equals

\frac{(g+a)\cdot (g+b)}{(m+c)\cdot (m+d)}

where g equals the GCD of the elements of the sequence, and m equals its length.

Find the sum of the scores of all non-empty subsequences of A.

EXPLANATION:

The array A itself can be quite long, but its elements remain small: they’re all at most N.
This means the GCD of any subsequence is also at most N.

So, let’s fix g, the GCD of the subsequence we’re considering.
Then, the numerator of the score of any subsequence with GCD g is always (g+a)\cdot (g+b), and is hence constant.
This means it’s enough to compute the sum of \frac{1}{(m+c)\cdot (m+d)} across all subsequences whose GCD is g.
Let this value be D(g).

It’s somewhat hard to obtain a subsequence whose GCD is exactly g, but it’s quite easy to obtain one whose GCD is a multiple of g: just choose any subset of elements that are multiples of g themselves!
In particular, there are L_g = \left\lfloor \frac{N}{g} \right\rfloor \cdot K elements of A that are multiples of g, and any subset of them will have a GCD that’s a multiple of g.

Now, note that

D(g) = \left(\sum_{i=1}^{L_g} \binom{L_g}{i} \frac{1}{(i+c)\cdot (i+d)} \right) - \left(D(2g) + D(3g) + D(4g) + \ldots\right)

This is because the first term computes the sum we want across all multiples of g, so we remove from it the values that correspond specifically to larger multiples of g.

Iterating across all numbers and their multiples till N takes \mathcal{O}(N\log N) time, so we focus on computing the first term, i.e

\sum_{i=1}^{L_g} \binom{L_g}{i} \frac{1}{(i+c)\cdot (i+d)}

To compute \displaystyle\sum_{i=1}^{L_g} \binom{L_g}{i} \frac{1}{(i+c)\cdot (i+d)}, we’ll need a few tricks.

The first is that of partial fractions, to write \displaystyle\frac{1}{(i+c)\cdot (i+d)} = \frac{x}{i+c} + \frac{y}{i+d} for some constants x and y.
You can work the math out to verify that x = \frac{1}{d-c} and y = \frac{1}{c-d}.

This means it’s enough for us to be able to compute

\sum_{i=1}^{L_g} \binom{L_g}{i} \frac{1}{i+c}

(multiplied by some constant), then do the same thing with the denominator being i+d instead.

To do this, we turn to generating functions.

Since we want binomial coefficients in the sum, we turn to the binomial expansion.
Note that we want i+c in the denominator, and one way of obtaining constants related to the degree in the denominator is via integration.
Specifically, we can do the following:

\begin{align*} (1+x)^{L_g} &= \sum_{i=0}^{L_g} \binom{L_g}{i} x^i \\ (1+x)^{L_g} \cdot x^{c-1} &= \sum_{i=0}^{L_g} \binom{L_g}{i} x^{i+c-1} \\ \int_0^x (1+t)^{L_g}\cdot t^{c-1}dt &= \sum_{i=0}^{L_g} \binom{L_g}{i} \frac{x^{i+c}}{i+c} \end{align*}

Note that the right side of the above equation, evaluated at x = 1, is exactly what we want (minus the term for i = 0, which can be subtracted out separately).
For this though, we’ll need to be able to compute the integral on the left (with x = 1).

Let’s define

I(a, b) = \int_0^1 (1+t)^{a}\cdot t^{b}dt

Assume b\gt 0, since if b = 0 it’s trivial to compute.
Integrating this by parts, with u(t) = t^{b} and v'(t) = (1+t)^a, we get

\begin{align*} I(a, b) &= \left[ t^{b} \cdot\frac{(1+t)^{a+1}}{a+1}\right]_0^1 - \int_0^1 \frac{(1+t)^{a+1}}{a+1}\cdot b\cdot t^{b-1}dt \\ &= \frac{2^{a+1}}{a+1} - \frac{b}{a+1}\cdot I(a+1, b-1) \end{align*}

This gives a recursive formulation for I(a, b), with I(a, 0) being the base case.
Since b decreases by 1 at each step, this can be computed in \mathcal{O}(b) recursive steps.

In our case, we want I(L_g, c-1) and I(L_g, d-1), which can be computed in \mathcal{O}(c+d).


Given that D(g) can be computed in \mathcal{O}(d) time, we now have a solution that works in \mathcal{O}(Nd\log{MOD} + N\log N).
To speed it up further, note that there are only \mathcal{O}(\sqrt N) values of \left\lfloor \frac{N}{g} \right\rfloor, so running the \mathcal{O}(d) algorithm for only all of them and reusing results gives an algorithm that’s \mathcal{O}(d\sqrt{N}\log{MOD} + N\log N), and hence fast enough.

TIME COMPLEXITY:

\mathcal{O}(d\sqrt{N}\log{MOD} + N\log N) per testcase.
Across T tests, the first term is bounded by 2000\cdot \sqrt{2\cdot 10^5}\cdot\log{MOD} and the latter by
2\cdot 10^5 \log({2\cdot 10^5}), so this is fast enough.

CODE:

Author's code (C++)
                                    //  ॐ
#include <bits/stdc++.h>
using namespace std;
#define PI 3.14159265358979323846
#define ll long long int
#define ld long double


const int MOD = 998244353;  // check mod
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

mod_int f(ll a,ll b){
    if(b==0){
       return (mod_int(2).pow(a+1)-1)/(a+1);
    }
    return (mod_int(2).pow(a+1)/(a+1) - (f(a+1,b-1)*b)/(a+1));
}


int main(){

    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // freopen("inp.txt","r",stdin);
    // freopen("out.txt","w",stdout);

    int test=1;
    cin>>test;

    assert(test>=1 && test<=100);
    int sum_n=0,sum_d=0;

    while(test--){

                     ll n,k,a,b,c,d;
                     cin>>n>>k>>a>>b>>c>>d;

                     assert(n>=1 && n<=2e5);
                     assert(k>=1 && k<=1e9);
                     assert(a>=0 && b>a && c>b && d>c && d<=2000);
                     sum_n+=n;
                     sum_d+=d;

                     mod_int prev=0;
                     vector<mod_int> dp(n+1,0);
                     mod_int ans=0;

                     for(int i=n;i>=1;i--){
                         if(i<n && (n/i)==(n/(i+1))){
                              dp[i]=prev;
                         }
                         else{
                              dp[i]=((f((n/i)*k,c-1)-mod_int(1)/c)-(f((n/i)*k,d-1)-mod_int(1)/d))/(d-c);
                              prev=dp[i];
                         }

                         for(int j=2*i;j<=n;j+=i){
                              dp[i]-=dp[j];
                         }

                         mod_int mul=1LL*(i+a)*(i+b);
                         ans+=dp[i]*mul;
                     }

                     cout<<ans<<'\n';
                   

    }

    assert(sum_n<=2e5 && sum_d<=2000);

    return 0;
}

Editorialist's code (Python)
mod = 998244353

import sys
sys.setrecursionlimit(2500)

def I(a, b):
    if b == 0: return (pow(2, a+1, mod)- 1) * pow(a+1, mod-2, mod) % mod
    return (pow(2, a+1, mod)- b * I(a+1, b-1) % mod) * pow(a+1, mod-2, mod) % mod

def calc(n, c):
    # sum i >= 1 (choose(n, i) / (i+c))
    return (I(n, c-1) - pow(c, mod-2, mod))%mod

for _ in range(int(input())):
    n, k = map(int, input().split())
    a, b, c, d = map(int, input().split())
    ans = 0
    dp = [0]*(n+1)
    for g in reversed(range(1, n+1)):
        N = (n//g) * k
        if n//g != n//(g+1): dp[g] = (calc(N, c) - calc(N, d)) * pow(d-c, mod-2, mod) % mod
        else: dp[g] = val
        val = dp[g]

        if 2*g <= n: dp[g] -= sum(dp[2*g:n+1:g]) % mod
        dp[g] %= mod
        ans += dp[g] * (g+a) % mod * (g+b) % mod
    print(ans % mod)