BAWTREE - Editorial

PROBLEM LINK:

Practice
Div1
Div2
Div3

Setter: Lavish Gupta
Tester: Samarth Gupta
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

MEDIUM

PREREQUISITES:

Trees, Dynamic programming, Depth First Search

PROBLEM:

We are given a tree with N nodes where each node is black or white in color. In one step, we can choose a vertex u and toggle u along with all the neighbors of u.

We are asked to find the minimum number of steps required to make all the nodes black or report that it is not possible.

QUICK EXPLANATION:

  • We can initially root the tree at some vertex.

  • Let color 0 be black and color 1 be white.

  • Let us define an N \cdot 2 \cdot 2 dp. Let Let dp[u][state][toggle] denote the minimum number of steps required to make the color of vertex u equal to state and all the other vertices of subtree of u equal to black and also toggle denotes whether one step is applied on vertex u or not.

  • This dp can be calculated by depth first search with bottom-up approach.

  • If vertex u is a leaf, then dp[u][color[u]][0]=0 and dp[u][color[u] \oplus 1][1]=1.

  • While calculating dp[u][state][toggle], first we need to figure out what is the final state of every direct child of u must be depending on toggle. Let it be fin.

  • After that, for every child x we need to take the minimum values of child dp states dp[x][fin][0] and dp[x][fin][1]. Let the parity of total child toggles taken here be par.

  • While calculating dp[u][state][toggle], we also need to figure out what is the parity of the total number of direct child toggles needed in order to make the color of u equal to state. If this value is equal to par, we are done. Else for exactly one child x, we need to take the maximum of dp[x][fin][0] and dp[x][fin][1] instead of minimum and update the dp state accordingly.

EXPLANATION:

Firstly, make the tree rooted by fixing some vertex as root.

The first observation is that we do not apply more than one step at some vertex. This is because for suppose if we apply two steps, it is the same as not applying any step at all.

Let color 0 represent black and color 1 represent white.

Also, let color[u] denote the color of vertex u.

This problem can be solved by dynamic programming. For this, let us define a state. Let dp[u][state][toggle] denote the minimum number of steps required to make the color of vertex u equal to state and all the other vertices of subtree of u equal to black and also toggle denotes whether one step is applied on vertex u or not.

It is important to note that, while calculating the dp state of u, we only consider the subtree of u and ignore everything else.

Let us intialize all the values in the dp state to infinity.

Let us first see the base case.

  • Let u be a leaf. Now it takes 0 operations to make color of u equal to color[u] and 1 operation with toggle to make color of u equal to color[u] \oplus 1. ( Here xor is used just for showing the other color by toggling the current color ).

  • Therefore, the dp transitions for vertex u being a leaf are dp[u][color[u]][0]=0 and dp[u][color[u] \oplus 1][1] = 1.

Now let us try to see how to make the transitions for vertex u being not a child. Let us assume state=0. ( The similar case work can be done for state =1 ). Let us also assume that the initial color[u] = 0. ( The similar case work can be done for color[u] =1 ).

Therefore, we are now trying to calculate dp[u][0][0] and dp[u][0][1] where color[u] =0.

Case 1: \hspace{1 mm} toggle = 0

  • This means that we haven’t applied any step on vertex u.

  • Also, according to our dp assumption every vertex in the subtree of u other than it must be black i.e, color 0.

  • Therefore iterate over every child x of u, and take the minimum of dp[x][0][0] and dp[x][0][1] and add it to dp[u][0][0].

  • But we are not done yet. Since color[u] = 0 initially, we need to apply a total of even number of steps on the direct children of u which share an edge with it in order to keep
    color[u] =0 = state.

  • By taking the minimums as explained above, if the number of children toggles are even, then we are done. If it is odd, we need to pick exactly one child x of u and do the following: We need to change our previous decision and take the maximum of dp[x][0][0] and dp[x][0][1] instead of minimum and update dp[u][0][0] accordingly. And we need to take such a child x which incurs the minimum extra cost to dp[u][0][0].

Case 2: \hspace{1 mm} toggle = 1

  • This means that we have applied exactly 1 step on vertex u.

  • Also, according to our dp assumption every vertex in the subtree of u other than it must be black i.e, color 0.

  • Therefore iterate over every child x of u, and take the minimum of dp[x][1][0] and dp[x][1][1] and add it to dp[u][0][1]. Here, unlike the previous case, we want the state of x
    to be equal to 1 because anyways later it will become equal to state 0 by the step/toggle applied on vertex u.

  • Since color[u] = 0 initially, we need to apply a total of odd number of steps on the direct children of u which share an edge with it in order to make color[u] = 1 and after a toggle on u it becomes color[u] = 0 = state.

  • By taking the minimums as explained above, if the number of children toggles are odd, then we are done. If it is even, we need to pick exactly one child x of u and do the following: We need to change our previous decision and take the maximum of dp[x][1][0] and dp[x][1][1] instead of minimum and update dp[u][0][1] accordingly. And we need to take such a child x which incurs the minimum extra cost to dp[u][0][1].

  • Finally, we need to add 1 to dp[u][0][1] for the toggle we are applying on vertex u.

Calculating these dp states can be done in a bottom up manner by using depth first search from the root.

You can have a look at the code for better understanding.

TIME COMPLEXITY:

O(N) for each testcase.

SOLUTION:

Editorialist's solution

#include <bits/stdc++.h>
using namespace std;

const int INF = 1e9;

void dfs(int i, int par, vector<vector<int>> &tree, vector<int> &color, vector<vector<vector<int>>> &dp)
{
     bool leaf = true;
     for (int child : tree[i])
     {
          if (child == par)
               continue;
          leaf = false;
          dfs(child, i, tree, color, dp);
     }

     if (leaf)
     {
          dp[i][color[i]][0] = 0;
          dp[i][color[i] ^ 1][1] = 1;
          return;
     }

     for (int wanted_state = 0; wanted_state < 2; wanted_state++)
     {
          for (int toggle = 0; toggle < 2; toggle++)
          {
               // This is the total number of child toggles parity (whether it should be even or odd) in order to make the state of
               // vertex i equal to the wanted_state

               int needed_parity_of_child_toggles = wanted_state ^ color[i] ^ toggle;

               vector<int> selected_child_state;
               vector<int> other_child_state;
               bool impossible = false;
               int parity_of_selected_child_toggles = 0;

               for (int child : tree[i])
               {
                    if (child == par)
                         continue;

                    //Here we assume all nodes in subtree of i except i should be black(0) and solve accordingly
                    if (dp[child][toggle][0] == INF && dp[child][toggle][1] == INF)
                    {
                         impossible = true;
                         break;
                    }

                    if (dp[child][toggle][0] <= dp[child][toggle][1])
                    {
                         selected_child_state.push_back(dp[child][toggle][0]);
                         other_child_state.push_back(dp[child][toggle][1]);
                    }
                    else
                    {
                         selected_child_state.push_back(dp[child][toggle][1]);
                         other_child_state.push_back(dp[child][toggle][0]);
                         parity_of_selected_child_toggles ^= 1;
                    }
               }

               if (impossible)
                    continue;

               int tot = toggle;
               for (int x : selected_child_state)
                    tot += x;

               if (parity_of_selected_child_toggles != needed_parity_of_child_toggles)
               {
                    // Now we need to replace one of the selected_child_state with its other_child_state to get the required parity of
                    // child toggles

                    int extra = INF;

                    for (int cur = 0; cur < other_child_state.size(); cur++)
                    {
                         if (other_child_state[cur] != INF)
                              extra = min(extra, other_child_state[cur] - selected_child_state[cur]);
                    }

                    if (extra == INF)
                    {
                         impossible = true;
                         continue;
                    }

                    tot += extra;
               }

               dp[i][wanted_state][toggle] = tot;
          }
     }
}

int main()
{
     int tests;
     cin >> tests;
     while (tests--)
     {
          int n;
          cin >> n;

          vector<vector<int>> tree(n + 1);
          vector<int> color(n + 1);

          //Initialize the values to infinity
          vector<vector<vector<int>>> dp(n + 1, vector<vector<int>>(2, vector<int>(2, INF)));

          for (int i = 1; i <= n; i++)
          {
               cin >> color[i];
          }

          for (int i = 1; i < n; i++)
          {
               int x, y;
               cin >> x >> y;
               tree[x].push_back(y);
               tree[y].push_back(x);
          }

          dfs(1, -1, tree, color, dp);

          int ans = min(dp[1][0][0], dp[1][0][1]);

          if (ans == INF)
               ans = -1;

          cout << ans << endl;
     }
     return 0;
}

Setter's solution
/* Lavish Gupta */
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
//#include<bits/stdc++.h>
#include<iomanip>
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>
dd pi = acos(-1) ;
ll z =  1000000007 ;
ll inf = 1000000000000000 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 =  202976689 ;
ll mod2 =  203034253 ;
ll fact[20] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ;    return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 ; e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return e_gcd(b , a%b , x1 , y1) ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
// #include <ext/pb_ds/assoc_container.hpp> 
// #include <ext/pb_ds/tree_policy.hpp> 
// using namespace __gnu_pbds; 
// #define ordered_set tree<pll, null_type,less<pll>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
#define tll tuple<ll ,ll ,ll>
#define pll pair<ll ,ll>
const int N = 1e5 + 5 ;
ll dp[N][2][2] ;
 
void get_min(vector<ll> &v , ll parity , ll &ans)
{
    if(parity == 1)
    {
        ans += v[0] ;
    }
    ll ind = parity ;
    for(; ind+1 < v.size() ; ind += 2)
    {
        ll curr = v[ind] + v[ind+1] ;
        if(curr < 0)
            ans += curr ;
        else
            break ;
    }
    return ;
}
 
void dfs(vector<ll> adj[] , ll col[] , ll u , ll par)
{
    ll flag = 0 ;
    vector<ll> children ;
    fo(i , adj[u].size())
    {
        ll v = adj[u][i] ;
        if(v == par)
            continue ;
        children.pub(v) ;
        dfs(adj , col , v , u) ;
        flag = 1 ;
    }
    if(flag == 0)
    {
        if(col[u] == 0)
        {
            dp[u][0][0] = 0 ;
            dp[u][1][1] = 1 ;
        }
        else
        {
            dp[u][1][0] = 0 ;
            dp[u][0][1] = 1 ;
        }
        return ;
    }
    ll ans = 0 ;
    fo(j , 2)
    {
        fo(k , 2)
        {
            ans = 0 ;
            ll jd = k ;
            ll parity = (col[u] + j + k) % 2 ;
 
            vector<ll> v1 ;
            fo(i , children.size())
            {
                ll v = children[i] ;
                ans += dp[v][jd][0] ;
                v1.pub(dp[v][jd][1] - dp[v][jd][0]) ;
            }
            sort(v1.begin() , v1.end()) ;
            get_min(v1 , parity , ans) ;
 
            ans = min(ans + k , inf) ;
            dp[u][j][k] = ans ;
        }
    }
    return ;
}
 
void solve()
{
    ll n ;
    cin >> n ;
    ll col[n] ;
    fo(i , n)
    {
        cin >> col[i] ;
        fo(j , 4)
        {
            dp[i][j/2][j%2] = inf ;
        }
    }
 
    vector<ll> adj[n] ;
    fo(i , n-1)
    {
        ll u , v ;
        cin >> u >> v ;
        u-- ; v-- ;
        adj[u].pub(v) ;
        adj[v].pub(u) ;
    }
 
    dfs(adj , col, 0 , -1) ;
    ll ans = min(dp[0][0][0] , dp[0][0][1]) ;
    if(ans >= inf)
    {
        cout << -1 << endl ;
    }
    else
        cout << ans << endl ;
    return ;
}
 
int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("errorf.txt" , "w" , stderr) ; 
    #endif
    
    ll t ;
    cin >> t ;
    while(t--)
        solve() ;
    return 0 ;
}

Tester's solution
#include <bits/stdc++.h>
using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}
 
void readEOF(){
    assert(getchar()==EOF);
}
// f = 0 means even, f = 1 means odd
int solve(vector<int> tree[], int v, int p, int dp[][2][2], int x, int y, int a, int b, int f){
    // {x, y} even remaining from {a, b}
    // min dp[j][x][y] + dp[j][a][b] -> all(dp[j][a][b]) + (dp[j][x][y] - dp[j][a][b])(even/odd)
    long long sum = 0;
    vector<int> vec;
    for(auto j : tree[v])
        if(j != p)
            sum += dp[j][a][b], vec.push_back(dp[j][x][y] - dp[j][a][b]);
    sort(vec.begin(), vec.end());
    reverse(vec.begin(), vec.end());
    if(f == 1){
        int u = vec.back();
        vec.pop_back();
        sum += u;
    }
    while(vec.size() >= 2){
        int u = vec.back();
        vec.pop_back();
        int v = vec.back();
        vec.pop_back();
        if(u + v >= 0)
            break;
        sum += (u + v);
    }
    sum = min(sum, (long long)1e8);
    return sum;
}
// dp[node][final_state][is_toggled] -> all subtrees are black but this node can be white/black
void dfs(vector<int> tree[], int v, int p, int dp[][2][2], vector<int> &vec){
    int leaf = 1;
    for(auto j : tree[v])
        if(j != p){
            leaf = 0;
            dfs(tree, j, v, dp, vec);
        }
    if(leaf){
        if(vec[v] == 0) // black
            dp[v][0][0] = 0, dp[v][1][1] = 1;
        else
            dp[v][1][0] = 0, dp[v][0][1] = 1;
    }
    else{
        if(vec[v] == 0){
            // dp[v][0][0], dp[v][1][1] = even children nodes toggled
            // dp[v][1][0], dp[v][0][1] = odd children nodes toggled
            
            // dp[v][0][0] -> dp[j][0][1] (even) + dp[j][0][0] (remaining)
            dp[v][0][0] = min(dp[v][0][0], solve(tree, v, p, dp, 0, 1, 0, 0, 0));
            // dp[v][1][1] -> dp[j][1][1] (even) + dp[j][1][0] (remaining)
            dp[v][1][1] = min(dp[v][1][1], 1 + solve(tree, v, p, dp, 1, 1, 1, 0, 0));
            // dp[v][1][0] -> dp[j][0][1] (odd) + dp[j][0][0] (remaining)
            dp[v][1][0] = min(dp[v][1][0], solve(tree, v, p, dp, 0, 1, 0, 0, 1));
            // dp[v][0][1] -> dp[j][1][1] (odd) + dp[j][1][0] (remaining)
            dp[v][0][1] = min(dp[v][0][1], 1 + solve(tree, v, p, dp, 1, 1, 1, 0, 1));
        }
        else{
            // dp[v][0][0], dp[v][1][1] = odd children nodes toggled
            // dp[v][1][0], dp[v][0][1] = even children nodes toggled
            
            // dp[v][0][0] -> dp[j][0][1] (odd) + dp[j][0][0] (remaining)
            dp[v][0][0] = min(dp[v][0][0], solve(tree, v, p, dp, 0, 1, 0, 0, 1));
            // dp[v][1][1] -> dp[j][1][1] (odd) + dp[j][1][0] (remaining)
            dp[v][1][1] = min(dp[v][1][1], 1 + solve(tree, v, p, dp, 1, 1, 1, 0, 1));
            // dp[v][1][0] -> dp[j][0][1] (even) + dp[j][0][0] (remaining)
            dp[v][1][0] = min(dp[v][1][0], solve(tree, v, p, dp, 0, 1, 0, 0, 0));
            // dp[v][0][1] -> dp[j][1][1] (even) + dp[j][1][0] (remaining)
            dp[v][0][1] = min(dp[v][0][1], 1 + solve(tree, v, p, dp, 1, 1, 1, 0, 0));
        }
    }
}
int main() {
	// your code goes here
	int t = readIntLn(1, 5e4);
	int sum = 0;
	while(t--){
	    int n = readIntLn(1, 1e5);
	    sum += n;
	    assert(sum <= 5e5);
	    vector<int> vec(n + 1);
	    for(int i = 1 ; i <= n ; i++){
	        if(i == n)
	            vec[i] = readIntLn(0, 1);
	        else
	            vec[i] = readIntSp(0, 1);
	    }
	    vector<int> tree[n + 1];
	    for(int i = 1 ; i < n ; i++){
	        int x, y;
	        x = readIntSp(1, n);
	        y = readIntLn(1, n);
	        tree[x].push_back(y);
	        tree[y].push_back(x);
	    }
	    int dp[n + 1][2][2];
	    for(int i = 0; i <= n ; i++)
	        dp[i][0][0] = dp[i][0][1] = dp[i][1][0] = dp[i][1][1] = 1e8;
	    dfs(tree, 1, 0, dp, vec);
	    int ans = min(dp[1][0][0], dp[1][0][1]);
	    if(ans == 1e8)
	        ans = -1;
	    cout << ans << '\n';
	}
	readEOF();
	return 0;
}

        

Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:

2 Likes

@ajit123q “herefore, the dp transitions for vertex uu being a leaf are dp[u][color[u]][0]=0 and dp[u][color[u]][1] = 1.” Should not it be “color[u] xor 1” ?

Thanks, fixed!