RANDOM_ARRAY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash_daga
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Linearity of expectation, dynamic programming

PROBLEM:

You’re given N arrays, each containing some integers.
Consider a sequence S built as follows:

  • If S is empty, choose a random element from any array and append it to S.
  • Otherwise, let x be the last chosen element, and suppose it’s from array i.
    Choose any element strictly greater than x from an array that’s not the i'th, and append it to S.

All choices are made uniformly randomly across all possible options.
This process continues till it’s no longer possible to choose an element.
Find the expected sum of the obtained sequence.

EXPLANATION:

Let’s first merge all the N arrays with us into a single giant sorted array, of size M = \sum K_i
Let A_i denote the value of the i'th element of this sorted array, and B_i denote which array it initially came from.

We use the linearity of expectation: for each element, let’s find the probability P_i that it’s included in the sequence we build.
The final answer is then simply \sum_{i=1}^M P_iA_i.

To compute P_i, note that there are two possibilities:

  • A_i can be the first chosen element, which has a constant probability of \frac{1}{M}.
  • A_i can be chosen after something else is, the probability of which depends on what that “something else” is.

For the second case, let’s fix j \lt i to be the index of the element chosen just before A_i.
Then,

  • If A_i = A_j or B_i = B_j this isn’t allowed, so ignore such j.
  • Otherwise, for A_i to immediately follow A_j in S,
    • A_j should be present in the sequence in the first place, which has a probability of P_j.
    • Let there be C_j choices for the next element after choosing A_j.
      The probability that A_i is chosen is then \frac{1}{C_j}

That is, we have

P_i = \frac{1}{M} + \sum_{\substack{j \lt i \\ A_i\neq A_j \\ B_i\neq B_j}} \frac{P_j}{C_j}

We now have two things to do: compute all the C_i values, and figure out how to compute P_i faster.

Computing C

C_i denotes the number of elements that are greater than A_i and from a different array than it.

One easy way to compute this quickly is binary search: find the number of elements in the large list that are \gt A_i, and from this subtract the number of elements \gt A_i that belong to the same list as it.
The latter can be found by binary searching on the (sorted) list B_i, for instance.

An even faster approach, avoiding the binary search, is to store something like a frequency array F, where F_k is the number of elements of the k-th array seen so far.
Iterate in decreasing order of i, so that the overall number of greater elements can be maintained as you go.
Then, C_i simply equals the number of greater elements, minus the number of elements seen of B_i (which is just F_{B_i}).

You might need some extra care when dealing with equal elements, for example by processing them all at once and only then updating F.

Computing P

As it turns out, computing P_i quickly can be done using basically the same technique we used to compute the C_i values.

Note that we want the sum of \frac{P_j}{C_j} across all j \lt i such that A_i \neq A_j and B_i \neq B_j.
It’s quite easy to find the sum of all \frac{P_j}{C_j}. From this, we’ll subtract out the “bad” ones.
For that, simply store the sum of \frac{P_j}{C_j} corresponding to each array so far, so we can immediately subtract out the sum corresponding to B_i (which can be stored in an array of length N).
That leaves indices with A_j = A_i; and as we did when computing C you can get around this by processing all elements with this value first and only then updating values.

Once every P_i is known, simply compute the answer as \sum P_iA_i.

TIME COMPLEXITY:

\mathcal{O}((N+\sum K_i)\log (N+\sum K_i)) per testcase.

CODE:

Author's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
const long long N=1000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
{
    if(a==0)
    return 0;
    int res=1;
    a%=p;
    while(b>0)
    {
        if(b&1)
        res=(1ll*res*a)%p;
        b>>=1;
        a=(1ll*a*a)%p;
    }
    return res;
}
int small(vi &v, int a)
{
    return (upper_bound(all(v), a) - v.begin());
}
int32_t main()
{
    IOS;
    int t;
    cin>>t;
    while(t--)
    {
        int n;
        cin>>n;
        int sum[n], tot_sum=0, sz[n], co=0;
        fill(sum, 0);
        mii mp1[n];
        vi v, v1[n];
        map <pii, int> mp, choices;
        rep(i,0,n)
        {
            cin>>sz[i];
            co+=sz[i];
            rep(j,0,sz[i])
            {
                int a;
                cin>>a;
                v.pb(a);
                v1[i].pb(a);
                mp1[i][a]++;
                mp[{a, i}]++;
            }
            sort(all(v1[i]));
        }
        sort(all(v));
        int first_pick = power(co, mod-2, mod), ans=0;
        int sum_cur=0, val_cur=-1;
        for(auto it:mp)
        {
            int a=it.ff.ff, id=it.ff.ss, num=it.ss;
            int prob=(tot_sum + first_pick - sum[id] + mod)%mod;
            if(a==val_cur)
                prob=(prob - sum_cur + mod)%mod;
            else
            {
                sum_cur = 0;
                val_cur = a;
            }
            ans=(ans + (((a*prob)%mod)*num))%mod;
            int valid=(co - small(v, a)) - (sz[id] - small(v1[id], a));
            // cout<<tot_sum<<" "<<sum[id]<<"\n";
            // cout<<" "<<a<<" "<<id<<" "<<valid<<" "<<prob<<" "<<num<<"\n";
            prob=(prob*power(valid, mod-2, mod))%mod;
            prob=(prob*num)%mod;
            sum_cur = (sum_cur + prob)%mod;
            tot_sum = (tot_sum + prob)%mod;
            sum[id] = (sum[id] + prob)%mod;
        }
        cout<<ans<<"\n";
    }
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int facN = 1e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
	if (y == 0) return 1;

	int v = power(x, y / 2);
	v = 1LL * v * v % mod;

	if (y & 1) return 1LL * v * x % mod;
	else return v;
}

