MGICMENU - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Anik Sarker, Ezio Auditore

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

PREREQUISITES:

Inclusion-Exclusion, Mobius Inversion, Prefix-sums, Observations and Divide and Conquer.

PROBLEM:

Given an integer M, you have to answer Q queries of the form L, R, K and for each query, you have to calculate the number of ways to select up to K integers all within range [L, R] such that gcd of selected elements is 1. Constraint 1 \leq L \leq R \leq M holds.

EXPLANATION

First of all, let’s define g(d) denote the number of ways to select exactly K elements within the range [L, R] such that their GCD is a multiple of d.

It is easy to see that the number of multiples of d in range [L, R] is \lfloor \frac{R}{d} \rfloor - \lfloor \frac{L-1}{d} \rfloor. If we select elements only out of these multiples, the GCD shall also be a multiple of d. We have to select K elements and for each element, we have K choices. This gives g(d) = \Big(\lfloor \frac{R}{d} \rfloor - \lfloor \frac{L-1}{d} \rfloor \Big)^K.

Now, let us use Inclusion-Exclusion to calculate the number of ways to select K elements such that their gcd is one using g(d). Let’s define f(d) as the number of ways to select K elements in the range [L, R] such that their gcd is d. We want to calculate f(1) as expression of g(x) only.

Let’s assume M = 6. Then
g(1) = f(1)+f(2)+f(3)+f(4)+f(5)+f(6),
g(2) = f(2) + f(4) + f(6),
g(3) = f(3)+f(6),
g(4) = f(4),
g(5) = f(5) and g(6) = f(6).

Let’s start with g(1), since only that contain f(1). Now, we need to exclude g(1)-f(1). We need to exclude f(2) which is a part of g(2), so we subtract g(2). Now, the expression becomes g(1)-g(2) = f(1)+f(3)+f(5). Now, to remove f(3), we subtract g(3). Expression becomes g(1)-g(2)-g(3) = f(1)+f(5)-f(6). To remove f(5), we subtract g(5). Expression becomes g(1)-g(2)-g(3)-g(5) = f(1)-f(6). Now, we need to add g(6) to both sides to obtain f(1).

The final expression becomes f(1) = g(1)-g(2)-g(3)-g(5)+g(6). Does this ring the bell? It’s just Mobius inversion in play. We can see that for a general M, that f(1) = \sum_{d = 1}^{M} μ(d)*g(d) where μ(d) is the mobius function.

So, \sum_{d = 1}^{M} μ(d)*g(d) = \sum_{d = 1}^{M} μ(d)*\Big(\lfloor \frac{R}{d} \rfloor - \lfloor \frac{L-1}{d} \rfloor \Big)^K only include sequences of size K, but we need sequences of size up to K, so we take the sum for each size from 1 to K.

The expression becomes \sum_{n = 1}^{K}\sum_{d = 1}^{M} μ(d)*\Big(\lfloor \frac{R}{d} \rfloor - \lfloor \frac{L-1}{d} \rfloor \Big)^n. Rearranging this, we have \sum_{d = 1}^{M} μ(d)*\Big(\sum_{n = 1}^{K}\Big(\lfloor \frac{R}{d} \rfloor - \lfloor \frac{L-1}{d} \rfloor \Big)^n \Big). The summation inside bracket is actually a Geometric Progression, which we can calculate sum in O(logK) using divide and Conquer (due to non-prime mod).

GP Sum using Divide and Conquer

We want to calculate GP sum with first term a, common ratio r and n terms.
First, we can just compute GP sum with first term 1 and multiply it by a. So, now assume the first term is 1.

We can see the following three things if gp(r, n) denote GP sum with first term 1.

  • gp(r, 1) = 1
  • If n is even, gp(r, n) = (1+r)*gp(r^2, n/2)
  • gp(r, n) = 1+r*gp(r, n-1)

Above three allows us to use divide and conquer and calculate GP sum in O(logN) time.

