MUSICAL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

2831

PREREQUISITES:

Familiarity with probability and expected value

PROBLEM:

N people play a game of musical chairs, with N chairs.
Each chair has a color between 1 and M, while the people don’t have assigned colors yet.
Colors are assigned to the people one by one.

During the process of the game, if anyone sits on a chair whose color matches their assigned color, they are eliminated.

After each color assignment, find the expected number of people remaining if the game is played; assuming everyone with an unassigned color picks one uniformly at random from [1, M].

EXPLANATION:

For convenience, let’s label the people and chairs 0, 1, 2, \ldots, N-1.
Every time we talk about indices below, they’re modulo N.

Consider a person x. This person will:

  • First sit on chair x+L_1
  • Then sit on chair x+L_1+L_2
  • Then sit on chair x+L_1+L_2+L_3
    \vdots
  • Finally sit on chair x+L_1+L_2+\ldots + L_K

If person x's banned color matches the color of any of these chairs, they’ll be eliminated; otherwise they won’t.

Notice that, since there are only N chairs, the set of chairs (and hence, the set of colors of chairs) they sit on will be of size \leq N (which is bounded by 2000), even though K can be as large as 2\cdot 10^5.

So, let’s first find for each person, which colors are ‘bad’ for them (in the sense that if they were assigned this color, they’d be eliminated).
That’s not very hard to do: as mentioned above, just go through all chairs this person can reach, and insert its color into a set.
This can be done in \mathcal{O}(N^2 \log N) time — all we need to know is the distinct values of L_1 + L_2 + \ldots + L_i, modulo N.

Let \text{bad}[x] denote the number of colors that are bad for person x.
We’ll now try to compute the expected number of people remaining.

By linearity of expectation, it’s enough if we can compute, for each person x, the probability that they’ll remain. The answer is then the sum of all these probabilities.
Computing these probabilities isn’t hard, because each person moves independently of the others.
For a person x:

  • If A_x has been set already, the probability is either 0 or 1, depending on whether A_x is a bad color for x or not.
  • If A_x hasn’t been set yet, the probability that x remains simply equals the probability that a bad color isn’t picked; which in turn is exactly \frac{M - \text{bad}[x]}{M}.

It’s quite easy to recompute this after an update: only the probability of one person changes; and becomes either 0 or 1 from \frac{M - \text{bad}[x]}{M}.

In summary, our algorithm is as follows:

  • Compute \text{bad}[x] for all x. This can be done in \mathcal{O}(N^2 \log N) time.
  • Before any updates the answer is \sum \frac{M - \text{bad}[x]}{M}
  • When person x's color is updated to A_x, subtract \frac{M - \text{bad}[x]}{M} from the sum, and add either 0 or 1 depending on whether A_x is a bad color for x or not.
    This check can even be done in \mathcal{O}(N), since there are only N updates.

TIME COMPLEXITY

\mathcal{O}(N^2\log N + K) per testcase.

CODE:

Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif 
#define ll long long 
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){   
        char g=getchar();
        if(g=='-'){  
            assert(fi==-1);     
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
const ll MOD=1e9+7;
vector<ll> readv(ll n,ll l,ll r){
    vector<ll> a;
    ll x;
    for(ll i=1;i<n;i++){  
        x=readIntSp(l,r);  
        a.push_back(x);   
    }
    x=readIntLn(l,r);
    a.push_back(x);
    return a;  
}
const ll MAX=3000300;   
ll sum_n=0;     
void dbug(vector<ll> a){
    for(auto t:a){
        cout<<t<<" ";
    }   
    cout<<endl; 
}
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;  
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
    ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
    ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
    now%=MOD;
    return now;
}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
bool check_distinct(vector<ll> a){
    sort(a.begin(),a.end());
    ll n=a.size();
    for(ll i=1;i<n;i++){
        assert(a[i]!=a[i-1]);
    }
    return true;
}
ll g(ll x){
    return x;  
}
ll sum_k=0;
void solve(){    
    ll n=readIntSp(1,2000),m=readIntSp(1,g(2e5)),k=readIntLn(1,g(2e5));
    vector<ll> c=readv(n,1,m);
    vector<ll> l=readv(k,1,g(1e9));
    sum_n+=n;
    sum_k+=k;
    vector<ll> resultant(n,0);
    ll now=0;
    for(auto &it:c){
        it--;
    }
    for(auto it:l){
        now=(now+it)%n;
        resultant[now]=1;
    }  
    ll ans=0;
    vector<ll> contribution(n,0);
    for(ll i=0;i<n;i++){
        set<ll> found;
        for(ll j=0;j<n;j++){
            if(resultant[j]){
                found.insert(c[(i+j)%n]);
            }
        }
        ll options=m-found.size();
        contribution[i]=(options*inverse(m,MOD))%MOD;
        ans=(ans+contribution[i])%MOD;
        cerr<<contribution[i]<<"\n";
    }
    for(ll q=1;q<=n;q++){
        ll x=readIntSp(1,n),y=readIntLn(1,m);
        x--,y--;
        ans-=contribution[x];
        contribution[x]=1;
        for(ll j=0;j<n;j++){
            if(resultant[j]){  
                if(c[(x+j)%n]==y){
                    contribution[x]=0;
                }
            }
        }
        ans=(ans+contribution[x]+MOD)%MOD;
        cout<<ans<<" ";
    }
    cout<<"\n";
    return;  
}  
int main(){
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                   
    freopen("input.txt", "r", stdin);                                                
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases=readIntLn(1,g(1000));
    while(test_cases--){
        solve();
    }
    assert(sum_n<=g(3000)); 
    assert(sum_k<=g(4e5));
    assert(getchar()==-1);
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n, m, k = map(int, input().split())
    col = list(map(int, input().split()))
    moves = list(map(int, input().split()))
    bad = [0]*n
    minv = pow(m, mod-2, mod)
    
    mark = [0]*n
    for i in range(k):
        if i > 0: moves[i] += moves[i-1]
        mark[moves[i]%n] = 1
    
    ans = 0
    for i in range(n):
        s = set()
        for j in range(n):
            if mark[j] == 0: continue
            s.add(col[(i+j)%n])
        bad[i] = len(s)
        ans += m - bad[i]
    
    for i in range(n):
        x, y = map(int, input().split())
        x -= 1
        ans -= m - bad[x]
        ans += m
        for j in range(n):
            if mark[j] == 0: continue
            if col[(x+j)%n] == y:
                ans -= m
                break
        print(ans * minv % mod, end = ' ')
    print()

O(N*K) solution can pass this problem.