void factorialinit(){
	facinit = true;
	ff[0] = iff[0] = 1;

	for (int i = 1; i < facN; i++){
		ff[i] = 1LL * ff[i - 1] * i % mod;
	}

	iff[facN - 1] = power(ff[facN - 1], mod - 2);
	for (int i = facN - 2; i >= 1; i--){
		iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
	}
}

int C(int n, int r){
	if (!facinit) factorialinit();

	if (n == r) return 1;

	if (r < 0 || r > n) return 0;
	return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
	if (!facinit) factorialinit();

	assert(0 <= r && r <= n);
	return 1LL * ff[n] * iff[n - r] % mod;
}

int inv(int x){
	return power(x, mod - 2);
}

int Solutions(int n, int r){
	//solutions to x1 + ... + xn = r 
	//xi >= 0

	return C(n + r - 1, n - 1);
}

void Solve() 
{
    int n; cin >> n;

    // what is global probability to pick
    // what is sum each row's probability
    // also need to know number of greater numbers globally 
    // number of greater numbers in row
    // keep changes lazily so only update when new number is smaller 

    map <int, int> ggf; 
    vector <map<int, int>> lgf(n);
    map <int, vector<int>> mp;
    int gpb = 0;
   	vector <int> rpb(n, 0);

   	int total = 0;

   	for (int i = 0; i < n; i++){
   		int k; cin >> k;
   		total += k;

   		int sum = k;

   		while (k--){
   			int x; cin >> x;

   			mp[x].push_back(i);
   			ggf[x]++;
   			lgf[i][x]++;
   		}

   		for (auto [x, y] : lgf[i]){
   			sum -= y;
   			lgf[i][x] = sum;
   		}
   	}
   	

   	int tt = total;
   	for (auto [x, y] : ggf){
   		tt -= y;
   		ggf[x] = tt;
   	}

   	gpb = inv(total);

   	int ans = 0;

   	for (auto [x, vec] : mp){
   	    map <int, int> rem;
   		for (auto i : vec){
   			int prob = gpb - rpb[i];
   			if (prob < 0) prob += mod;
   			
   			rem[i] = prob;

   			ans += prob * x % mod;
   			
   			//cout << prob << " " << x << "\n";
   		}

   		for (auto i : vec){
   			int gg = ggf[x] - lgf[i][x];
   			if (gg == 0) continue;
   			
   			int prob = rem[i];
   			prob *= inv(gg);
   			prob %= mod;

   			gpb += prob;
   			gpb %= mod;
   			rpb[i] += prob;
   			rpb[i] %= mod;
   		}
   	}

   	ans %= mod;
   	cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    vals = []
    for i in range(n):
        a = list(map(int, input().split()))[1:]
        for x in a: vals.append((x, i))
    vals.sort()
    m = len(vals)

    first = pow(m, mod-2, mod)
    more = 0
    seen = [0]*n
    after = [0]*m
    i = m - 1
    while i >= 0:
        j = i
        while j >= 0 and vals[i][0] == vals[j][0]:
            after[j] = more - seen[vals[j][1]]
            j -= 1
        j = i
        while j >= 0 and vals[i][0] == vals[j][0]:
            more += 1
            seen[vals[j][1]] += 1
            j -= 1
        i = j
    
    for i in range(m):
        after[i] = pow(after[i], mod-2, mod)
    
    pref = 0
    seen = [0]*n
    dp = [first]*m
    i = 0
    while i < m:
        j = i
        while j < m and vals[i][0] == vals[j][0]:
            dp[j] = (dp[j] + pref - seen[vals[j][1]])%mod
            j += 1
        j = i
        while j < m and vals[i][0] == vals[j][0]:
            pref += dp[j] * after[j] % mod
            seen[vals[j][1]] += dp[j] * after[j] % mod
            j += 1
        i = j
    ans = 0
    for i in range(m):
        ans += dp[i] * vals[i][0] % mod
    print(ans % mod)
2 Likes

UPDATE:

Bad comment deleted. Sorry.

The first number in each line is the size of the respective array, so the first example is A=[10] and B=[4].

1 Like

Hi, 1 is the size of array not the element of array,the two arrys are , A=[10],B=[4]

2 Likes

sorry, my bad. Let me delete above comment, it is useless. Thanks for quick reply.

2 Likes

why don’t you sort the merge array? Isn’t it ok to be sorted?

1 Like