PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Sieve, binary exponentiation, dynamic programming
PROBLEM:
An array B is said to be good with respect to positive integer X if the array C computed as C_i = B_i\bmod X is a palindrome.
Define f(B) to be the number of integers X such that B is good with respect to X, and f(B) = -1 if there are infinitely many.
Given N and M, compute the sum of f(A) across all integers arrays of length N with values between 1 and M.
EXPLANATION:
Let’s analyze when the array A is good with respect to X.
It must be a palindrome when looked at modulo X.
So, for each i, we want A_i\bmod X = A_{N+1-i}\bmod X.
This is equivalent to saying that X divides |A_i - A_{N+1-i}|.
So, A is good with respect to X if and only if X divides all of
|A_1 - A_N|, |A_2 - A_{N-1}|, |A_3 - A_{N-2}|, \ldots
Now, for the array A, any valid X must divide all these values - so it must divide their GCD.
Let g = \gcd_{i=1}^N(|A_i - A_{N+1-i}|)
Any valid X must be a factor of g. So, we have two cases:
- g = 0.
Here, X can be any integer at all, so f(A) = -1. - g \gt 0.
Here, A will be good with respect to any factor of g. So, f(A) = \text{facs}(g), where \text{facs}(g) denotes the number of positive factors of g.
Since the array elements must be between 1 and M, their differences must be between 0 and M-1.
So, the GCD of all opposite differences must be between 0 and M-1 as well.
Suppose we’re able to calculate, for each 0 \leq g \lt M, the value ct_g, which denotes the number of arrays of length N with elements between 1 and M such that the GCD of opposite differences is exactly g.
Then, the answer would simply be
because anything with GCD 0 contributes -1, and anything with GCD g \gt 0 contributes \text{facs}(g) to the sum.
The values of \text{facs}(g) for all g from 1 to M can be computed in \mathcal{O}(M\log M) using a sieve.
We now focus on computing ct_g.
Let’s fix the value of g \gt 0.
We want the GCD of all opposite differences to be g - meaning each opposite difference must be, to start with, a multiple of g.
This means we must count the number of ways the opposite difference can be a multiple of g.
For each d = 0, g, 2g, 3g, \ldots, for the opposite difference to be exactly d, there are exactly M-d ways to choose the pair of elements:
- (x, x+d) for x = 1, 2, 3, \ldots, M-d
- (x, x-d) for x = d+1, d+2, \ldots, M
So, the number of ways of obtaining one opposite pair whose difference is a multiple of g, is
summed up across all multiples of g that don’t exceed M.
Let this value be k_g. It can be computed either by simply iterating through all the values of d, or by using a formula, given that it’s an arithmetic progression.
Once k_g is known, we have \left\lfloor \frac{N}{2} \right\rfloor opposite pairs of elements, each of which can receive one of these k_g pairs.
So, there are
configurations in total.
However, these are all configurations where the GCD is a multiple of g, not exactly g.
To obtain the number of configurations with GCD equal to g, we must subtract out those where the GCD isn’t g.
- If the GCD is 0, all opposite pairs of elements must be equal. There are M^{\left\lfloor \frac{N}{2} \right\rfloor} such configurations.
- Note that this is also the value of ct_0.
- If the GCD is a positive multiple of g, say x\cdot g, then by definition there are ct_{x\cdot g} configurations with GCD equal to x\cdot g.
So, we obtain
This can be computed, again, by simply iterating over multiples of g.
As long as the values of ct_{g} are cached when computing them (or simply computed in descending order of g), there’s no extra work necessary here: the complexity is just the number of multiples of g.
Once all the ct_g values are known, the answer can be computed as a simple summation in \mathcal{O}(M) as mentioned earlier.
Note that when N is odd, we haven’t accounted for the middle element, which isn’t actually paired with anything else.
However, this also means it can take any value at all without changing the value of f(A) - so we can simply multiply the obtained answer by M.
As for the time complexity, the only real work we do is iterating over all integers and their multiples from 1 to M, and then a couple of exponentiations.
The former has a complexity of \mathcal{O}(\frac{M}{1} + \frac{M}{2} + \ldots + \frac{M}{M}) = \mathcal{O}(M\log M), while the latter has a complexity of \mathcal{O}(\log N) if binary exponentiation is used, and is done \mathcal{O}(M) times in total, for \mathcal{O}(M\log N).
TIME COMPLEXITY:
\mathcal{O}(M\log M + M\log{N}) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 998244353;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
ll bexp(ll a, ll b) {
a %= MOD;
if (a == 0) return 0;
ll res = 1;
while (b) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
void solve(int test_case){
ll n,m; cin >> n >> m;
vector<ll> w(m+5);
w[0] = m;
rep1(i,m-1){
w[i] = 2*(m-i);
}
vector<ll> dp(m+5);
rev(i,m,1){
ll pick_ways = 0;
for(int j = 0; j <= m; j += i){
pick_ways += w[j];
}
ll ways = bexp(pick_ways,n/2)-bexp(w[0],n/2);
ways = (ways%MOD+MOD)%MOD;
if(n&1) ways = ways*m%MOD;
dp[i] = ways;
for(int j = 2*i; j <= m; j += i){
dp[i] -= dp[j];
}
dp[i] = (dp[i]%MOD+MOD)%MOD;
}
vector<ll> divs(m+5);
rep1(i,m){
for(int j = i; j <= m; j += i){
divs[j]++;
}
}
ll ans = 0;
rep1(i,m){
ans += dp[i]*divs[i];
ans %= MOD;
}
{
ll bad = bexp(w[0],n/2);
if(n&1) bad = bad*m%MOD;
ans -= bad;
ans = (ans%MOD+MOD)%MOD;
}
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int facN = 1e6 + 5;
const int mod = 998244353;
int ff[facN], iff[facN];
bool facinit = false;
int power(int x, int y){
if (y == 0) return 1;
int v = power(x, y / 2);
v = 1LL * v * v % mod;
if (y & 1) return 1LL * v * x % mod;
else return v;
}
void factorialinit(){
facinit = true;
ff[0] = iff[0] = 1;
for (int i = 1; i < facN; i++){
ff[i] = 1LL * ff[i - 1] * i % mod;
}
iff[facN - 1] = power(ff[facN - 1], mod - 2);
for (int i = facN - 2; i >= 1; i--){
iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
}
}
int C(int n, int r){
if (!facinit) factorialinit();
if (n == r) return 1;
if (r < 0 || r > n) return 0;
return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}
int P(int n, int r){
if (!facinit) factorialinit();
assert(0 <= r && r <= n);
return 1LL * ff[n] * iff[n - r] % mod;
}
int Solutions(int n, int r){
//solutions to x1 + ... + xn = r
//xi >= 0
return C(n + r - 1, n - 1);
}
void Solve()
{
int n, m; cin >> n >> m;
// diff = 3 => ways = 2, 2 * 2
// diff = 2 => ways = 4, 4 * 2
// diff = 1 => ways = 6, 6 * 1
// diff = 0 => ways = 4, 4 * -1
vector <int> f(m + 1, 0);
// number of arrays with gcd of common differences = d
vector <int> a(m + 1, 0);
for (int i = 0; i <= m; i++){
// number of pairs with common difference = i
if (i == 0){
a[i] = m;
} else {
a[i] = 2 * m - 2 * i;
}
}
int g = (n / 2);
int all_zero = power(a[0], g);
for (int d = m; d >= 1; d--){
int ways = 0;
for (int i = 0; i <= m; i += d){
ways += a[i];
}
ways %= mod;
ways = power(ways, g);
f[d] = ways - all_zero;
f[d] %= mod;
if (f[d] < 0) f[d] += mod;
for (int j = d + d; j <= m; j += d){
f[d] -= f[j];
}
f[d] %= mod;
if (f[d] < 0) f[d] += mod;
assert(f[d] >= 0);
}
f[0] = all_zero;
if (n % 2 == 1){
for (int d = 0; d <= m; d++){
f[d] *= m;
f[d] %= mod;
}
}
int ans = 0;
// if gcd = g, can choose any divisors
// so need to find number of divisors
// fuck it n sqrt (n)
vector <int> cnt(m + 1, 0);
cnt[0] = -1;
for (int i = 1; i <= m; i++){
for (int j = 1; j * j <= i; j++){
if (i % j == 0){
cnt[i]++;
if (j * j != i){
cnt[i]++;
}
}
}
}
for (int i = 0; i <= m; i++){
ans += cnt[i] * f[i];
ans %= mod;
}
if (ans < 0) ans += mod;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
n, m = map(int, input().split())
dp = [0]*(m+1)
facs = [0]*(m+1)
ans = 0
for g in reversed(range(1, m+1)):
# gcd = g
# floor(n/2) pairs, for each choose a pair with gcd = x*g
# remove arrays where gcd = 0 or gcd > g
pairs = m
for d in range(g, m, g):
facs[d] += 1
pairs += 2*(m-d)
if d > g: dp[g] -= dp[d]
dp[g] += pow(pairs, n//2, mod) - pow(m, n//2, mod)
dp[g] %= mod
for i in range(1, m+1): ans += facs[i] * dp[i]
ans -= pow(m, n//2, mod)
if n%2 == 1: ans *= m
print(ans % mod)