So, the required summation is \sum_{d = 1}^{M} μ(d)*GP(f(L, R, d), f(L, R, d), K) where f(L, R, d) = \Big( \lfloor\frac{R}{d} \rfloor-\lfloor\frac{L-1}{d}\rfloor\Big).

This gives us O(Q*M*log(K)) solution, but we need faster.

One final optimization is by noticing that f(L, R, d) can take only up to \sqrt{R} + \sqrt{L} different values for different values of d. This comes by the fact that \lfloor\frac{R}{d}\rfloor takes at most \sqrt{R} values and \lfloor\frac{L-1}{d}\rfloor also takes at most \sqrt{L} distinct values. Now, we need to divide the interval [1, R] into intervals such that \Big( \lfloor\frac{R}{d} \rfloor-\lfloor\frac{L-1}{d}\rfloor\Big) remains same for all values of d in same interval.

This process can be done as dividing range [1, R] into intervals such that R/d remains same in each interval, dividing range [1, R] into intervals such that (L-1)/d remains same in each interval and then merge these intervals.

The dividing process can be done similar to the square root factorization. It is an exercise, and if facing trouble, refer the following code snippet after trying.

Code Snippet

//
ArrayList<int[]> generate(int val){
ArrayList<int[]> list = new ArrayList<>();
for(int i = 1; i*i <= val; i++){
list.add(new int[]{i, i, val/i});
if(i != val/i)list.add(new int[]{val/(i+1)+1, val/i, i});
}
Collections.sort(list, (int[] i1, int[] i2) -> Integer.compare(i1[0], i2[0]));
return list;
}
//Merging two intervals
ArrayList<int[]> merge(ArrayList<int[]> L, ArrayList<int[]> R){
ArrayList<int[]> list = new ArrayList<>();
int i = 0, j = 0;
while(i < L.size() && j< R.size()){
int[] cur1 = L.get(i), cur2 = R.get(j);
int[] cur = new int[]{Math.max(cur1[0], cur2[0]), Math.min(cur1[1], cur2[1]), cur2[2]-cur1[2]};
list.add(cur);
if(cur1[1] < cur2[1])i++;
else if(cur1[1] > cur2[1])j++;
else {i++;j++;}
}
return list;
}

Now, we can take prefix sums of Mobius function and for each interval, we can calculate the GP sum and take its product with summation of Mobius function in interval range, leading to answering query in time proportional to number of intervals, leading to O(Q*\sqrt{M}*log(K)) time which is sufficient to fit the time limit.

TIME COMPLEXITY

Overall Time complexity is O(Q*\sqrt{M}*log(K)).

SOLUTIONS:

Setter 1 Solution
#include<bits/stdc++.h>
using namespace std;
#define MAX 100005
#define MOD (1LL<<30)
#define ll long long int

ll Cum[MAX];
int mobius[MAX];

void Init(){
	mobius[1] = 1;
	for(int i=1;i<MAX;i++){
	    for(int j=i+i;j<MAX;j+=i){
	        mobius[j] -= mobius[i];
	    }
	}
	for(int i=1;i<MAX;i++) Cum[i] = Cum[i-1] + mobius[i];
}

ll GeoSum(ll a,int n){
	ll sz = 0;
	ll ret = 0;
	ll mul = 1;
	int MSB = 31 - __builtin_clz(n);

	while(MSB >= 0){
	    ret = ret * (1 + mul); sz <<= 1; mul = (mul * mul) % MOD;
	    if((n>>MSB)&1) mul = (mul * a) % MOD, ret += mul;
	    ret %= MOD; MSB--;
	}
	return ret;
}

