PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: kingmessi
Tester: watoac2001
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
DFS
PROBLEM:
You’re given a tree on N vertices.
Color each of its vertices either red or blue, such that the quantity |c_R - c_B| + |s_R - s_B| is minimized, where:
- c_R is the number of red vertices.
- s_R is the number of subtrees with strictly more red vertices than blue; or with an equal number of red and blue vertices but with the root being red.
EXPLANATION:
Looking at the quantity we want to minimize,
- If N is even, we have a chance of making it 0: if we’re able to have an equal number of red and blue vertices, and and have red and blue ‘win’ an equal number of subtrees.
- If N is odd, the best we can do is 2 instead: each term must be at least 1.
These lower bounds are attainable.
Suppose we want to decide the color of vertex u.
First, let’s color each child subtree of u recursively.
We now know that:
- Each child with even size contributes 0 to the total answer, so we simply ignore them from now on.
- Each child with odd size contributes 2 to the total answer.
- For simplicity, let’s assume this 2 comes from having both c_R\gt c_B and s_R\gt s_B.
So, if there are k odd children, our answer is currently 2k; because we have k extra red vertices compared to blue, and k extra subtrees where red wins over blue.
However, notice that if you take some odd child, and flip all the colors in its subtree,
- The number of extra red vertices is now k-2.
- The number of extra red subtree wins is also now k-2.
So, if we choose \frac{k}{2} of the odd children and flip their colors, we’ll be left with an answer of either 0 or 2 (depending on the parity of k).
This is pretty close to what we want, the only thing remaining is to choose a color for u.
Further, observe that as a result of this process, the number of red and blue vertices in the subtree of u will differ by at most one.
This means whichever color we give to u will also win the subtree of u.
If the current answer is 0, it doesn’t matter what the color of u is - whatever we color it, we’ll have one extra vertex of that color and one extra subtree win of that color.
So, in such a case, we color it red (since our recursive formulation assumed that there are always more reds then blues for simplicity).
If the current answer is 2, we instead color it blue - this will equalize the number of red and blue vertices, and the number of red and blue subtree wins; making the overall answer 0.
Performing this process starting from the root will indeed give us a valid coloring whose value is either 0 or 2 (depending on parity), and we’re done!
The only remaining issue is implementation: while the algorithm is correct, it’s quadratic in its current form because of all the subtree flipping we do.
There’s now a couple of ways to optimize this.
Method 1 (linear)
Notice that simply computing the initial color of each vertex doesn’t actually require us to do any flipping at all — we can simply pretend that some children subtrees had their colors flipped, which is enough to determine whether to color u red or blue.
Now, suppose we’ve determined the initial colors of each vertex.
If we’d actually performed the flips, the final color of u depends only on:
- The initial color of u; and
- The number of times some ancestor of u was chosen to be flipped.
So, when determining the initial colors of each vertex, let’s also mark which subtrees must be flipped (without actually performing the flips, of course).
This is still \mathcal{O}(N) time.
Then, run a second DFS - this time maintaining the number of ancestors of each vertex that were flipped.
This information will tell you whether it’s necessary to flip each vertex’s value from its original; hence solving the problem in \mathcal{O}(N) time.
Method 2 (NlogN)
When choosing which child subtrees, we can do it a bit smartly: instead of choosing \frac{k}{2} of them at random, choose the \frac{k}{2} smallest subtrees.
This simple-looking optimization brings the complexity of our ‘brute force’ down to \mathcal{O}(N\log N)!
Proof
Consider some vertex u, let’s look at the number of times it’s flipped.
Clearly, it can only be flipped by one of its ancestors.
Let the ancestors that flip it be v_1, v_2, \ldots, v_m, in descending order of depth.
Let c_i be the child of v_i that contains u, and s_i be the subtree size of c_i.
Now, since c_i was chosen to be flipped, there must exist some other child of v_i whose subtree size is at least s_i.
In particular, this other subtree will be part of every ancestor of v_i as well.
This means s_i \geq 2s_{i-1}, and this doubling shows that m \leq \log_2{N}.
So, this simple-looking optimization ensures that each vertex is flipped \mathcal{O}(\log N) times, bringing the overall complexity down to \mathcal{O}(N\log N).
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x)
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace __gnu_pbds;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;
const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
int power( int N, int M){
int power = N, sum = 1;
if(N == 0) sum = 0;
while(M > 0){if((M & 1) == 1){sum *= power;}
power = power * power;M = M >> 1;}
return sum;
}
const int N=200005;
vector<int> adj[N];
set<int> ch[N];
bool visited[N];
int par[N];
vector<int> depth(N);
void dfs(int current){
visited[current]=true;
for(int next_vertex : adj[current]){
if(visited[next_vertex])continue;
ch[current].insert(next_vertex);
par[next_vertex] = current;
depth[next_vertex] = depth[current]+1;
dfs(next_vertex);
}
}
void solve()
{
int n;
cin >> n;
rep(i,0,n+1)visited[i] = 0,adj[i].clear(),ch[i].clear();
rep(i,0,n-1){
int u,v;
cin >> u >> v;
adj[u].pb(v);
adj[v].pb(u);
}
// cout << "hi\n";return;
dfs(1);
// rep(i,1,n+1){
// if(!visited[i]){
// cout << "NO\n";return;
// }
// }
// cout << "hi\n";return;
set<pii> s;
rep(i,1,n+1)s.insert({depth[i],i});
vi ans(n+1,-1);
ch[0].insert(1);
// cout << "hi\n";return;
while(s.size()){
auto it = s.end();
it--;
auto [d,nd] = (*it);
s.erase(it);
ans[nd] = 1;
if(s.size() == 0){
break;
}
ch[par[nd]].erase(nd);
if(ch[par[nd]].size()){
auto gt = ch[par[nd]].begin();
int x = (*gt);
ans[x] = 0;
ch[par[nd]].erase(gt);
s.erase({depth[x],x});
}
else{
ans[par[nd]] = 0;
ch[par[par[nd]]].erase(par[nd]);
s.erase({depth[par[nd]],par[nd]});
}
}
// give(ans,n+1);cout << "\n";
rep(i,1,n+1){
if(ans[i] > 1 || ans[i] < 0){assert(false);}
(ans[i]?cout << "R":cout << "B");
}cout << "\n";
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
#ifdef NCR
init();
#endif
#ifdef SIEVE
sieve();
#endif
di(t)
while(t--)
solve();
return 0;
}
Tester's code (C++)
//****************************Template Begins****************************//
// Header Files
#include <bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<vector>
#include<utility>
#include<set>
#include<unordered_set>
#include<list>
#include<iterator>
#include<deque>
#include<queue>
#include<stack>
#include<set>
#include<bitset>
#include<map>
#include<unordered_map>
#include<stdio.h>
#include<complex>
#include<math.h>
#include<chrono>
#include<cstring>
#include<string>
// Header Files End
using namespace std;
#define fio ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
#define ll int
#define umap unordered_map
#define uset unordered_set
#define lb lower_bound
#define ub upper_bound
#define fo(i,a,b) for(i=a;i<b;i++)
#define all(v) (v).begin(),(v).end()
#define all1(v) (v).begin()+1,(v).end()
#define allr(v) (v).rbegin(),(v).rend()
#define allr1(v) (v).rbegin()+1,(v).rend()
#define sort0(v) sort(all(v))
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<ll> vll;
typedef pair<ll, ll> pll;
#define max3(a,b,c) max(max((a),(b)),(c))
#define max4(a,b,c,d) max(max((a),(b)),max((c),(d)))
#define min3(a,b,c) min(min((a),(b)),(c))
#define min4(a,b,c,d) min(min((a),(b)),min((c),(d)))
#define pb push_back
#define ppb pop_back
#define mp make_pair
#define inf 9999999999999
#define endl '\n'
#include "ext/pb_ds/assoc_container.hpp"
#include "ext/pb_ds/tree_policy.hpp"
using namespace __gnu_pbds;
template<class T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update> ;
template<class key, class value, class cmp = std::less<key>>
// find_by_order(k) returns iterator to kth element starting from 0;
// order_of_key(k) returns count of elements strictly smaller than k;
using ordered_map = tree<key, value, cmp, rb_tree_tag, tree_order_statistics_node_update>;
const ll mod1 = 998244353;
const ll mod = 1e9 + 7;
const ll MOD = 1e18 + 1e16;
ll mod_mul(ll a, ll b) {a = a % mod; b = b % mod; return (((a * b) % mod) + mod) % mod;}
ll inv(ll i) {if (i == 1) return 1; return (mod - ((mod / i) * inv(mod % i)) % mod) % mod;}
ll gcd(ll a, ll b) { if (b == 0) return a; return gcd(b, a % b);}
ll pwr(ll a, ll b) {a %= mod; ll res = 1; while (b > 0) {if (b & 1) res = res * a % mod; a = a * a % mod; b >>= 1;} return res;}
//****************************Template Ends*******************************//
//****************************Functions*******************************//
// const ll M = 4 * 1e5 + 1;
// ll parent[M];
// ll size1[M];
// void make_set(ll v) {
// parent[v] = v;
// size1[v] = 1;
// }
// ll find_set(ll v) {
// if (v == parent[v])
// return v;
// return parent[v] = find_set(parent[v]);
// }
// // DSU
// bool union_sets(ll a, ll b) {
// a = find_set(a);
// b = find_set(b);
// if (a == b) {
// return false;
// }
// if (a != b) {
// if (size1[a] < size1[b])
// swap(a, b);
// parent[b] = a;
// size1[a] += size1[b];
// }
// return true;
// }
//sparse table
class STable
{
public:
ll LOG = 0;
ll n;
vector<vector<ll>> T;
STable(vector<ll> &a)
{
n = (ll)a.size();
while (1 << (LOG + 1) <= n)
LOG++;
T.resize(LOG + 1);
for (ll i = 0; i <= LOG; i++)
T[i].resize(n);
for (ll i = 0; i < n; i++)
T[0][i] = a[i];
for (ll i = 1; i <= LOG; i++)
{
for (ll j = 0; j < n - (1 << i) + 1; j++)
{
T[i][j] = combine(T[i - 1][j], T[i - 1][j + (1 << (i - 1))]);
}
}
}
ll combine(ll a, ll b)
{
return max(a, b);
}
ll qry(ll l, ll r)
{
ll exp = 0;
while (1 << (exp + 1) <= r - l + 1)
exp++;
return combine(T[exp][l], T[exp][r - (1 << exp) + 1]);
}
};
// vector<vll> dp(n + 1, vll(k + 1, 99999999999));
// dp[0][0] = 0;
// for (ll i = 0; i <= n; i++)
// for (ll j = 0; j < k; j++) {
// dp[i][j + 1] = min(dp[i][j + 1], dp[i][j]);
// MCMF (hungarian algo) flows,graph matching,dp
// if (i < n)
// dp[i + 1][j + 1] = min(dp[i + 1][j + 1], dp[i][j] + abs(a[i] - (j + 1)));
// }
ll findMinNumber(ll n)
{
ll count = 0, ans = 1;
// Since 2 is oendly even prime, compute its
// power seprately.
while (n % 2 == 0)
{
count++;
n /= 2;
}
// If count is odd, it must be removed by dividing
// n by prime number.
if (count % 2)
ans *= 2;
for (ll i = 3; i <= sqrt(n); i += 2)
{
count = 0;
while (n % i == 0)
{
count++;
n /= i;
}
// If count is odd, it must be removed by
// dividing n by prime number.
if (count % 2)
ans *= i;
}
if (n > 2)
ans *= n;
return ans;
}
//*****binary lifting begins*****
// vector<vll>parent;
// vll depth;
// ll mx_pow = 20;// change this as required(max value is 2^(mx_pow-1), not 2^(mx_pow))
// void assign_parent(ll curr_node, ll par)
// {
// for (auto x : adj[curr_node])
// {
// if (x == par)res=1;continue;
// depth[x] = 1 + depth[curr_node];
// parent[x][0] = curr_node;
// assign_parent(x, curr_node);
// }
// }
// // call the following function in main :
// void binary_lifting(ll n)
// {
// ll i, j, temp;
// parent.assign(n + 1, vll(mx_pow, 0));
// depth.assign(n + 1, 0);
// assign_parent(1, -1);
// for (j = 1; j < mx_pow; j++)
// {
// for (i = 1; i <= n; i++)
// {
// temp = parent[i][j - 1];
// parent[i][j] = parent[temp][j - 1];
// }
// }
// }
// //*****binary lifting ends*****
// //******lca begins*******
// // comment the following function if you oendly need binary lifting and not lca
// ll get_lca(ll u, ll v)
// {
// if (depth[u] < depth[v])swap(u, v);
// ll jump = depth[u] - depth[v];
// ll i, j;
// for (j = 0; j < mx_pow; j++)
// {
// if (jump & (1 << j))
// {
// u = parent[u][j];
// }
// }
// if (u == v)return u;
// for (j = mx_pow - 1; j >= 0; j--)
// {
// if (parent[u][j] != parent[v][j])
// {
// u = parent[u][j];
// v = parent[v][j];
// }
// }
// return parent[u][0];
// }
// //*******lca ends********
// ll get_distance(ll u, ll v)
// {
// ll lc = get_lca(u, v);
// return depth[u] + depth[v] - 2 * depth[lc];
// }
// ll fact(ll n , ll mod)
// {
// ll p=1;
// if (n == 0)
// return 1;
// else
// {
// for(ll i=1;i<=n;i++)
// {p*=i;
// p=p%mod;}
// return p;
// }
// }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// vll fact(300005, 1);
// fo(i, 2, 300005)
// {
// fact[i] = mod_mul(fact[i - 1], i); //include the loop inside ll main()
// }
// ll C(ll n, ll r)
// {
// ll ans = 1;
// ans = fact[n];
// ans = mod_mul(ans, pwr(fact[r], mod - 2));
// ans = mod_mul(ans, pwr(fact[n - r], mod - 2));
// return ans;
// }
// vector<bool>isprime(10000001, true);
// vll primes;
// void sieve()
// {
// ll i, j;
// fo(i, 2, 10000001)
// {
// if (isprime[i])
// {
// for (j = i * i; j <= 10000000; j += i)
// {
// isprime[j] = false;
// }
// }
// }
// fo(i, 2, 10000001)if (isprime[i])primes.pb(i);
// }
// -------------- DIJKSTRA --------------
// ll INF = 1e16;
// void disp(vll v)
// {
// for(auto u:v)
// cout<<u<<" ";
// cout<<"\n";
// }
// vll p(200001),d(200001);
// map<pll,ll>m1;
// vector<pair<ll,ll>> adj[100005];
// ll n;
// void dijkstra(ll s,vll d, vll p, ll f)
// {
// d .assign(n+1,INF);
// p .assign(n+1,-1);
// priority_queue<pll> q;
// d[s]=0;
// q .push({0,s});
// while(!q.empty())
// {
// ll v =q.top().second;
// ll d_v = -(q.top().first);
// q .pop();
// if(d_v!=d[v])continue;
// for(auto edge:adj[v])
// {
// auto to = edge.first;
// auto len = edge.second;
// if(d[v] + len < d[to])
// {
// d[to] = d[v] + len;
// p[to]=v;
// q .push({-d[to],to});
// }
// }
// }
// ll v=f;
// if(d[f]==INF)
// {
// cout<<-1;
// return ;
// }
// vll path;
// while(v!=-1)
// {
// path .pb(v);
// v=p[v];
// }
// reverse(all(path));
// disp(path);
// }
// -------------- SEGMENT TREE ----------
// ll const N = 2e5 + 10;
// ll t[4 * N];
// void update(ll v, ll tl, ll tr, ll id, ll val)
// {
// if (tr == tl and tr == id)
// {
// t[v] = val;
// return;
// }
// if (id<tl or id>tr)
// return;
// ll tm = (tr + tl) / 2;
// update(2 * v, tl, tm, id, val);
// update(2 * v + 1, tm + 1, tr, id, val);
// t[v] = t[2 * v] + t[2 * v + 1];
// }
// ll query(ll v, ll tl, ll tr, ll l, ll r)
// {
// if (l > tr or r < tl)
// return 0;
// if (l <= tl and r >= tr)
// return t[v];
// ll tm = (tr + tl) / 2;
// ll ans = 0;
// ans += query(2 * v, tl, tm, l, r);
// ans += query(2 * v + 1, tm + 1, tr, l, r);
// return ans;
// }
// vll dx = {1, -1, 0, 0};
// vll dy = {0, 0, 1, -1};
// ll N = 1e5;
// vll fact(N + 1, 1);
// vll ifact(N + 1, 1);
// ll C(ll n, ll r)
// {
// if (r > n)
// return 0;
// ll ans = fact[n];
// ans = mod_mul(ans, ifact[r]);
// ans = mod_mul(ans, ifact[n - r]);
// return ans;
// }
// fo(i, 2, N + 1)
// {
// fact[i] = mod_mul(fact[i - 1], i);
// }
// ifact[N] = inv(fact[N]);
// for (i = N - 1; i > 0; i--)
// {
// ifact[i] = mod_mul(i + 1, ifact[i + 1]);
// }
// ll const N = 3e6 + 10;
// ll t[8 * N], lazy[8 * N];
// void push(ll v) {
// t[v * 2] += lazy[v];
// lazy[v * 2] += lazy[v];
// t[v * 2 + 1] += lazy[v];
// lazy[v * 2 + 1] += lazy[v];
// lazy[v] = 0;
// }
// void update(ll v, ll tl, ll tr, ll l, ll r, ll addend)
// {
// if (l > r)
// return;
// if (l == tl && tr == r) {
// t[v] += addend;
// lazy[v] += addend;
// } else {
// push(v);
// ll tm = (tl + tr) / 2;
// update(v * 2, tl, tm, l, min(r, tm), addend);
// update(v * 2 + 1, tm + 1, tr, max(l, tm + 1), r, addend);
// t[v] = max(t[v * 2], t[v * 2 + 1]);
// }
// }
// ll query(ll v, ll tl, ll tr, ll l, ll r) {
// if (l > r)
// return -1e9;
// if (l <= tl && tr <= r)
// return t[v];
// push(v);
// ll tm = (tl + tr) / 2;
// return max(query(v * 2, tl, tm, l, min(r, tm)),
// query(v * 2 + 1, tm + 1, tr, max(l, tm + 1), r));
// }
const auto start_time = chrono::high_resolution_clock::now();
void output_run_time() {
// will work for ac,cc&&cf.
#ifndef ONLINE_JUDGE
auto end_time = chrono::high_resolution_clock::now();
chrono::duration<double> diff = end_time - start_time;
cout << "\n\n\nTime Taken : " << diff.count();
#endif
}
ll n, cnt = 0, mcnt = 0;
vector<vll> adj(200010);
vll ans(200010), val(200010), d(200010), par(200010);
vector<set<ll>> cc(200010);
void dfs(ll v, ll p)
{
for (ll u : adj[v])
{
if (u == p)
continue;
par[u] = v;
cc[v].insert(u);
d[u] = d[v] + 1;
dfs(u, v);
}
}
void get_score(ll v, ll p, ll &ans1)
{
if (val[v] < 0)
{
ans1--;
}
else if (val[v] > 0)
{
ans1++;
}
else
ans1 += ans[v];
for (ll u : adj[v])
{
if (u == p)
continue;
get_score(u, v, ans1);
}
}
int main() {
fio;
ll t;
cin >> t;
while (t--)
{
cin >> n;
ll x, y, i;
ll a[n][2];
cnt = 0;
adj.assign(n + 1, vll(0));
val.assign(n + 1, 0ll);
ans.assign(n + 1, 0ll);
fo(i, 1, n + 1)
{
cc[i].clear();
par[i] = -1;
}
fo(i, 1, n)
{
cin >> x >> y;
a[i][0] = x;
a[i][1] = y;
adj[x].pb(y);
adj[y].pb(x);
}
mcnt = 0;
d[1] = 0;
dfs(1, -1);
set<pll> s;
fo(i, 1, n + 1)
{
s.insert({ -d[i], i});
}
// int gg = 0;
while (s.size())
{
ll x = s.begin()->second;
ans[x] = 1;
s.erase({ -d[x], x});
if (s.size() == 0)
break;
cc[par[x]].erase(x);
if (cc[par[x]].size())
{
ll y = *cc[par[x]].begin();
cc[par[x]].erase(y);
ans[y] = -1;
s.erase({ -d[y], y});
}
else
{
ans[par[x]] = -1;
s.erase({ -d[par[x]], par[x]});
if (par[par[x]] != -1)
cc[par[par[x]]].erase(par[x]);
}
}
fo(i, 1, n + 1)
{
if (ans[i] == 1)
cout << 'B';
else
cout << 'R';
}
cout << endl;
}
// output_run_time();
return 0;
}
//remove #define endl '\n' for lleractive problems
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector adj(n, vector<int>());
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
adj[v].push_back(u);
}
vector<int> flip(n), subsz(n), col(n);
auto dfs = [&] (const auto &self, int u, int p) -> void {
subsz[u] = 1;
vector<int> odd;
for (int v : adj[u]) {
if (v == p) continue;
self(self, v, u);
subsz[u] += subsz[v];
if (subsz[v] & 1) odd.push_back(v);
}
int k = odd.size();
for (int i = 0; i < k/2; ++i) flip[odd[i]] = 1;
if (subsz[u]%2 == 0) {
// currently +1
// make it 0 by coloring 2
col[u] = 2;
}
else {
// currently 0
// make it +1 by coloring 1
col[u] = 1;
}
};
dfs(dfs, 0, 0);
auto fix = [&] (const auto &self, int u, int p, int change = 0) -> void {
if (change) col[u] = 3 - col[u];
for (int v : adj[u]) {
if (v == p) continue;
self(self, v, u, change ^ flip[v]);
}
};
fix(fix, 0, 0);
for (int i = 0; i < n; ++i) {
if (col[i] == 1) cout << 'R';
else cout << 'B';
}
cout << '\n';
}
}