RNDRATIO - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Anik Sarker
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Combinatorics, Inclusion-Exclusion, and Expectation.

PROBLEM:

Given N intervals [L_i, R_i] for 1 \leq L_i \leq R_i \leq 10^5, consider all possible ways of choosing N values, i-th value being chosen in i-th interval. Let A be the product of chosen values and B be the gcd of chosen values, find the expected value of A/B over all possible ways.

QUICK EXPLANATION

EXPLANATION

Let us solve an easier problem. Instead of the expected value of A/B where A is the product and B is the greatest common divisor, let us calculate the expected value of the greatest common divisor of chosen values.

Let us count the number of valid N-tuples such that their GCD is x. Let’s call this g_x. Instead, we can easily count the N-tuples such that their GCD is a multiple of x, calling this f_x. This just means that all the values in tuple should be a multiple of x.

Let us try each x and for each x, we find the number of valid N-tuples which have all elements as multiple of x. In this tuple, the i-th value is between L_i and R_i and a multiple of x. There are only \displaystyle\Big\lfloor \frac{R_i}{x} \Big\rfloor - \Big\lfloor \frac{L_i-1}{x} \Big\rfloor possible choices for i-th value in tuple. Each value in tuple can be independently chosen from each other, so the number of valid tuples become \displaystyle\prod \Big(\Big\lfloor \frac{R_i}{x} \Big\rfloor - \Big\lfloor \frac{L_i-1}{x} \Big\rfloor \Big). Let’s do this for each x, so this takes O(N*MX) time. We’ll optimize it later.

Now, for each x, we have found g_x for all values of x. We need to compute f_x from g_x which we can do in O(N*log(N)) as follows.

Pseudo Code

f[i] -> Number of tuples with GCD as multiple of x
g[i] -> Number of tuples with GCD as x
for(int i = MX; i >= 1; i--){
    g[i] = f[i]; //Number of tuples with GCD as multiple of i
    for(int j = 2*i; j <= MX; j += i){
        g[i] -= g[j]; //Subtracting tuples with GCD j > i such that i divides j
    }
}

Let’s come back to the original problem. We want to find the expected value of the product of values in tuple divided by GCD of values in the tuple.

For a fixed x, Let us take the sum of the product of elements of all N-tuples.

Consider example

2
4 6
2 5

Let’s consider each value of x one by one. Let g_x denote the sum of the product of values in tuples, where all values of the tuple are divisible by x.

g_1 = 4*2+4*3+4*4+4*5+5*2+5*3+5*4+5*5+6*2+6*3+6*4+6*5 = 1^2*(4+5+6)*(2+3+4+5) = 210
g_2 = 4*2+4*4+6*2+6*4 = (4+6)*(2+4) = 2^2*(2+3)*(1+2) = 60
g_3 = 6*3 = (6)*(3) = 3^2*2*1 = 18
g_4 = 4*4 = (4)*(4) = 4^2*1*1 = 16
g_5 = 5*5 = (5)*(5) = 5^2*1*1 = 25
g_6 = 0

It is easy to see from above that we can write the sum of products as the product of sum of possible choices for each element of the tuple.

For example, for x = 2, we could write g_2 as (4+6)*(2+4) where first element of tuple could take values 4 and 6 and second value of tuple can take values 2 and 4. We can also divide each term of the product by x and multiply it separately. So g_2 becomes 2^N*(2+3)*(1+2) which is same as 2^N*(sum(3)-sum(1))*(sum(2)-sum(0)) where sum(n) = n*(n+1)/2 as the sum of the first n natural numbers.

Consider for a general x, and an interval [L_i, R_i], let a-th multiple of x be the largest multiple of x such that l_x \leq L_i-1 and b-th multiple of x be the largest multiple of x such that r_x \leq R. It’s not hard to prove that Sum of possible choices of multiple of x would be x*(sum(b)-sum(a)). Also, it is easy to see that a = \displaystyle\Big\lfloor \frac{L_i-1}{x} \Big\rfloor and b = \displaystyle\Big\lfloor \frac{R_i}{x} \Big\rfloor

So we have \displaystyle g_x = x^N*\prod_{i = 1}^N \Big[ sum\Big( \frac{R_i}{x}\Big) - sum\Big( \frac{L_i-1}{x} \Big) \Big]. From this, we can apply inclusion-exclusion to calculate the sum of product of values of tuple over all tuples with GCD x, say f_x.

Considering the above example again, we get
f_1 = g_1-f_2-f_3-f_4-f_5-f_6 = 107
f_2 = g_2-f_4-f_6 = 44
f_3 = g_3-f_6 = 18
f_4 = g_4 = 16
f_5 = g_5 = 25
f_6 = g_6 = 0

