EXPTPROD - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Anik Sarker, Ezio Auditore

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Math, Binary Exponentiation, Expectation and Probabilities…

PROBLEM:

Given an array A of length N such that 0 \leq A_i < N and an additional integer K, the following operation is to be applied exactly K times. Initially Score = 1.

  • Randomly choose one element of A.
  • Set Score = (Score*X) \bmod N.

Find the expected value of Score after repeating above procedure K times.

EXPLANATION

First of all, let us write the expression for the expected value of Score. Read about Expected Value here.

Since Score < N holds, the expected value of Score is \frac{\sum_{x = 0}^{N-1} i*f_K(x)}{\sum_{x = 0}^{N-1}f_K(x)} if f_p(x) denotes the number of ways to obtain Score = x after p operations.

In the denominator, We can see that \sum_{x = 0}^{N-1}f_K(x) is nothing, but the number of ways to choose one element from N elements K times (repetitions allowed) which is given by N^K.

First of all, we can see, that f_0(1) = 1 and f_0(x) = 0 for x \neq 1. Let freq(x) denote the frequency of x in array A. Now, We want to calculate f_{p+1}(x) assuming we know f_p(x) and freq(x) for all 0 \leq x < N.

Suppose we have calculated f_p(x) for all values of x for some fixed p. We want to calculate f_{p+1}(x) using this. Also, freq(x) denote the frequency of x in array A.

We can see that f_{p+1}(z) = \sum_{x= 0}^{N-1}\sum_{y = 0}^{N-1} g(x, y, z)*f_p(x)*freq(y) where g(x, y, z) = 1 if and only if x*y = z \bmod N otherwise g(x, y, z) = 0. This works because we have f_p(x) ways to obtain Score = x after p operations and freq(y) ways to select y at current operation, it contributes f_p(x)*freq(y) ways to obtain x*y \bmod N after p+1 operations, so it is added to f_{p+1}((x*y)\bmod N). Repeating this over all pairs of (x, y) calculates f_{p+1}(x) for all 0 \leq x < N. This process takes O(N^2) time, since we have to consider every pair (x, y) and there are N^2 such pairs.

Hence, repeating this process K times gives us f_K(x) and allows us to solve the problem in O(K*N^2) time, which is too slow. Let us optimize this.

Let us now calculate f_{a+b}(x) assuming we have calculates f_a(x) and f_b(x). The process remains same, we consider every pair (x, y) and we have f_a(x) ways to get Score = x after a operations and f_b(y) ways to get Score = y after b operations, this pair contributes to f_a(x)*f_b(y) ways to obtain (x*y) \bmod N after a+b operations. Let us call this a convolution.

This is it. Let us precompute f_p(x) for all p such that p is a power of 2 doesn’t exceed K. This takes O(N^2*log_2(K)) time. (Hint: Convoluting f_p(x) with itself gives f_{2*p}(x)) Now, Let us write K as a sum of powers of two (It is always possible, refer binary representation of K) and for every power of two and convolute all of them. We can see, that it gives us f_K(x) for all x. This process also takes O(N^2*log(N)) time Every number can be written as the sum of log(K) powers of two and each convolution takes O(N^2) time. See here.

After this, we can just calculate the summation \sum_{x = 0}^{N-1} i*f_K(x) and calculate the numerator, and print the answer using the product of numerator and the modular inverse of the denominator.

TIME COMPLEXITY

The time complexity is O(N^2*log(K)) per test case.

SOLUTIONS:

Setter 1 Solution
#include <bits/stdc++.h>
using namespace std;
#define MAX 1005
#define MOD 1000000007
#define ll long long int

ll bigMod(ll a,ll b){
	ll res=1;
	while(b){
	    if(b&1) res = (res*a) % MOD;
	    a = (a*a) % MOD; b>>=1;
	}
	return res;
}

struct Row{
	int n;
	vector<ll>m;
	Row(int _n) {n = _n; m.clear(); m.resize(n,0);}
};

Row Multiply(Row A, Row B){
	int n = A.n;
	Row result(n);
	for(int i=0;i<n;i++){
	    for(int j=0;j<n;j++){
	        int x = (i*j) % n;
	        result.m[x] += (A.m[i] * B.m[j]) % MOD;
	        if(result.m[x] >= MOD) result.m[x] -= MOD;
	    }
	}
	return result;
}

Row Power(Row mat,ll p){
	int n = mat.n;
	Row res(n);
	Row ans(n);

	ans.m[1] = 1;
	for(int i=0;i<n;i++) res.m[i] = mat.m[i];

	while(p){
	    if(p&1) ans=Multiply(ans,res);
	    res=Multiply(res,res);
	    p=p/2;
	}
	return ans;
}

long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
		char g=getchar();
		if(g=='-'){
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g && g<='9'){
			x*=10;
			x+=g-'0';
			if(cnt==0){
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);

			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd){
			if(is_neg){
				x= -x;
			}
			assert(l<=x && x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l,int r,char endd){
	string ret="";
	int cnt=0;
	while(true){
		char g=getchar();
		assert(g!=-1);
		if(g==endd){
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
	return readString(l,r,'\n');
}
string readStringSp(int l,int r){
	return readString(l,r,' ');
}

int main(){
	int T = readIntLn(1,10);
	while(T--){
		int N = readIntSp(2,1000);
		int K = readIntLn(1,1e9);

	    Row mat(N);
		for(int i=0;i<N;i++){
		    int x;
			if(i!= N-1) x = readIntSp(0,N-1);
			else x = readIntLn(0,N-1);
	        mat.m[x]++;
		}

	    Row ret = Power(mat,K);

	    ll Sum = 0;
	    for(int i=0;i<N;i++) Sum += i * ret.m[i];
	    Sum %= MOD;

	    ll Way = bigMod(N,K);
	    Way = bigMod(Way, MOD-2);
	    Sum = (Sum * Way) % MOD;

	    printf("%lld\n",Sum);
	}
}
Setter 2 Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;

int t, n, m, x, y, z, k;
ll cnt[1005];
ll res[1005];
ll tmp[1005];

ll bigMod(ll n, ll k)
{
	if(k == 0) return 1;
	ll re = bigMod(n, (k/2));
	re = (re * re) % mod;
	if(k % 2) re = (re * n) % mod;
	return re;
}

int main()
{
	cin >> t;
	if(t < 1 || t > 10) assert(false);

	while(t--){

	    scanf("%d %d",&n, &k);

	    if(n < 1 || n > 1000) assert(false);
	    if(k < 1 || k > 1000000000) assert(false);
	    memset(cnt, 0, sizeof(cnt));

	    for(int i = 1; i <= n; i++){
	        scanf("%d", &x);
	        if(x < 0 || x >= n) assert(false);
	        cnt[x]++;
	    }
	    memset(res, 0, sizeof(res));
	    memset(tmp, 0, sizeof(tmp));
	    res[1] = 1;

	    int msb = 31 - __builtin_clz(k);

	    while(msb + 1){
	        for(int i = 0; i < n; i++){
	            for(int j = 0; j < n; j++){
	                int xx = (i * j) % n;
	                tmp[xx] += (res[i] * res[j]) % mod;
	                if(tmp[xx] >= mod) tmp[xx] -= mod;
	            }
	        }

	        for(int i = 0; i < n; i++) res[i] = tmp[i], tmp[i] = 0;

	        if((1 << msb)&k){
	            for(int i = 0; i < n; i++){
	                for(int j = 0; j < n; j++){
	                    int xx = (i * j) % n;
	                    tmp[xx] += (res[i] * cnt[j]) % mod;
	                    if(tmp[xx] >= mod) tmp[xx] -= mod;
	                }
	            }
	            for(int i = 0; i < n; i++) res[i] = tmp[i], tmp[i] = 0;
	        }
	        msb--;
	    }

	    ll ans = 0;
	    for(ll i = 0; i < n; i++) ans = (ans + i * res[i]) % mod;

	    ll niche = bigMod(n, k);

	    ans = (ans * bigMod(niche, mod - 2)) % mod;

	    printf("%lld\n", ans);

	}

	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;

#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;

#define int ll

int a[1234],b[1234],c[1234],d[1234];

int n;
int multiply(int a[],int b[]){
	int i,j;
	rep(i,n){
		d[i]=0;
	}
	int val;
	// int iinf = inf;
	// iinf*=inf;
	rep(i,n){
	    val=0;
		rep(j,n){
			d[val]+=a[i]*b[j];
			d[val]%=mod;
			val+=i;
	        if(val>=n)
	            val-=n;  
	    }
	}
	rep(i,n){
		a[i]=d[i];
	}
	return 0;
}

int getpow(int p,int q){
	int ans=1;
	while(q){
		if(q%2){
			ans*=p;
			ans%=mod;
		}
		p*=p;
		p%=mod;
		q/=2;
	}
	return ans;
}
main(){
	//std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	//cin>>t;
	scanf("%lld",&t);
	while(t--){
		int k;
		//cin>>n>>k;
		scanf("%lld",&n);
	    scanf("%lld",&k);
	
	    int i;
		rep(i,n){
			//cin>>a[i];
	        scanf("%lld",&a[i]);
	
		}
		rep(i,n){
			b[i]=0;
			c[i]=0;
		}
		rep(i,n){
			b[a[i]]++;
		}
		c[1]=1;
		while(k){
			if(k%2){
				multiply(c,b);
			}
			multiply(b,b);
			k/=2;
		}
		int ans=0,sumi=0;
		rep(i,n){
			ans+=c[i]*i;
			sumi+=c[i];
		}
	    ans%=mod;
	    sumi%=mod;
		ans*=getpow(sumi,mod-2);
		ans%=mod;
		cout<<ans<<"\n";
	}
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class EXPTPROD{
	//SOLUTION BEGIN
	//Into the Hardware Mode
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni();long k = nl();
	    long[] freq = new long[n];
	    for(int i = 0; i< n; i++)freq[ni()]++;
	    long num = 0, den = pow(n, k);
	    long[] sum = pow(freq, k);
	    for(int i = 0; i< n; i++)num = (num+(i*sum[i])%mod)%mod;
	    pn((num*pow(den, mod-2))%mod);
	}
	long pow(long a, long k){
	    long ans = 1;
	    while(k>0){
	        if((k&1)==1)ans = (ans*a)%mod;
	        a = (a*a)%mod;
	        k>>=1;
	    }
	    return ans;
	}
	long[] pow(long[] a, long k){
	    long[] ans = new long[a.length];
	    ans[1] = 1;
	    while(k>0){
	        if((k&1)==1)ans = multiply(ans, a);
	        a = multiply(a, a);
	        k>>=1;
	    }
	    return ans;
	}
	long[] multiply(long[] a, long[] b){
	    int n = a.length;
	    long[] ans = new long[n];
	    for(int i = 0; i< n; i++)
	        for(int j = 0; j< n; j++)
	            ans[(i*j)%n] = (ans[(i*j)%n]+a[i]*b[j]%mod)%mod;
	    return ans;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	long IINF = (long)1e18, mod = (long)1e9+7;
	final int INF = (int)1e9, MX = (int)2e5+5;
	DecimalFormat df = new DecimalFormat("0.00000000000");
	double PI = 3.141592653589793238462643383279502884197169399, eps = 1e-6;
	static boolean multipleTC = true, memory = false, fileIO = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    if(fileIO){
	        in = new FastReader("input.txt");
	        out = new PrintWriter("output.txt");
	    }else {
	        in = new FastReader();
	        out = new PrintWriter(System.out);
	    }
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    if(memory)new Thread(null, new Runnable() {public void run(){try{new EXPTPROD().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
	    else new EXPTPROD().run();
	}
	long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
	int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}
 
	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }
 
	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }
 
	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }
 
	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to Share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

8 Likes

Can you please suggest some similar problems of this problem

1 Like

Try AND, OR and XOR operation instead of multiplication.

Appreciated this problem! It was too good if someone wants to learn exponentiation!

1 Like

Can you please elaborate on how adding ‘a’ number of steps after ‘b’ steps are over is same as adding ‘a’ steps to the initial score?
In other words, I think f(a+b) depends not only on f(a) and f(b) but also on the initial score vector ? Please clarify this.

Thanks in advance

3 Likes

try this…

1 Like

@taran_1407 what do if initial score was not 1? (i mean if it was a random number)

Suppose we have calculated f_K(x) for each x, and we start from y. Then, for every 0 \leq x < N, f_K(x) contributes to ans(x*y). We can run a single loop to calculate this.

@taran_1407 @vijju123 can you guys please suggest a lower difficulty problem similar to this ? i cant get this fully …if you have encountered similar problem suggest me thanks !

Maybe Consider addition operation in place of multiplication. It is i think simpler problem. Other than that, I don’t know any specific easier problem.

can you tell how this formula came ?

can you tell how the formula of expectation for score s came?

A more natural way to do this is to proceed like in case of binary exponentiation.

  • When K is even, compute f_{K/2} and convolute it with itself to compute f_K.
  • When K is odd, compute f_{K-1} and convolute it with (f_1 = freq) to compute f_K.
  • The base case is when K = 0, f_K(1) = 1 and all other values are 0.

My submission uses this idea: CodeChef: Practical coding for everyone