ll Solve(ll b,ll d,ll k){
	vector<ll>vec;
	for(int i=1;i*i<=b;i++){
	    vec.push_back(i);
	    vec.push_back(b/i);
	}

	for(int i=1;i*i<=d;i++){
	    vec.push_back(i);
	    vec.push_back(d/i);
	}


	vec.push_back(0);
	sort(vec.begin(),vec.end());
	vec.erase(unique(vec.begin(),vec.end()),vec.end());

	ll Sum = 0;
	for(int i=1;i<vec.size();i++){
	    ll Curr = Cum[vec[i]] - Cum[vec[i-1]];
	    if(Curr < 0) Curr += MOD;
	    ll Now = d/vec[i] - b/vec[i];
	    Sum += (Curr * GeoSum(Now,k)) % MOD;
	    if(Sum >= MOD) Sum -= MOD;
	}
	return Sum;
}

int main(){
	Init();

	int m,q;
	scanf("%d %d",&m,&q);

	assert(1<= m && m <= 100000);
	assert(1<= q && q <= 10000);

	for(int cs=1;cs<=q;cs++){
	    int l,r,k;
	    scanf("%d %d %d",&l,&r,&k);
	    assert(1<= l && l <= r && r <= m);
	    assert(1<= k && k <= 1000000000);

	    ll Ans = Solve(l-1,r,k);
	    printf("%lld\n",Ans);
	}
}
Setter 2 Solution
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
const ll mod = (1LL<<30);

ll t, m;
vector < ll > vec;
ll mob[300000];

ll seriessum(ll val, ll k)
{


	ll msb = 63 - __builtin_clzll(k);
	ll re = 0, cur = 1;
	while(msb + 1){
	    re = (re + cur * re) % mod;
	    cur = (cur * cur) % mod;
	    if((1LL << msb) & k) {
	        cur = (cur * val) % mod;
	        re += cur;
	        if(re >= mod) re -= mod;
	    }
	    msb--;
	}
	return re;
}


int main()
{

	mob[1] = 1;
	for(ll i = 1; i <= 100009; i++){
	    for(ll j = i + i; j <= 100009; j += i) mob[j] -= mob[i];
	    mob[i] += mob[i - 1];
	}

	cin >> m >> t;
	if(m < 1  || m > 100000) assert(false);
	if(t < 1 || t > 10000) assert(false);
	while(t--){

	    ll L, R, K;
	    scanf("%lld %lld %lld", &L, &R, &K);
	    if(L < 1 || L > m|| R < 1 || R > m) assert(false);
	    if(K < 1 || K > 1000000000) assert(false);
	    vec.clear();

	    for(ll i = 1; ; i++){
	        ll tmp = i * i;
	        if(tmp > (L-1)) break;
	        vec.push_back(i);
	        vec.push_back((L-1)/i);
	    }

	    for(ll i = 1; ; i++){
	        ll tmp = i * i;
	        if(tmp > R) break;
	        vec.push_back(i);
	        vec.push_back(R/i);
	    }
	    vec.push_back(0);
	    sort(vec.begin(), vec.end());
//        vec.erase(unique(vec.begin(), vec.end()), vec.end());

	    ll ans = 0;

	    for(ll i = 1; i < vec.size(); i++){
	        ll ccmob = (mob[vec[i]] - mob[vec[i - 1]]);
	        if(ccmob < 0) ccmob += mod;
	        if(ccmob >= mod) ccmob -= mod;

	        ll tmp = (R / vec[i]) - ((L - 1) / vec[i]);

	        ll tmp2 = seriessum(tmp, K);
	        ans = (ans + tmp2 * ccmob) % mod;
//            cout << i << ' ' << vec[i] << ' ' << tmp << ' ' << tmp2 << endl;

	    }

	    printf("%lld\n", ans);

	}

	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1<<30)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

#define int ll

int MAXM = 100006;
int haha = 700;
int BSIZE = 1000;
vector<vi> fact(123456);
int mobi[123456],pr[123456];
int cnt[123456],wow[123456];
int l[123456],r[123456],k[123456],ans[123456];
struct total{
	int l,r,ind,bl;
};
total quer[123456];
int comp(total a,total b){
	if(a.bl!=b.bl)
		return a.bl<b.bl;
	if(a.bl%2)
		return a.r>b.r;
	else
		return a.r<b.r;
}
int add(int val1){
	int i,val;
	rep(i,fact[val1].size()){
		val=fact[val1][i];
		cnt[wow[val]]-=mobi[val];
		wow[val]++;
		cnt[wow[val]]+=mobi[val];
	}

}
int remov(int val1){
	int i,val;
	rep(i,fact[val1].size()){
		val=fact[val1][i];
		cnt[wow[val]]-=mobi[val];
		wow[val]--;
		cnt[wow[val]]+=mobi[val];
	}

}
int gao[123456];

pii compute(int a,int r){
	if(r==1){
		return mp(a,a);
	}
	pii papa=compute(a,r/2);
	papa.ff+=papa.ss*papa.ff;
	papa.ff%=mod;
	papa.ss*=papa.ss;
	papa.ss%=mod;
	if(r%2){
		papa.ss*=a;
		papa.ss%=mod;
		papa.ff+=papa.ss;
		if(papa.ff>=mod)
			papa.ff-=mod;
	}
	return papa;

}

