PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Jatin Yadav
Tester: Riley Borgard
Editorialist: Aman Dwivedi
DIFFICULTY
Medium
PREREQUISITES
Tree, DFS, RMQ
PROBLEM
You are given a rooted tree with N nodes numbered 1, 2, \ldots, N. Node 1 is the root node. Some of the nodes have a token in them. In one move, you can choose a non-root node that has a token, but its parent doesn’t, and move the token from this node to its parent. What is the maximum number of moves you can make?
Note: When a token is moved out of a node, the node becomes empty, and other tokens will be able to move there.
QUICK EXPLANATION:
We can run DFS on tree as, DFS(s): First call DFS for all children v of s. Now If node s initially had a token, then we can do nothing otherwise find the deepest token in its subtree and shift it to node s.
EXPLANATION:
The idea is that we can take the deepest token that has a free ancestor and move it to the closest (deepest) free ancestor.
Proof
Let us prove it by contradiction:
Let’s look at an optimal solution that doesn’t move this token. This token had more ancestor vertices than ancestor tokens.
So in the end:
- It has a free ancestor, which is definitely not optimal we can just move this token there as our goal is to maximize the number of moves.
Hence, to do so we will go through all the vertices from the deepest. If there is a token in it we will try to find the closest free ancestor and move this token there.
To simplify the implementation, we can do DFS on the tree as
DFS(s): Such that all the nodes of subtree which are rooted at node s have been explored. Now If node s initially had a token, then we can do nothing otherwise find the deepest token in its subtree and shift it to node s.
Subtask 1:
T\le10,N\le17
Since the value of N is so small we can try every possible combination to shift which token to which node etc and find such a combination that maximizes the number of moves.
But yes definitely it is an overkill solution.
Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld double
template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
os<<"("<<p.first<<", "<<p.second<<")";
return os;
}
template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
os<<"{";
for(int i = 0;i < (int)v.size(); i++){
if(i)os<<", ";
os<<v[i];
}
os<<"}";
return os;
}
#ifdef LOCAL
#define cerr cout
#else
#endif
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
template<class T>
vector<T> readVector(int n, long long l, long long r){
vector<T> ret(n);
for(int i = 0; i < n; i++){
ret[i] = i == n - 1 ? readIntLn(l, r) : readIntSp(l, r);
}
return ret;
}
const int SN = 1000;
// O(N 2^N)
int get(int n, vector<int> par, string type){
vector<vector<int>> children(n);
for(int i = 1; i < n; i++){
assert(par[i] >= 0 && par[i] < i);
children[par[i]].push_back(i);
}
vector<int> depth(n);
vector<int> masks(n);
int sum = 0;
vector<vector<bool>> isAncestor(n, vector<bool>(n, 0));
vector<vector<bool>> isOK(1 << n, vector<bool>(n, 0));
for(int i = 1; i < n; i++){
depth[i] = depth[par[i]] + 1;
sum += (type[i] - '0') * depth[i];
isAncestor[0][i] = true;
int u = i;
while(u != 0){
isAncestor[u][i] = true;
u = par[u];
}
}
for(int mask = 0; mask < (1 << n); mask++){
for(int i = 0; i < n; i++) if(mask >> i & 1){
isOK[mask][i] = true;
for(int j = 0; j < i; j++) if((mask >> j & 1) && isAncestor[j][i]) isOK[mask][i] = false;
}
}
const int INF = 1 << 29;
vector<vector<int>> dp(n, vector<int>(1 << n, INF));
for(int s = n - 1; s >= 0; s--){
if(type[s] == '1') masks[s] |= 1<<s;
for(int v : children[s]){
masks[s] |= masks[v];
}
dp[s][0] = 0;
for(int submask = masks[s]; submask; submask = (submask - 1) & masks[s]){
for(int i = 0; i < n; i++) if(isOK[submask][i]){
int cost = depth[s];
int mask = submask ^ (1 << i);
for(int v : children[s]){
cost += dp[v][mask & masks[v]];
cost = min(cost, INF);
}
dp[s][submask] = min(dp[s][submask], cost);
}
}
}
return sum - dp[0][masks[0]];
}
int main(){
int t; cin >> t;
int sn = 0;
while(t--){
int n; cin >> n;
string type; cin >> type;
vector<int> par(n);
for(int i = 1; i < n; i++){
cin >> par[i];
par[i]--;
}
cout << get(n, par, type) << endl;
}
}
Subtask 2:
The sum of N over all test cases doesn’t exceed 2000.
So during DFS if we are at some node say s (such that its subtree is already explored) which doesn’t have a token. Then we need to find the deepest node in its subtree which has a token so that we can shift that token to this node.
We can find this deepest node by doing DFS again on this subtree since the value of N is small enough and it does allow us to do DFS again. Once we found that node we shift that node token to node s adding the number of moves that were needed to our answer.
Subtask 3:
The sum of N over all test cases doesn’t exceed 10^5.
The idea is the same i.e during DFS if there is a node that doesn’t have a token then we will simply find the deepest node that has a token on its subtree and will shift that token to this node.
But finding the deepest node in the subtree by traversing each node again will result in TLE as the value of N is large enough this time.
To optimize it, we can maintain the depths of nodes with the help of multiset + offset, use small to large merging.
This results in a O(N*log^2 N) solution which will be good enough to pass this subtask.
Solution
#include <bits/stdc++.h>
#define ll long long
#define sz(x) ((int) (x).size())
#define all(x) (x).begin(), (x).end()
#define vi vector<int>
#define pii pair<int, int>
#define rep(i, a, b) for(int i = (a); i < (b); i++)
using namespace std;
template<typename T>
using minpq = priority_queue<T, vector<T>, greater<T>>;
// O(n log^2 n) solution
// dfs on the tree
// if a node is empty and has a token in the subtree, jump the deepest token up
// maintain the depths with a multiset + offset, use small to large merging
struct tokens {
multiset<int> depths;
int offset = 0;
};
void solve() {
int n;
string s;
cin >> n >> s;
vector<vi> g(n + 1);
rep(i, 2, n + 1) {
int p;
cin >> p;
g[p].push_back(i);
}
ll ans = 0;
vector<tokens> ma(n + 1);
function<void(int)> dfs = [&](int x) {
for(int y : g[x]) {
dfs(y);
ma[y].offset++;
if(sz(ma[x].depths) < sz(ma[y].depths)) {
for(int d : ma[x].depths) {
ma[y].depths.insert(d + ma[x].offset - ma[y].offset);
}
ma[x].depths.swap(ma[y].depths);
swap(ma[x].offset, ma[y].offset);
}else {
for(int d : ma[y].depths) {
ma[x].depths.insert(d + ma[y].offset - ma[x].offset);
}
}
}
ma[x].depths.insert(-ma[x].offset);
if(s[x - 1] == '0') {
int d = *prev(ma[x].depths.end()) + ma[x].offset;
ma[x].depths.erase(prev(ma[x].depths.end()));
ans += d;
}
};
dfs(1);
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int te;
cin >> te;
while(te--) solve();
}
Subtask 4:
Original Constraints
As the value of N is large enough to get the TLE verdict for our O(N*log^2N) solution. We can optimize it further to the O(N*logN) solution by using Euler Tour and Range Minimum Query.
We can simply do Euler Tour in the given tree and build the RMQ structure of this Euler tour. So when we are at node s which doesn’t have a token on it then we can make a query on the subtree and can find the deepest token using RMQ.
Hence we are able to optimize our solution to O(N*log(N)).
TIME COMPLEXITY:
O(N*log(N)) per test case
SOLUTIONS:
Setter
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#ifdef LOCAL
#include <print.h>
#else
#define trace(...)
#endif
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd, char minc = 'a', char maxc = 'z'){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
assert(g >= minc && g <= maxc);
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r, char minc = 'a', char maxc = 'z'){
return readString(l,r,'\n', minc, maxc);
}
string readStringSp(int l,int r, char minc = 'a', char maxc = 'z'){
return readString(l,r,' ', minc, maxc);
}
template<class T>
vector<T> readVector(int n, long long l, long long r){
vector<T> ret(n);
for(int i = 0; i < n; i++){
ret[i] = i == n - 1 ? readIntLn(l, r) : readIntSp(l, r);
}
return ret;
}
template<class T>
struct segtree{
int n;
vector<T> t, A;
T def;
inline T combine(T a, T b){
if(a == -1) return b;
if(b == -1) return a;
return A[a] > A[b] ? a : b;
}
segtree(vector<T> inp) : n(sz(inp)), A(inp), def(-1){
t.resize(2 * n, def);
for(int i = 0; i < n; i++) t[n + i] = i;
for(int i = n - 1; i > 0; --i) t[i] = combine(t[i<<1], t[i<<1|1]);
}
void modify(int p, T value) { // modify A[p] = value
// value = combine(value, t[p + n]); // if a[p] = combine(a[p], value)
A[p] = value;
for (p += n; p >>= 1; ) t[p] = combine(t[p<<1], t[p<<1|1]);
}
T query(int l, int r) { // compute on interval [l, r]
r++;
T resl = def, resr = def;
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l&1) resl = combine(resl, t[l++]);
if (r&1) resr = combine(t[--r], resr);
}
return combine(resl, resr);
}
};
long long get(int n, vector<int> par, string type){
vector<vector<int>> children(n);
for(int i = 1; i < n; i++){
assert(par[i] >= 0 && par[i] < i);
children[par[i]].push_back(i);
}
vector<int> depth(n, 1);
long long sum = (type[0] - '0');
for(int i = 1; i < n; i++){
depth[i] = depth[par[i]] + 1;
sum += (type[i] - '0') * depth[i];
}
vector<int> st(n), en(n);
stack<int> stk;
stk.push(0);
int timer = 0;
while(!stk.empty()){
int s = stk.top(); stk.pop();
st[s] = timer++;
for(int v : children[s]) stk.push(v);
}
for(int s = n - 1; s >= 0; s--){
en[s] = st[s];
reverse(all(children[s]));
for(int v : children[s]) en[s] = en[v];
}
segtree<int> stree(vector<int>(n, 0));
for(int s = n - 1; s >= 0; s--){
stree.modify(st[s], depth[s]);
sum -= depth[s];
if(type[s] == '0'){
int u = stree.query(st[s], en[s]);
sum += stree.A[u];
stree.modify(u, 0);
}
}
return sum;
}
const int SN = 1000000;
int main(){
int t = readIntLn(1, SN);
int sn = 0;
while(t--){
int n = readIntLn(1, SN);
sn += n;
assert(sn <= SN);
string type = readStringLn(n, n, '0', '1');
vector<int> par = readVector<int>(n - 1, 0, n);
reverse(all(par)); par.push_back(0); reverse(all(par));
for(int i = 1; i < n; i++){
par[i]--;
}
cout << get(n, par, type) << endl;
}
}
Tester
#pragma GCC optimize ("Ofast")
#include <bits/stdc++.h>
#define ll long long
#define sz(x) ((int) (x).size())
#define all(x) (x).begin(), (x).end()
#define vi vector<int>
#define pii pair<int, int>
#define rep(i, a, b) for(int i = (a); i < (b); i++)
using namespace std;
template<typename T>
using minpq = priority_queue<T, vector<T>, greater<T>>;
// O(n log n), optimized from O(n log^2 n) solution
// to query deepest token in subtree, use euler tour tree and RMQ
void solve() {
int n;
string s;
cin >> n >> s;
vector<vi> g(n + 1);
rep(i, 2, n + 1) {
int p;
cin >> p;
g[p].push_back(i);
}
ll ans = 0;
vi tree(4 * n, n);
vi a(n + 1, -1);
function<int(int, int, int, int, int)> query = [&](int i, int l, int r, int L, int R) {
if(r < L || R < l) return n;
if(L <= l && r <= R) return tree[i];
int m = (l + r) / 2;
int j = query(2 * i + 1, l, m, L, R);
int k = query(2 * i + 2, m + 1, r, L, R);
return a[j] > a[k] ? j : k;
};
function<void(int, int, int, int, int)> upd = [&](int i, int l, int r, int k, int x) {
if(l == r) {
a[k] = x;
tree[i] = k;
return;
}
int m = (l + r) / 2;
if(k <= m) upd(2 * i + 1, l, m, k, x);
else upd(2 * i + 2, m + 1, r, k, x);
tree[i] = (a[tree[2 * i + 1]] > a[tree[2 * i + 2]] ? tree[2 * i + 1] : tree[2 * i + 2]);
};
vi tin(n + 1), tout(n + 1), dep(n + 1);
int ti = 0;
function<void(int)> dfs = [&](int x) {
tin[x] = ti++;
for(int y : g[x]) {
dep[y] = 1 + dep[x];
dfs(y);
}
tout[x] = ti;
upd(0, 0, n - 1, tin[x], dep[x]);
if(s[x - 1] == '0') {
int j = query(0, 0, n - 1, tin[x], tout[x] - 1);
ans += a[j] - dep[x];
upd(0, 0, n - 1, j, -1);
}
};
dfs(1);
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int te;
cin >> te;
while(te--) solve();
}
Editorialist
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mxN=1e6+5;
vector <int> adj[mxN];
int n;
string s;
bool visited[mxN];
int level[mxN];
int tin[mxN];
int sz[mxN];
int timer;
int ans;
vector <int> a;
pair<int,int> t[4*mxN];
void build(int arr[],int v,int tl,int tr)
{
if(tl==tr)
{
t[v].first=arr[tl];
t[v].second=tl;
}
else
{
int tm=(tl+tr)/2;
build(arr,v*2,tl,tm);
build(arr,v*2+1,tm+1,tr);
t[v].first=max(t[v*2].first,t[v*2+1].first);
if(t[v].first==t[v*2].first)
t[v].second=t[v*2].second;
else
t[v].second=t[v*2+1].second;
}
}
pair<int,int> find_max(int v,int tl,int tr,int l,int r)
{
if(l>r)
return {-1,-1};
if(l==tl && r==tr)
return t[v];
int tm=(tl+tr)/2;
pair<int,int> fst=find_max(v*2,tl,tm,l,min(r,tm));
pair<int,int> snd=find_max(v*2+1,tm+1,tr,max(l,tm+1),r);
if(fst.first>snd.first)
return fst;
else
return snd;
}
void update(int v,int tl,int tr,int pos,int new_val)
{
if(tl==tr)
t[v].first=new_val;
else
{
int tm=(tl+tr)/2;
if(pos<=tm)
update(v*2,tl,tm,pos,new_val);
else
update(v*2+1,tm+1,tr,pos,new_val);
t[v].first=max(t[v*2].first,t[v*2+1].first);
if(t[v].first==t[v*2].first)
t[v].second=t[v*2].second;
else
t[v].second=t[v*2+1].second;
}
}
void dfs(int v)
{
visited[v]=false;
for(auto x: adj[v])
{
if(visited[x])
dfs(x);
}
if(s[v]=='0')
{
int l=tin[v];
int r=l+sz[v]-1;
// cout<<v<<" "<<l<<" "<<r<<endl;
pair <int,int> val=find_max(1,0,n-1,l,r);
// cout<<val.first<<" "<<val.second<<endl;
if(val.first!=0)
{
ans+=(val.first-level[v]);
update(1,0,n-1,val.second,0);
update(1,0,n-1,tin[v],level[v]);
}
}
}
int euler_tour(int v,int he)
{
tin[v]=timer;
level[v]=he;
visited[v]=true;
timer++;
a.push_back(v);
for(auto x: adj[v])
{
if(!visited[x])
sz[v]+=euler_tour(x,he+1);
}
sz[v]++;
return sz[v];
}
void solve()
{
cin>>n;
cin>>s;
ans=0;
timer=0;
a.clear();
for(int i=0;i<n;i++)
{
adj[i].clear();
visited[i]=false;
level[i]=0;
tin[i]=-1;
sz[i]=0;
}
for(int i=1;i<n;i++)
{
int x;
cin>>x;
adj[x-1].push_back(i);
}
int waste=euler_tour(0,1);
int arr[n];
for(int i=0;i<n;i++)
{
if(s[a[i]]=='1')
arr[i]=level[a[i]];
else
arr[i]=0;
}
build(arr,1,0,n-1);
dfs(0);
cout<<ans<<"\n";
}
int32_t main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
int tc;
cin>>tc;
while(tc--)
solve();
return 0;
}