EXPCOMP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: prince_patel_8
Tester: wasd2401
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics, linearity of expectation

PROBLEM:

You’re given a tree on N vertices. Each vertex is either black or white.
For each x from 1 to K (the number of black vertices), compute the expected number of components remaining if you remove exactly x random black vertices from the tree.

EXPLANATION:

There are a couple of different ways to approach this problem, though they all rely on linearity of expectation.
That is, for any two random variables, \mathbb{E}[X+Y] = \mathbb{E}[X] + \mathbb{E}[Y].
If the concept is new to you, this might help.

One way that turns to be particularly easy to work with is to come up with an appropriate model for connected components.
Specifically, if you remove x vertices from a tree, observe that the number of connected components will exactly equal N - x, minus the number of remaining edges.
This is only true because we started with a tree (or rather a forest), so each time we add an edge to it, it will join two different components.

So, for a fixed x (1 \leq x \leq K), the expected number of components equals N-x minus the expected number of edges removed.

Once again, we apply linearity of expectation: the expected number of edges removed is simply the sum of probabilities that each edge is removed.
So, consider some edge (u, v). What’s the probability that it’s removed?
Well, the edge will be removed if and only if at least one of vertices u and v are deleted.
This tells us that:

  • If u and v are both white, the edge will never be removed; since both its endpoints will always remain.
  • If u is white and v is black (or vice versa), the edge will be removed only if that specific endpoint is removed.
    For this to happen, we can choose any subset of size x-1 of the remaining K-1 black vertices, and add this vertex to it; so the number of ways is just \binom{K-1}{x-1}.
    To get the probability, divide this by \binom{K}{x} (the total number of choices).
    Note that this simplifies to \frac{x}{K} to you don’t even need to compute binomial coefficients.
  • If u and v are both black, the edge will be removed if at least one of them is deleted.
    It’s easier to count the number of ways in which both aren’t deleted; which is just \binom{K-2}{x-2}.
    Divide this by \binom{K}{x} and subtract it from 1 to get the desired probability.
    Once again, this simplifies to 1 - \frac{x\cdot (x-1)}{K\cdot (K-1)} so explicitly computing binomial coefficients isn’t needed.

Notice that while there are N-1 edges, there are only three types of probabilities.
Let’s call them p_{ww}, p_{bb}, p_{bw}.
Let the number of edges of each type be c_{ww}, c_{bb}, c_{bw}, respectively.

Then, the sum of probabilities across all edges is simply
p_{ww}\cdot c_{ww} + p_{bb}\cdot c_{bb} + p_{bw}\cdot c_{bw}

Note that the edge counts of each type can be computed once, and remain the same for all x.
Further, for a fixed x, each probability requires only the computation of a couple of binomial coefficients, and/or a few arithmetic operations; meaning each one can be computed in constant or \mathcal{O}(\log{MOD}) time.

So, simply apply this solution to each x: compute the necessary probabilities, find p_{ww}\cdot c_{ww} + p_{bb}\cdot c_{bb} + p_{bw}\cdot c_{bw}, and subtract it from N-x to get the answer.

TIME COMPLEXITY:

\mathcal{O}(N) or \mathcal{O}(N\log{MOD}) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
 
using namespace std;
 
#define int long long int
 
const int N = 2e5 + 10;
const int mod = (int)1e9 + 7;

vector<int> fact(N);
vector<int> inv(N);

int power(int x, int y, int p) {
    int res = 1;
    x = x % p;
    if (x == 0)
        return 0;
    while (y > 0) {
        if (y & 1)
            res = (res * x) % p;
        y = y >> 1;
        x = (x * x) % p;
    }
    return res;
}
void init() {
    fact[0] = 1;
    for (int i = 1; i < N; i++) {
        fact[i] = (fact[i - 1] % mod * i % mod) % mod;
    }
    for (int i = 0; i < N; i++) {
        inv[i] = power(fact[i], mod - 2, mod);
    }
}
int nCr(int n, int r) {
  if(r > n || r < 0 || n < 0) return 0;
  return (fact[n] % mod * inv[n - r] % mod * inv[r] % mod) % mod;
}

int mminvprime(int a, int b) {return power(a, b - 2, b);}
int mod_mul(int a, int b, int m) {a = a % m; b = b % m; return (((a * b) % m) + m) % m;}
int mod_div(int a, int b, int m) {a = a % m; b = b % m; return (mod_mul(a, mminvprime(b, m), m) + m) % m;}

int32_t main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  init();
  int tc;
  cin >> tc;
  while(tc--) {
    int n, k;
    cin >> n >> k;
    vector<int> a(n + 1);
    for(int i = 1; i <= k; i++) {
      int x;
      cin >> x;
      a[x] = 1;
    }
    vector<int> degree(n + 1);
    int extraEdgesCount = 0;
    for(int i = 0; i < n - 1; i++) {
      int u, v;
      cin >> u >> v;
      degree[u]++;
      degree[v]++;
      // both vertices are black
      if(a[u] && a[v]) {
        extraEdgesCount ++;
      }
    }
    int blackNodeDegreeSum = 0;
    int blackNodeCount = 0;
    for(int i = 1; i <= n; i++) {
      if(a[i]) {
        blackNodeCount ++;
        blackNodeDegreeSum += degree[i];
      }
    }
    vector<int> ans;
    for(int i = 1; i <= k; i++) {
        int countOfOne = nCr(blackNodeCount, i) % mod;

        int degreeSum = ((nCr(blackNodeCount - 1, i - 1) % mod) * 
        ((blackNodeDegreeSum - blackNodeCount + mod) % mod)) % mod;

        int extraEdges = (extraEdgesCount % mod * 
        (nCr(blackNodeCount - 2, i - 2) % mod)) % mod;
        assert(extraEdges >= 0);
        int expectedComponents = (countOfOne + degreeSum - extraEdges + mod) % mod;
        ans.push_back(mod_div(expectedComponents, countOfOne, mod));
    }

    for(int x: ans) {
      assert(x >= 0);
      cout << x << " ";
    }
    cout << '\n';
  }
}
Tester's code (C++)
/*

*       *  *  ***       *       *****
 *     *   *  *  *     * *        *
  *   *    *  ***     *****       *
   * *     *  * *    *     *   *  *
    *      *  *  *  *       *   **

                                 *
                                * *
                               *****
                              *     *
        *****                *       *
      _*     *_
     | * * * * |                ***
     |_*  _  *_|               *   *
       *     *                 *  
        *****                  *  **
       *     *                  ***
  {===*       *===}
      *  IS   *                 ***
      *  IT   *                *   *
      * RATED?*                *  
      *       *                *  **
      *       *                 ***
       *     *
        *****                  *   *
                               *   *
                               *   *
                               *   *
                                ***   

*/

// 2 Years Tribute to Competitive Programming

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace __gnu_pbds;
using namespace std;

#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()

const ll N = 2e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;

vl v[N];
set<ll> t;
vl b(N);
vl fact(N,1);
ll modinverse(ll a) {
	ll m = mod, y = 0, x = 1;
	while (a > 1) {
		ll q = a / m;
		ll t = m;
		m = a % m;
		a = t;
		t = y;
		y = x - q * y;
		x = t;
	}
	if (x < 0) x += mod;
	return x;
}
ll gcd(ll a, ll b) {
	if (b == 0)
		return a;
	return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
	return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
	return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
	ll product = 1;
	a %= md;
	while (b) {
		if (b & 1) product = (product * a) % md;
		a = (a * a) % md;
		b /= 2;
	}
	return product % md;
}
ll c1,c2;
void barfi(ll n){
	b[n]=0;
	fora(x,v[n]){
		if(b[x]){
			barfi(x);
			ll x1=0,x2=0;
			if(t.count(n)) x1=1;
			if(t.count(x)) x2=1;
			if(x1+x2==2) c2++;
			if(x1+x2==1) c1++;
		}
	}
}
ll kulfi(ll n, ll r){
	if(n<r || r<0) return 0;
	ll p=modinverse(fact[r])*modinverse(fact[n-r]);
	p%=mod;
	return (p*fact[n])%mod;
}
void panipuri() {
	ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
	string s;
	bool ch = true;
	cin >> n>>k;
	forl(i, 1, n+1) {
		v[i].clear();
		b[i]=1;
	}
	c1=0;
	c2=0;
	t.clear();
	vl a(k);
	forl(i,0,k){
		cin>>a[i];
		t.insert(a[i]);
	}
	forl(i,1,n){
		ll x,y;
		cin>>x>>y;
		v[x].pub(y);
		v[y].pub(x);
	}
	barfi(1);
	// cout<<c1<<' '<<c2<<'\n';
	forl(i,1,k+1){
		ll p1=kulfi(k,i)-kulfi(k-1,i)+mod;
		p1%=mod;
		ll p2=kulfi(k,i)-kulfi(k-2,i)+mod;
		p2%=mod;
		p1*=modinverse(kulfi(k,i));
		p1%=mod;
		p2*=modinverse(kulfi(k,i));
		p2%=mod;
		ans=(c1*p1+c2*p2+1-i+mod)%mod;
		// ans*=kulfi(k,i);
		// ans%=mod;
		cout<<ans<<' ';
	}
	return;
}
int main() {
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	#ifndef ONLINE_JUDGE
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
	#endif
	int laddu = 1;
	cin >> laddu;
	forl(i,1,N){
		fact[i]=i*fact[i-1];
		fact[i]%=mod;
	}
	forl(i, 1, laddu + 1) {
		// cout << "Case #" << i << ": ";
		panipuri();
		cout << '\n';
	}
}
Editorialist's code (Python)
mod = 10**9 + 7
N = 2*10**5 + 100
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
for i in range(N): inv[i] = pow(fac[i], mod-2, mod)

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    mark = [0]*n
    for x in a: mark[x-1] = 1

    bb, ww, bw = 0, 0, 0
    for i in range(n-1):
        u, v = map(int, input().split())
        if mark[u-1] == 0 and mark[v-1] == 0: ww += 1
        elif mark[u-1] == 0 or mark[v-1] == 0: bw += 1
        else: bb += 1
    for i in range(1, k+1):
        # number of components = n-i - (number of remaining edges)
        # both ends black -> always remains
        # both ends white -> remains in C(k-2, i)
        # one black and one white -> remains in C(k-1, i)
        ans = bb * C(k-2, i) % mod + bw * C(k-1, i) % mod
        ans = ans * pow(C(k, i), mod-2, mod) % mod
        ans = n - i - ww - ans
        print(ans%mod, end = ' ')
    print()
4 Likes

I’ve a question. First, it is mentioned that

But later,

Aren’t these contradictory statements? I believe the former is correct, since for x = 0 we have N - 1 edges, yielding 1 connected component. But then later we’re computing probability of removing edges? Whereas it should be that of keeping them?

If you look at the editorialist’s code it is much clearer. I agree you are correct. The solution is N - x - expected number of edges remaining.

1 Like

Thanks!