From this, the expected value of A/B as \displaystyle\sum_{i = 1}^{MX} \frac{f_i}{i} divided by the total number of ways to choose a tuple.

This gives us 50 points since calculating g is taking time O(N*MX), one final optimization is needed.

For each g_x, let’s ignore x^N part, we’ll handle later. For a fixed interval [L, R], it can be proved that \displaystyle sum\Big( \frac{R_i}{x}\Big) takes only 2*\sqrt{MX} distinct values. So \displaystyle\Big[ sum\Big( \frac{R_i}{x}\Big) - sum\Big( \frac{L_i-1}{x} \Big) \Big] can take at most 4*\sqrt{MX} distinct values over all values of x. Moreover, each distinct value appears for some continuous values of x, allowing us to apply range updates.

Specifically, \displaystyle\Big[ sum\Big( \frac{R_i}{x}\Big) - sum\Big( \frac{L_i-1}{x} \Big) \Big] takes a different value when \displaystyle\frac{R_i}{x} \neq \frac{R_i}{x-1} or \displaystyle\frac{L_i-1}{x} \neq \frac{L_i-1}{x-1}. We can precompute all such values of x for each value from 1 to MX, keep in sorted order, in order to get ranges with same value of \displaystyle\Big[ sum\Big( \frac{R_i}{x}\Big) - sum\Big( \frac{L_i-1}{x} \Big) \Big]. We can use concept of difference arrays to apply these updates in O(1) and then recovering the product array.

Lastly, need to take special care of 0s. Consider x = 4 and intervals [4, 9] and [9, 11]. We can use the difference array to find which values of x have non-zero values for each interval.

Refer to the implementations below if anything is unclear.

Optimizations
  • Try to reduce usage of modular operations. That would be enough.
  • For applying updates, we’d need to compute modular inverses which take O(log(MOD)) time. If applied individually, Inverse would be calculated for nearly N*\sqrt{MX}$. We can group these inverses, reducing to computing inverses nearly MX times.

TIME COMPLEXITY

Precomputation takes O(MX*\sqrt{MX}) and each test case takes O(N*\sqrt{MX}+MX*log(MOD))

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100005;
const int maxv = 100005;
const int mod = 998244353;

inline int add(int a, int b) {return a + b >= mod ? a + b - mod : a + b;}
inline int sub(int a, int b) {return add(a, mod - b);}
inline int mul(int a, int b) {return (a * 1LL * b) % mod;}
inline int bigMod(int n, int r){
	int ret = 1;
	while(r){
	    if(r & 1) ret = mul(ret, n);
	    n = mul(n, n); r >>= 1;
	}
	return ret;
}
inline int inv(int n) {return bigMod(n, mod - 2);}
inline int Div(int a, int b) {return mul(a, inv(b));}

int L[maxn], R[maxn];
vector<int> Fac[maxv];
int Mul[maxv], Zero[maxv], Del[maxv];
int Cnt[maxv], Sum[maxv], Inv[maxv];

int main(){
	Sum[0] = 0;
	for(int v=1; v<maxv; v++) Sum[v] = add(Sum[v-1], v);
	for(int v=1; v<maxv; v++) Inv[v] = inv(v);

	for(int v=1; v<maxv; v++){
	    for (int i=1, la; i<=v; i=la+1) {
	        la = v / (v / i);
	        Fac[v].push_back(i);
	    }
	    Fac[v].push_back(v+1);
	}

	int t;
	scanf("%d", &t);

	for(int cs=1; cs<=t; cs++){
	    int n;
	    scanf("%d", &n);

	    for(int i=1; i<=n; i++) scanf("%d %d", &L[i], &R[i]);
	    for(int v=1; v<maxv; v++) Mul[v] = Del[v] = 1, Zero[v] = 0;

	    for(int i=1; i<=n; i++){
	        vector<int> vec;
	        merge(Fac[L[i] - 1].begin(), Fac[L[i] - 1].end(),
	        Fac[R[i]].begin(), Fac[R[i]].end(), back_inserter(vec));
	        vec.erase(unique(vec.begin(), vec.end()), vec.end());

	        for(int j=1; j<vec.size(); j++){
	            int v = vec[j-1], w = vec[j];
	            int tot = sub(Sum[R[i] / v], Sum[(L[i] - 1) / v]);

	            if(tot == 0) Zero[v]++, Zero[w]--;
	            else Mul[v] = mul(Mul[v], tot), Del[w] = mul(Del[w], tot);
	        }
	        Zero[vec.back()]++;
	    }

	    Cnt[0] = 1;
	    for(int v=1; v<maxv; v++) Cnt[v] = mul(Cnt[v-1], Div(Mul[v], Del[v]));
	    for(int v=1; v<maxv; v++) Cnt[v] = mul(Cnt[v], bigMod(v, n));

	    for(int v=1; v<maxv; v++){
	        Zero[v] += Zero[v-1];
	        if(Zero[v]) Cnt[v] = 0;
	    }

	    for(int v=maxv-1; v>=1; v--){
	        for(int w=v+v; w<maxv; w+=v){
	            Cnt[v] = sub(Cnt[v], Cnt[w]);
	        }
	    }

	    int ans = 0;
	    for(int v=1; v<maxv; v++) ans = add(ans, Div(Cnt[v], v));

	    int way = 1;
	    for(int i=1; i<=n; i++) way = mul(way, R[i] - L[i] + 1);

	    printf("%d\n", Div(ans, way));
	}
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std; 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
#define int ll
 
int LIMIT = 300; 
int ap[123456],an[123456],bp[123456],bn[123456],az[123456],bz[123456];
int foo[123456],bar[123456];
vector<vii> vec(123456);
vector<vi> divisor(123456);
int getpow(int a,int b){
	int ans=1;
	while(b){
		if(b%2){
			ans*=a;
			ans%=mod;
		}
		a*=a;
		a%=mod;
		b/=2;
	}
	return ans;
}
signed main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	cin>>t;
	int i,j;
	int val,val1,val2;
	rep(i,1e5+10){
		int last=-10;
		LIMIT= max((int)sqrt(i),2LL);
		for(j=1;;j++){
			val=i/j;
			if(val==last)
				continue;
			if(val<=LIMIT)
				break;
			last=val;
			vec[i].pb(mp(j,val));
		}
		fd(j,min(LIMIT,i),0){
			val=i/(j+1)+1;
			vec[i].pb(mp(val,j));
		}
		vec[i].pb(mp(1e5+10,-1));
	}
	f(i,1,1e5+100){
		for(j=2*i;j<1e5+100;j+=i){
			divisor[j].pb(i);
		}
	}
	while(t--){
		int n;
		cin>>n;
		int i;
		int st,en;
		rep(i,123456){
			az[i]=0;
			bz[i]=0;
			ap[i]=1;
			an[i]=1;
			bp[i]=1;
			bn[i]=1;
		}
		int l,r;
		rep(i,n){
			cin>>l>>r;
			l--;
			int p=0,q=0;
			while(p+1<vec[l].size() && q+1<vec[r].size()){
				st=max(vec[l][p].ff,vec[r][q].ff);
				en=min(vec[l][p+1].ff,vec[r][q+1].ff);
				val1=vec[r][q].ss-vec[l][p].ss;
				val2=(vec[r][q].ss)*(vec[r][q].ss+1) - (vec[l][p].ss)*(vec[l][p].ss+1);
				val2/=2;
				//cout<<st<<" "<<en<<endl;
				val2%=mod;
				if(val1==0){
					az[st]++;
					az[en]--;
				}
				else{
					ap[st]*=val1;
					ap[st]%=mod;
					an[en]*=val1;
					an[en]%=mod;
				}
				if(val2==0){
					bz[st]++;
					bz[en]--;
				}
				else{
					bp[st]*=val2;
					bp[st]%=mod;
					bn[en]*=val2;
					bn[en]%=mod;
				}
				if(en==vec[l][p+1].ff){
					p++;
				}
				if(en==vec[r][q+1].ff){
					q++;
				}
			}
		}
		int val=1;
		f(i,1,1e5+10){
			az[i]+=az[i-1];
			val*=ap[i];
			val%=mod;
			val*=getpow(an[i],mod-2);
			val%=mod;
			if(az[i]>0){
				foo[i]=0;
			}
			else{
				foo[i]=val;
			}
		}
		val=1;
		f(i,1,1e5+10){
			bz[i]+=bz[i-1];
			val*=bp[i];
			val%=mod;
			val*=getpow(bn[i],mod-2);
			val%=mod;
			if(bz[i]>0){
				bar[i]=0;
			}
			else{
				bar[i]=val;
			}
			bar[i]*=getpow(i,n);
			bar[i]%=mod;
		}
		int sum1=0,sum2=0;
		fd(i,1e5+5,1){
			rep(j,divisor[i].size()){
			
				foo[divisor[i][j]]-=foo[i];
				if(foo[divisor[i][j]]<0)
					foo[divisor[i][j]]+=mod;
				bar[divisor[i][j]]-=bar[i];
				if(bar[divisor[i][j]]<0)
					bar[divisor[i][j]]+=mod;
			}
			val1=bar[i];
			val1*=getpow(i,mod-2);
			val1%=mod;
			sum1+=val1;
			sum1%=mod;
			sum2+=foo[i];
			sum2%=mod;
		}
		sum1*=getpow(sum2,mod-2);
		sum1%=mod;
		cout<<sum1<<endl;
	}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class RNDRATIO{
	//SOLUTION BEGIN
	int mx = (int)100000;
	long MOD = 998244353;
	int[][] list;
	void pre() throws Exception{
	    //Precomputation
	    list = new int[1+mx][];
	    for(int i = 0; i<= mx; i++){
	        ArrayList<Integer> curList = new ArrayList<>();
	        for(int K = 1; K <= i; ){
	            curList.add(K);
	            K = (i/(i/K))+1;
	        }
	        curList.add(i+1);
	        curList.add(mx+1);
	        list[i] = new int[curList.size()];
	        int cnt = 0;
	        for(int x:curList)list[i][cnt++] = x;
	    }
	}
	void solve(int TC) throws Exception{
	    int n = ni();
	    long[] productOfSums = new long[2+mx], productOfSumsInverse = new long[2+mx];//difference array concept
	    Arrays.fill(productOfSums, 1);Arrays.fill(productOfSumsInverse, 1);
	    int[] nonZero = new int[2+mx];//difference array for checking for zeroes
	    long ways = 1;
	    long[][] RNG = new long[n][];
	    for(int i = 0; i< n; i++){
	        int L = ni()-1, R = ni();
	        RNG[i] = new long[]{L, R};
	        ways *= (R-L);
	        if(ways >= MOD)ways %= MOD;
	        int p = 0, q = 0;
	        while(p+1 < list[L].length && q+1 < list[R].length){
	            int st = Math.max(list[L][p], list[R][q]);
	            int en = Math.min(list[L][p+1], list[R][q+1])-1;
	            //st <= x <= en have same value of sum(R/x)-sum((L-1)/x)
	            long v1 = L/st, v2 = R/st;
	            long val = (v2*v2+v2-v1*v1-v1)/2;
	            if(val > 0){
	                productOfSums[st] *= val;
	                if(productOfSums[st] >= MOD)productOfSums[st] %= MOD;
	                productOfSumsInverse[en+1] *= val;
	                if(productOfSumsInverse[en+1] >= MOD)productOfSumsInverse[en+1] %= MOD;
	                nonZero[st]++;
	                nonZero[en+1]--;
	            }
	            if(en+1 == list[L][p+1])p++;
	            if(en+1 == list[R][q+1])q++;
	        }
	    }
	    //Recovering original arrays from difference arrays
	    for(int i = 1; i<= mx; i++){
	        nonZero[i] += nonZero[i-1];
	        productOfSums[i] *= productOfSums[i-1];
	        productOfSumsInverse[i] *= productOfSumsInverse[i-1];
	        if(productOfSums[i] >= MOD)productOfSums[i] %= MOD;
	        if(productOfSumsInverse[i] >= MOD)productOfSumsInverse[i] %= MOD;
	    }
	    for(int i = 0; i<= mx; i++){
	        if(nonZero[i] != n)productOfSums[i] = 0;
	        productOfSums[i] *= pow(i, n);
	        if(productOfSums[i] >= MOD)productOfSums[i] %= MOD;
	        productOfSums[i] *= inv(productOfSumsInverse[i]);
	        if(productOfSums[i] >= MOD)productOfSums[i] %= MOD;
	    }
	    
	    //Inclusion-Exclusion
	    for(int i = mx; i >= 1; i--){
	        for(int j = i+i; j <= mx; j += i){
	            productOfSums[i] += MOD-productOfSums[j];
	            if(productOfSums[i] >= MOD)productOfSums[i] -= MOD;
	        }
	    }
	    //Final answer computation
	    long ans = 0;
	    for(int i = 1; i<= mx; i++){
	        if(nonZero[i] == n){
	            ans += (productOfSums[i]*inv(i))%MOD;
	            if(ans >= MOD)ans -= MOD;
	        }
	    }
	    pn((ans*inv(ways))%MOD);
	}
	long sum(long[][] rng, int idx, long prod, long gcd){
	    if(idx == rng.length)return (prod*inv(gcd))%MOD;//%MOD;
	    long ans = 0;
	    for(long i = rng[idx][0]+1; i<= rng[idx][1]; i++)ans = (ans+sum(rng, idx+1, prod*i, gcd(gcd, i)))%MOD;
	    return ans;
	}
	long gcd(long a, long b){
	    return b == 0?a:gcd(b, a%b);
	}
	long inv(long a){
	    return pow(a, MOD-2);
	}
	long pow(long a, long p){
	    long o = 1;a %= MOD;
	    for(;p>0;p>>=1){
	        if((p&1)==1)o = (o*a)%MOD;
	        a = (a*a)%MOD;
	    }
	    return o;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new RNDRATIO().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: