COOLSBST - Editorial

PROBLEM LINK:

Practice
Div-3 Contest
Div-2 Contest
Div-1 Contest

Author & Editorialist: Nishant Shah
Tester: Radoslav Dimitrov

DIFFICULTY:

MEDIUM-HARD

PREREQUISITES:

Number Theory, Combinatorics , Square-root Decomposition

PROBLEM:

The coolness of a non-empty set of N distinct integers S = \{x_1, x_2, \ldots, x_N\} is the count of integers X which divides \prod_{i=1}^N x_i and primitive root under modulo X exists.

Find the sum of coolness values over all possible non-empty sets having size in the range [A,B], and elements in the range [L,R].

QUICK EXPLANATION:

Consider the cases for values of X where primitive root under modulo X exists and separately derive formulas for each them for fixed size B.

For all subtasks, We can take summation over B , which becomes sums of N \choose R values for fixed N. That can be calculated using square-root decomposition.

EXPLANATION:

Let’s first see what exactly we have to calculate.

An integer X is nice if primitive root modulo ‘X’ exists.

So X must be one of the following :

(1) X = 1
(2) X = 2
(3) X = 4
(4) X = p^k
(5) X = 2*p^k

where p is an odd prime and k \in \N .

(Source : Multiplicative group of integers modulo n - Wikipedia)

Now we can independently find the contribution of each of the 5 cases.

For now, consider a problem where we want to find total coolness of all subsets of size exactly B. and let’s denote K = \prod_{i=1}^N x_i. K corresponds to product of elements of a set.

Case 1 : X = 1

1 is a divisor of every number, which means , for all the non-empty sets, 1 divides K.

Hence, this case contributes exactly 1 for every possible set.

\therefore ans1 = R-L+1 \choose B

Case 2 : X = 2

2 divides K when set S contains at least one even integer. Conversely, 2 does not divide K when all the elements in S are odd.

So we consider all possible sets and subtract those having only odd elements.

If we denote, Oddcount = Count of odd integers in the range [L,R] = R - R/2 - ((L-1) - (L-1)/2)

then,
ans2 = R-L+1 \choose B - Oddcount \choose B

Case 3 : X = 4

This case contributes 1 to the answer if K is a multiple of 4.

ans2 already contains all the sets for which K is even, so we subtract the case when K is multiple of 2,but not 4.

For that, we choose B-1 odd integers and one integer which is multiple of 2 , but not 4.

If, Only2 = Count of integers in the range [L,R] which are divisible by 2 but not 4 = R/2 - R/4 - ((L-1)/2 - (L-1)/4)

then, ans3 = ans2 - Only2*Oddcount \choose B-1

Case 4 : X = p^k

Consider only the odd primes. Then, each power of each odd prime from each integer will contribute 1 to the answer .

So for each integer in the range [L,R], whenever it occurs in the set S, it’s contribution will be sum of powers of odd primes in it.

Let’s denote Powersum = Sum of powers of odd primes of integers in the range [L,R].

Here, we can calculate Powersum for each integer using sieve and then take prefixsum.

ans4 = Powersum*R-L \choose B-1

where, R-L \choose B-1 is the number of times an integer in the range [L,R] occurs in S.

Case 5 : X = 2*p^k

This is similar to Case 2, we add the answer we got in case 4 and subtract from it, the contribution of prime powers when K will be odd.

Let’s denote
Powersumodd = Sum of powers of odd primes of odd integers in the range [L,R] = Powersum(L,R) - Powersum(L/2,R/2)

then, ans5 = Powersum* R-L \choose B-1 - Powersumodd*Oddcount-1 \choose B-1

Now consider the problem when we have to find this answer for sizes 0 to B. We can solve that problem by replacing all the occurances of C(N,R) by S(N,R).

Where , C(N,R) = N \choose R and S(N,R) = \sum_{i=0}^R N \choose i

So we have to calculate S(N,R) values efficiently.
The main idea is to use square-root decomposition in some form. Below is one of the ways :

We know that ,

C(N,R) = C(N-1,R) + C(N-1,R-1)
\therefore S(N,R) = S(N-1,R) + S(N-1,R-1)

and S(N,R) = S(N,R-1) + C(N,R) by definition

So we get, S(N,R) = C(N-1,R) + 2*S(N-1,R-1)

This enables us to use square-root decomposition on N,

So we precalculate the S(N,R) values when N is multiple of BLOCK, so we’ll maintain N/BLOCK rows.

Now for a query S(N,R) we use the recurrance mentioned above, so N reduces by 1 each time and we’ll get the answer in atmost BLOCK steps.

Choosing BLOCK as O(\sqrt(N)) , the whole procedure works in O(N*\sqrt(N)) time and memory.

This was enough to pass within the time limit.

SOME OPTIMIZATIONS

(1) To solve one query for size atmost B, 5 different values of S(N,R) are required, but it can be reduced to only 2 values and other 3 values can be calculated from these values in O(1).

(2) This approach can also be performed offline by taking all the queries S(N,R) and sorting them by N. Then we only have to maintain one row at a time, which will cost only O(N) memory.

(3) If there are more queries, then it is better to use Mo’s Algorithm here, which solves them in O(MAXN\sqrt(MAXN) + Qlog(Q)).

Outline of Mo's algorithm

The idea is that, the queries S(N,R) will be represented as contiguous ranges, and we have to either change N or R by 1.

To change R, we just have to add or subtract some C(N,R) value. To change N, we can use this recurrance S(N,R) = C(N-1,R) + 2*S(N-1,R-1).

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;

template<const int &MOD>
struct _m_int {
	int val;
 
	_m_int(int64_t v = 0) {
		if (v < 0) v = v % MOD + MOD;
		if (v >= MOD) v %= MOD;
		val = v;
	}
 
	static int mod_inv(int a, int m = MOD) {
		// https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Example
		int g = m, r = a, x = 0, y = 1;
 
		while (r != 0) {
			int q = g / r;
			g %= r; swap(g, r);
			x -= q * y; swap(x, y);
		}
 
		return x < 0 ? x + m : x;
	}
 
	explicit operator int() const { return val; }
	explicit operator int64_t() const { return val; }
 
	_m_int& operator+=(const _m_int &other) {
		val -= MOD - other.val;
		if (val < 0) val += MOD;
		return *this;
	}
 
	_m_int& operator-=(const _m_int &other) {
		val -= other.val;
		if (val < 0) val += MOD;
		return *this;
	}
 
	static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
		return x % m;
	}
 
	_m_int& operator*=(const _m_int &other) {
		val = fast_mod((uint64_t) val * other.val);
		return *this;
	}
 
	_m_int& operator/=(const _m_int &other) {
		return *this *= other.inv();
	}
 
	friend _m_int operator+(const _m_int &a, const _m_int &b) { return _m_int(a) += b; }
	friend _m_int operator-(const _m_int &a, const _m_int &b) { return _m_int(a) -= b; }
	friend _m_int operator*(const _m_int &a, const _m_int &b) { return _m_int(a) *= b; }
	friend _m_int operator/(const _m_int &a, const _m_int &b) { return _m_int(a) /= b; }
 
	_m_int& operator++() {
		val = val == MOD - 1 ? 0 : val + 1;
		return *this;
	}
 
	_m_int& operator--() {
		val = val == 0 ? MOD - 1 : val - 1;
		return *this;
	}
 
	_m_int operator++(int) { _m_int before = *this; ++*this; return before; }
	_m_int operator--(int) { _m_int before = *this; --*this; return before; }
 
	_m_int operator-() const {
		return val == 0 ? 0 : MOD - val;
	}
 
	bool operator==(const _m_int &other) const { return val == other.val; }
	bool operator!=(const _m_int &other) const { return val != other.val; }
 
	_m_int inv() const {
		return mod_inv(val);
	}
 
	_m_int pow(int64_t p) const {
		if (p < 0)
			return inv().pow(-p);
 
		_m_int a = *this, result = 1;
 
		while (p > 0) {
			if (p & 1)
				result *= a;
 
			a *= a;
			p >>= 1;
		}
 
		return result;
	}
 
	friend ostream& operator<<(ostream &os, const _m_int &m) {
		return os << m.val;
	}
};
 
extern const int MOD = 998244353;
using mod_int = _m_int<MOD>;

const int MAXN = 1e5+1;
const int BLOCK = 200;

int powers[MAXN];

mod_int fact[MAXN],ifact[MAXN],pw2[MAXN],calc[(MAXN/BLOCK) + 2][MAXN];

mod_int C(int N,int R)
{
	if(R>N) return 0;
	return fact[N]*ifact[R]*ifact[N-R];
}

//S(n,r) = Summation(i = 0 to r) C(n,i)
mod_int S(int N,int R)
{
   if(N < 0 || R < 0) return 0;

   if(R == 0) return 1;
  
   if(R >= N) return pw2[N];
	
   if((N % BLOCK) == 0) return calc[N/BLOCK][R];
   
   return C(N-1,R) + (2*S(N-1,R-1));
}

void precalc()
{
	for(int i=3;i<MAXN;i+=2)
	  if(powers[i] == 0)  
		for(int j=i;j<MAXN;j+=i)
		  powers[j] = powers[j/i] + 1;

	for(int i=2;i<MAXN;i++) powers[i]+=powers[i-1];

	fact[0] = 1;    
	for(int i=1;i<MAXN;i++) fact[i] = (fact[i-1]*i);
  
	ifact[MAXN-1] = fact[MAXN-1].inv();
	for(int i=MAXN-2;i>=0;i--) ifact[i] = (ifact[i+1]*(i+1));
		
	for(int i=0,j=0;i<MAXN;i+=BLOCK,j++)  
	{
		calc[j][0] = 1;
		for(int k=1;k<=i;k++) calc[j][k] = calc[j][k-1]+C(i,k);
	}
	
	pw2[0] = 1;
	for(int i=1;i<MAXN;i++) pw2[i] = (pw2[i-1] + pw2[i-1]);
}

mod_int solve(int L,int R,int B)
{
   L--; 
	
   int N = R - L;
   int Oddcount = R - R/2 - (L - L/2);
   int Only2 = (R/2 - R/4) - (L/2 - L/4);
   int Powersum = powers[R] - powers[L];
   int Powersumodd = Powersum - (powers[R/2] - powers[L/2]);
   
   mod_int S4 = S(N-1,B-1); //S(N-1,B-1)
   mod_int S5 = S(Oddcount-1,B-1); //S(oddcount-1,B-1)
   mod_int S2 = C(Oddcount-1,B) + S5 + S5; //S(oddcount,B)
   mod_int S3 = S2 - C(Oddcount,B); //S(oddcount,B-1)  
   mod_int S1 = C(N-1,B)+ S4 +S4 ; //S(N,B)

   //case 1 : X = 1 
   mod_int ans1 = S1;

   //case 2 : X = 2
   mod_int ans2 = S1 - S2;
	   
   //case 3 : X = 4
   mod_int ans3 = S1 - S2 - (Only2*S3);

   //case 4 : X = p^k
   mod_int ans4 = (Powersum*S4);
	
   //case 5 : X = 2*(p^k)
   mod_int ans5 = (Powersum*S4) - (Powersumodd*S5) ;
  
   return ans1 + ans2 + ans3 + ans4 + ans5;
}

signed main()
{
	precalc();	
	
	int T,L,R,A,B;
	cin >> T;

	while(T--)
	{  
	   cin >> L >> R >> A >> B;
	   cout << solve(L,R,B) - solve(L,R,A-1) << '\n';
	}
} 
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (int)1e5 + 42;
const int B = 250;
const int mod = 998244353;
const int inv2 = (mod + 1) / 2;
 
int pw2[MAXN];
int pw(int x, int p) {
	int r = 1;
	while(p) {
		if(p % 2 == 1) r = r * 1ll * x % mod;
		x = x * 1ll * x % mod;
		p >>= 1;
	}
 
	return r;
}
 
void fix(int &x) {
	if(x >= mod) x -= mod;
	if(x < 0) x += mod;
}
 
int inv[MAXN], fact[MAXN], ifact[MAXN];
 
void comb_prepare() {
	pw2[0] = 1;
	for(int i = 1; i < MAXN; i++) {
		pw2[i] = (2 * pw2[i - 1]) % mod;
	}
 
	fact[0] = 1;
	for(int i = 1; i < MAXN; i++) {
		fact[i] = fact[i - 1] * 1ll * i % mod;
	}
 
	ifact[MAXN - 1] = pw(fact[MAXN - 1], mod - 2);
	for(int i = MAXN - 2; i >= 0; i--) {
		ifact[i] = (i + 1) * 1ll * ifact[i + 1] % mod;
		inv[i + 1] = ifact[i + 1] * 1ll * fact[i] % mod;
	}
}
 
int C(int n, int k) {
	if(k > n || k < 0) return 0;
	return ((ifact[n - k] * 1ll * ifact[k]) % mod) * 1ll * fact[n] % mod;
}
 
int prec[MAXN / B + 2][MAXN];
 
int lp[MAXN];
int sum_powers[MAXN];
int cnt_odd[MAXN];
int cnt_2[MAXN];
 
void precompute() {
	comb_prepare();
	for(int k = 0; k < MAXN; k += B) {
		int i = k / B; 
		prec[i][0] = 1;
		for(int n = 1; n < MAXN; n++) {
			prec[i][n] = 2 * prec[i][n - 1] - C(n - 1, k);
			fix(prec[i][n]);	
		}
	}
 
	for(int i = 2; i < MAXN; i++) {
		for(int j = i; j < MAXN; j += i) {
			if(lp[j] == 0) lp[j] = i;
		}
	}
	
	cnt_2[2] = 1;
	cnt_odd[1] = 1;
	cnt_odd[2] = 1;
	for(int i = 3; i < MAXN; i++) {
		int x = i;
		cnt_2[i] = cnt_2[i - 1] + (i % 2 == 0 && i % 4 != 0);
		cnt_odd[i] = cnt_odd[i - 1] + (i % 2 == 1);
		sum_powers[i] = sum_powers[i - 1];
		while(lp[x] == 2) x /= 2;
		while(x != 1) {
			x /= lp[x];
			sum_powers[i]++;
		}
	}
}
 
int comb_sum(int n, int k) {
	if(k < 0 || n < 0) return 0;
	int ret = prec[k / B][n];
	for(int st = (k / B) * B + 1; st <= k; st++) {
		ret += C(n, st);
		fix(ret);
	} 
 
	return ret;
}
 
int l, r, a, b;
 
void read() {
	cin >> l >> r >> a >> b;
}
 
int solve(int l, int r, int k) {
	int comb_n_k = comb_sum(r - l + 1, k);
	int comb_odd_k = comb_sum(cnt_odd[r] - cnt_odd[l - 1], k);
	int comb_n1_k1 = comb_sum(r - l, k - 1); 
	int comb_odd1_k1 = comb_sum(cnt_odd[r] - cnt_odd[l - 1] - 1, k - 1);
 
	int answer = comb_n_k;
	int c2 = comb_n_k - comb_odd_k; fix(c2);
	answer += c2;
	fix(answer);
	
	int c4 = c2 - ((comb_odd_k - C(cnt_odd[r] - cnt_odd[l - 1], k) + mod) * 1ll * (cnt_2[r] - cnt_2[l - 1])) % mod; fix(c4);
	answer += c4;
	fix(answer);
	
	int cp = (sum_powers[r] - sum_powers[l - 1]) * 1ll * comb_n1_k1 % mod;
	answer += cp;
	fix(answer);
 
	// props to the author for the idea \/
	int P = sum_powers[r] - sum_powers[l - 1] - sum_powers[r / 2] + sum_powers[(l - 1) / 2];
	int cp2 = cp - (comb_odd1_k1 * 1ll * P % mod); fix(cp2);
	answer += cp2;
	fix(answer);
	return answer;
}
 
void solve() {
	int answer = solve(l, r, b) - solve(l, r, a - 1);
	fix(answer);
	cout << answer << endl;
}
 
int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);
	
	precompute();
 
	int T;
	cin >> T;
	while(T--) {
		read();
		solve();
	}
 
	return 0;
}

VIDEO EDITORIAL:

Feel free to ask any doubts. Suggestions are welcome :slight_smile:

1 Like

Great Problem!!!

1 Like

Idea about sqrt is genius! Thanks for amazing problem!

2 Likes