int gpsum(int a,int r){
	if(a==1)
		return r;
	pii papa=compute(a,r);
	return papa.ff;
}
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int i,j;
	rep(i,MAXM){
		mobi[i]=1;
	}
	f(i,2,MAXM){
		if(pr[i])
			continue;
		for(j=i;j<MAXM;j+=i){
			pr[j]=1;
			if((j/i)%(i)==0){
				mobi[j]=0;
			}
			else{
				mobi[j]*=-1;
			}

		}
	}
	//cerr<<mobi[2]<<" "<<mobi[3]<<" "<<mobi[6]<<endl;
	for(i=haha;i<MAXM;i++){
		if(mobi[i]==0)
			continue;
		for(j=i;j<=MAXM;j+=i){
			fact[j].pb(i);
		}
	}
	int maxi=0;
	rep(i,MAXM){
		maxi=max(maxi,(int)fact[i].size());
	}
	cerr<<maxi<<endl;
	int m,q;
	cin>>m>>q;
   	j=0;
	f(i,1,haha){
		if(mobi[i]){
			gao[j++]=i;
		}
	}	
	int BOUND=j,val;
	rep(i,q){
		cin>>l[i]>>r[i]>>k[i];
		rep(j,BOUND){
			val=gpsum(r[i]/gao[j]-(l[i]-1)/gao[j],k[i]);
			val*=mobi[gao[j]];
			ans[i]+=val;
		}
		ans[i]%=mod;
		quer[i].l = l[i];
		quer[i].r = r[i];
		quer[i].bl = l[i]/BSIZE;
		quer[i].ind = i;
	}    	
	sort(quer,quer+q,comp);
	int lo=1,ri=1;
	int poss=MAXM/haha+3;
	rep(i,q){
		int ind = quer[i].ind;
		while(l[ind]<lo){
			lo--;
			add(lo);
		}
		while(r[ind]>ri){
			ri++;
			add(ri);
		}
		while(lo<l[ind]){
			remov(lo);
			lo++;
		}
		while(r[ind]<ri){
			remov(ri);
			ri--;
		}
		f(j,1,poss){
			if(cnt[j]!=0){
				val=gpsum(j,k[ind]);
				val*=cnt[j];
				ans[ind]+=val;
			}
		}
		ans[ind]%=mod;
		ans[ind]+=mod;
		ans[ind]%=mod;
	}
	rep(i,q){
		cout<<ans[i]<<endl;
	}
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class MGICMENU{
	//SOLUTION BEGIN
	//Into the Hardware Mode
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int M = ni();
	    int[] mob = mobius(M);
	    for(int i = 1; i<= M; i++)mob[i] = mob[i-1]+mob[i];
	    for(int qq = ni(); qq>0; qq--){
	        int l = ni(), r = ni();long k = nl();
	        ArrayList<int[]> R = generate(r), L = generate(l-1);
	        L.add(new int[]{l, r, 0});
	        ArrayList<int[]> list = merge(L, R);
	        long ans = 0;
	        for(int[] i:list)ans = add(ans, mul((mob[i[1]]+mod-mob[i[0]-1])%mod, mul(i[2], gp(i[2], k))));
	        pn(ans);
	    }
	}
	long gp(long r, long n){
	    if(n == 1)return 1;
	    if(n%2 == 0)return mul((1+r), gp(r*r, n/2));
	    return (1+mul(r,gp(r, n-1)))%mod;
	}
	ArrayList<int[]> generate(int val){
	    ArrayList<int[]> list = new ArrayList<>();
	    for(int i = 1; i*i <= val; i++){
	        list.add(new int[]{i, i, val/i});
	        if(i != val/i)list.add(new int[]{val/(i+1)+1, val/i, i});
	    }
	    Collections.sort(list, (int[] i1, int[] i2) -> Integer.compare(i1[0], i2[0]));
	    return list;
	}
	ArrayList<int[]> merge(ArrayList<int[]> L, ArrayList<int[]> R){
	    ArrayList<int[]> list = new ArrayList<>();
	    int i = 0, j = 0;
	    while(i < L.size() && j< R.size()){
	        int[] cur1 = L.get(i), cur2 = R.get(j);
	        int[] cur = new int[]{Math.max(cur1[0], cur2[0]), Math.min(cur1[1], cur2[1]), cur2[2]-cur1[2]};
	        list.add(cur);
	        if(cur1[1] < cur2[1])i++;
	        else if(cur1[1] > cur2[1])j++;
	        else {i++;j++;}
	    }
	    return list;
	}
	//Slow O(r) function, returns number of ways to select k elements in range [l, r] such that their gcd = 1, mob is mobius function
	long f(int[] mob, long l, long r, long k){
	    long ans = 0;
	    for(int i = 1; i<= r; i++)ans += mob[i]*pow(r/i-(l-1)/i, k);
	    return ans;
	}
	long mul(long a, long b){
	    if(a>=mod)a%=mod;
	    if(b>=mod)b%=mod;
	    a*=b;
	    if(a>=mod)a%=mod;
	    return a;
	}
	long add(long a, long b){
	    if(Math.abs(a)>=mod)a%=mod;
	    if(a<0)a+=mod;
	    if(Math.abs(b)>=mod)b%=mod;
	    if(b<0)b+=mod;
	    a+=b;
	    if(Math.abs(a)>=mod)a%=mod;
	    return a;
	}
	long pow(long a, long p){
	    long o = 1;
	    while(p>0){
	        if(p%2==1)o = (a*o)%mod;
	        a = (a*a)%mod;
	        p>>=1;
	    }
	    return o;
	}
	int[] mobius(int mx){
	    mx++;
	    int[] mob = new int[mx];
	    mob[1] = 1;
	    boolean[] p = new boolean[mx];
	    Arrays.fill(p, true);
	    for(int i = 2; i< mx; i++){
	        if(p[i]){
	            for(int j = 1; i*j < mx; j++){
	                if(j>1)p[j*i] = false;
	                if(j%i!= 0)mob[j*i] = -1*mob[j];
	                else mob[j*i] = 0;
	            }
	        }
	    }
	    return mob;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	long IINF = (long)1e18, mod = (long)1l<<30l;
	final int INF = (int)1e9, MX = (int)2e5+5;
	DecimalFormat df = new DecimalFormat("0.00000000000");
	double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-6;
	static boolean multipleTC = false, memory = false, fileIO = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    if(fileIO){
	        in = new FastReader("input.txt");
	        out = new PrintWriter("output.txt");
	    }else {
	        in = new FastReader();
	        out = new PrintWriter(System.out);
	    }
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    if(memory)new Thread(null, new Runnable() {public void run(){try{new MGICMENU().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
	    else new MGICMENU().run();
	}
	long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
	int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}
 
	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }
 
	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }
 
	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }
 
	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }
	        return str;
	    }
	}
}

Feel free to share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile: