MODINDROME - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Sieve, binary exponentiation, dynamic programming

PROBLEM:

An array B is said to be good with respect to positive integer X if the array C computed as C_i = B_i\bmod X is a palindrome.
Define f(B) to be the number of integers X such that B is good with respect to X, and f(B) = -1 if there are infinitely many.

Given N and M, compute the sum of f(A) across all integers arrays of length N with values between 1 and M.

EXPLANATION:

Let’s analyze when the array A is good with respect to X.

It must be a palindrome when looked at modulo X.
So, for each i, we want A_i\bmod X = A_{N+1-i}\bmod X.
This is equivalent to saying that X divides |A_i - A_{N+1-i}|.

So, A is good with respect to X if and only if X divides all of
|A_1 - A_N|, |A_2 - A_{N-1}|, |A_3 - A_{N-2}|, \ldots

Now, for the array A, any valid X must divide all these values - so it must divide their GCD.
Let g = \gcd_{i=1}^N(|A_i - A_{N+1-i}|)

Any valid X must be a factor of g. So, we have two cases:

  1. g = 0.
    Here, X can be any integer at all, so f(A) = -1.
  2. g \gt 0.
    Here, A will be good with respect to any factor of g. So, f(A) = \text{facs}(g), where \text{facs}(g) denotes the number of positive factors of g.

Since the array elements must be between 1 and M, their differences must be between 0 and M-1.
So, the GCD of all opposite differences must be between 0 and M-1 as well.

Suppose we’re able to calculate, for each 0 \leq g \lt M, the value ct_g, which denotes the number of arrays of length N with elements between 1 and M such that the GCD of opposite differences is exactly g.
Then, the answer would simply be

-ct_0 + \sum_{g=1}^{M-1} ct_g \cdot \text{facs}(g)

because anything with GCD 0 contributes -1, and anything with GCD g \gt 0 contributes \text{facs}(g) to the sum.

The values of \text{facs}(g) for all g from 1 to M can be computed in \mathcal{O}(M\log M) using a sieve.
We now focus on computing ct_g.


Let’s fix the value of g \gt 0.
We want the GCD of all opposite differences to be g - meaning each opposite difference must be, to start with, a multiple of g.

This means we must count the number of ways the opposite difference can be a multiple of g.
For each d = 0, g, 2g, 3g, \ldots, for the opposite difference to be exactly d, there are exactly M-d ways to choose the pair of elements:

  • (x, x+d) for x = 1, 2, 3, \ldots, M-d
  • (x, x-d) for x = d+1, d+2, \ldots, M

So, the number of ways of obtaining one opposite pair whose difference is a multiple of g, is

M + (M-g) + (M-2g) + \ldots

summed up across all multiples of g that don’t exceed M.

Let this value be k_g. It can be computed either by simply iterating through all the values of d, or by using a formula, given that it’s an arithmetic progression.

Once k_g is known, we have \left\lfloor \frac{N}{2} \right\rfloor opposite pairs of elements, each of which can receive one of these k_g pairs.
So, there are

k_g ^ {\left\lfloor \frac{N}{2} \right\rfloor}

configurations in total.

However, these are all configurations where the GCD is a multiple of g, not exactly g.

To obtain the number of configurations with GCD equal to g, we must subtract out those where the GCD isn’t g.

  • If the GCD is 0, all opposite pairs of elements must be equal. There are M^{\left\lfloor \frac{N}{2} \right\rfloor} such configurations.
    • Note that this is also the value of ct_0.
  • If the GCD is a positive multiple of g, say x\cdot g, then by definition there are ct_{x\cdot g} configurations with GCD equal to x\cdot g.

So, we obtain

ct_g = k_g ^ {\left\lfloor \frac{N}{2} \right\rfloor} - M^{\left\lfloor \frac{N}{2} \right\rfloor} - ct_{2g} - ct_{3g} - \ldots

This can be computed, again, by simply iterating over multiples of g.
As long as the values of ct_{g} are cached when computing them (or simply computed in descending order of g), there’s no extra work necessary here: the complexity is just the number of multiples of g.

Once all the ct_g values are known, the answer can be computed as a simple summation in \mathcal{O}(M) as mentioned earlier.

Note that when N is odd, we haven’t accounted for the middle element, which isn’t actually paired with anything else.
However, this also means it can take any value at all without changing the value of f(A) - so we can simply multiply the obtained answer by M.


As for the time complexity, the only real work we do is iterating over all integers and their multiples from 1 to M, and then a couple of exponentiations.
The former has a complexity of \mathcal{O}(\frac{M}{1} + \frac{M}{2} + \ldots + \frac{M}{M}) = \mathcal{O}(M\log M), while the latter has a complexity of \mathcal{O}(\log N) if binary exponentiation is used, and is done \mathcal{O}(M) times in total, for \mathcal{O}(M\log N).

TIME COMPLEXITY:

\mathcal{O}(M\log M + M\log{N}) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll bexp(ll a, ll b) {
    a %= MOD;
    if (a == 0) return 0;

    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

void solve(int test_case){
    ll n,m; cin >> n >> m;
    vector<ll> w(m+5);
    w[0] = m;
    rep1(i,m-1){
        w[i] = 2*(m-i);
    }

    vector<ll> dp(m+5);
    rev(i,m,1){
        ll pick_ways = 0;
        for(int j = 0; j <= m; j += i){
            pick_ways += w[j];
        }

        ll ways = bexp(pick_ways,n/2)-bexp(w[0],n/2);
        ways = (ways%MOD+MOD)%MOD;
        if(n&1) ways = ways*m%MOD;

        dp[i] = ways;

        for(int j = 2*i; j <= m; j += i){
            dp[i] -= dp[j];
        }

        dp[i] = (dp[i]%MOD+MOD)%MOD;
    }

    vector<ll> divs(m+5);
    rep1(i,m){
        for(int j = i; j <= m; j += i){
            divs[j]++;
        }
    }

    ll ans = 0;
    rep1(i,m){
        ans += dp[i]*divs[i];
        ans %= MOD;
    }

    {
        ll bad = bexp(w[0],n/2);
        if(n&1) bad = bad*m%MOD;
        ans -= bad;
        ans = (ans%MOD+MOD)%MOD;
    }

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 1e6 + 5;
const int mod = 998244353;
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
    if (y == 0) return 1;

    int v = power(x, y / 2);
    v = 1LL * v * v % mod;

    if (y & 1) return 1LL * v * x % mod;
    else return v;
}

void factorialinit(){
    facinit = true;
    ff[0] = iff[0] = 1;

    for (int i = 1; i < facN; i++){
        ff[i] = 1LL * ff[i - 1] * i % mod;
    }

    iff[facN - 1] = power(ff[facN - 1], mod - 2);
    for (int i = facN - 2; i >= 1; i--){
        iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
    }
}

int C(int n, int r){
    if (!facinit) factorialinit();

    if (n == r) return 1;

    if (r < 0 || r > n) return 0;
    return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
    if (!facinit) factorialinit();

    assert(0 <= r && r <= n);
    return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
    //solutions to x1 + ... + xn = r 
    //xi >= 0

    return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n, m; cin >> n >> m;
    
    // diff = 3 => ways = 2, 2 * 2 
    // diff = 2 => ways = 4, 4 * 2 
    // diff = 1 => ways = 6, 6 * 1 
    // diff = 0 => ways = 4, 4 * -1 
    
    vector <int> f(m + 1, 0);
    // number of arrays with gcd of common differences = d 
    vector <int> a(m + 1, 0);
    for (int i = 0; i <= m; i++){
        // number of pairs with common difference = i
        if (i == 0){
            a[i] = m;
        } else {
            a[i] = 2 * m - 2 * i;
        }
    }
    
    int g = (n / 2);
    int all_zero = power(a[0], g);
    
    for (int d = m; d >= 1; d--){
        int ways = 0;
        for (int i = 0; i <= m; i += d){
            ways += a[i];
        }
        
        ways %= mod;
        
        ways = power(ways, g);
        
        f[d] = ways - all_zero;
        f[d] %= mod;
        if (f[d] < 0) f[d] += mod;
        
        for (int j = d + d; j <= m; j += d){
            f[d] -= f[j];
        }
        f[d] %= mod;
        if (f[d] < 0) f[d] += mod;
        assert(f[d] >= 0);
    }
    
    f[0] = all_zero;
    if (n % 2 == 1){
        for (int d = 0; d <= m; d++){
            f[d] *= m;
            f[d] %= mod;
        }
    }
    
    int ans = 0;
    // if gcd = g, can choose any divisors 
    // so need to find number of divisors 
    // fuck it n sqrt (n)
    
    vector <int> cnt(m + 1, 0);
    cnt[0] = -1;
    for (int i = 1; i <= m; i++){
        for (int j = 1; j * j <= i; j++){
            if (i % j == 0){
                cnt[i]++;
                if (j * j != i){
                    cnt[i]++;
                }
            }
        }
    }
    
    for (int i = 0; i <= m; i++){
        ans += cnt[i] * f[i];
        ans %= mod;
    }
    
    if (ans < 0) ans += mod;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
    n, m = map(int, input().split())
    
    dp = [0]*(m+1)
    facs = [0]*(m+1)
    ans = 0
    for g in reversed(range(1, m+1)):
        # gcd = g
        # floor(n/2) pairs, for each choose a pair with gcd = x*g
        # remove arrays where gcd = 0 or gcd > g
        pairs = m
        for d in range(g, m, g):
            facs[d] += 1
            pairs += 2*(m-d)
            if d > g: dp[g] -= dp[d]
        dp[g] += pow(pairs, n//2, mod) - pow(m, n//2, mod)
        dp[g] %= mod

    for i in range(1, m+1): ans += facs[i] * dp[i]    
    ans -= pow(m, n//2, mod)
    if n%2 == 1: ans *= m
    print(ans % mod)

2 Likes

Actually had a different solution that doesnt require DP and/or Sieve (O(Mlog(N))) :

# ﷽
# my template (https://github.com/Mohamed-Elnageeb , https://codeforces.com/profile/the_last_smilodon) :
import sys
input = lambda: sys.stdin.buffer.readline().decode().rstrip()


def II():
    return int(input())

def MII():
    return map(int, input().split())


###############################################################################
###############################################################################
###############################################################################
mod = 998244353

def solve():
    n,m = MII()
    sub = pow(m,(n)//2,mod) * ( m + 1)  # num of arrays with inifinte values and subtraction of m double countings
    ans = 0 

    for num in range(1,m+1) :
        
        val1 = m//num 
        val2 = val1+1 

        c1 = num - (m%num)
        c2 = m%num  
        
        ans += pow((c1*(val1**2) + c2*(val2**2))%mod , n//2, mod) # either choose from the first pool with val1^2 choices or from second pool with val2^2 choices 
        
        ans %= mod 
        
    ans -= sub
    if n&1 : ans *= m 
    
    print(ans%mod)

###############################################################################
###############################################################################
###############################################################################
 
for t in range(II()):
    
    solve()
    
    

what is the logic of multiplying (m+1) at the end of pow(m,n//2,mod)

I am also on same track, but my solution is O ( N * M ) .( basically, for each of the possible values of X, I have to recalculate answers ).

I must subtract -1 for every array with inifinite values. And also when am calculating contribution of the m values am adding their contribution in those arrays. But this is not true so i subtract back m* (num of such arrays). Hence we subtract (m+1)*(num. Of such arrays).
The edirorial does the same thing but keeps subtracting them with each iteration.