WBLACK-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Arun Sharma
Tester: Abhinav Sharma, Lavish Gupta
Editorialist: Devendra Singh

DIFFICULTY:

2601

PREREQUISITES:

Depth first search, Dynamic programming, Trees

PROBLEM:

Arun has a rooted tree of N vertices rooted at vertex 1. Each vertex can either be coloured black or white.

Initially, the vertices are coloured A_1, A_2, \ldots A_N, where A_i \in \{0, 1\} denotes the colour of the i-th vertex (here 0 represents white and 1 represents black). He wants to perform some operations to change the colouring of the vertices to B_1, B_2, \ldots B_N respectively.

Arun can perform the following operation any number of times. In one operation, he can choose any subtree and either paint all its vertices white or all its vertices black.

Help Arun find the minimum number of operations required to change the colouring of the vertices to B_1, B_2, \ldots B_N respectively.

EXPLANATION:

This problem can be solved using dynamic programming as it has optimal substructure and overlapping subproblems.
The tree is rooted at node 1.
Let black_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u when the complete subtree of node u is painted black with an operation.
Let white_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u when the complete subtree of node u is painted white with an operation.
Let none_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u without doing any operation on u.
Initialize each of the value black_u, white_u\: and\: none_u for all u where 1\leq u\leq N with 0

Start the dfs at node 1. Let us suppose we are at some node u during the dfs traversal then for this node we have three values :
none_u+=((A_u!=B_u)?INF:min(1+black_x,1+white_x,none_x) over all children x of node u
We can choose to colour any child’s subtree either black, white or leave the child untouched whichever gives minimum number of operations we add it to none_u.
black_u+=((!B_u)?1+\sum white_x : \sum black_x) over all children x of node u
If in the end we need the colour of node u as white we need to first colour the whole subtree of u white otherwise we can just leave the node u untouched and calculate the answer for black_x for all children x of node u and add them to black_u.
white_u+=((B_u)?1+\sum black_x : \sum white_x) over all children x of node u
If in the end we need the colour of node u as black we need to first colour the whole subtree of u black otherwise we can just leave the node u untouched and calculate the answer for white_x for all children x of node u and add them to white_u.
The answer to the problem is min(none_1,1+black_1,1+white_1);

TIME COMPLEXITY:

O(N) for each test case.

SOLUTION:

Setter's solution
#include <bits/stdc++.h>


using namespace std;

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;

#define ordered_set tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update>

#define ll long long int

#define ld long double
#define forn(i, x, n) for (ll i = x; i < n; i++)
#define fornb(i, n, x) for (ll i = n; i >=x; i--)
#define all(x) x.begin(), x.end()
#define pii pair<ll, ll>
#define MOD 1000000007
#define MAX 300007
#define endl "\n" // REMOVE in lleraction problem
#define debug cout << "K"
vector<ll> visited(MAX), color(MAX), dist(MAX, -1);
vector<ll> graph[MAX];
vector<ll> parent(MAX);
vector<pii> graph2[MAX];


vector<ll> A(MAX) , B(MAX);
ll dp[MAX][3];
ll visited2[MAX][3];
//state 0-> black  , 1->white , 2->none
ll dfs(ll node , ll state  ,  ll p)
{

    
    ll val =A[node];
    if(state!=2)
    val =state;
    if((state!=2 && val!=B[node]) || (state==2 && A[node]!=B[node]))
    {
        dp[node][state] = LLONG_MAX;
        return LLONG_MAX;
    }
    visited2[node][state] =1;
    ll ans = 0;
    if(state==1)
    {
    for(auto child : graph[node])
    {
        ll tmp = LLONG_MAX;
        if(child==p)
        continue;
            if(B[child]==0)
            {
                if(visited2[child][0]){}
                else
                dp[child][0] = dfs(child , 0 , node);
                tmp = min(tmp ,dp[child][0]);
            }
            if(B[child]==1)
            {
                if(visited2[child][1]==1){}
                else
                dp[child][1] = dfs(child , 1 , node);
                tmp = min(tmp ,dp[child][1] -1);
            }
        if(tmp!=LLONG_MAX)
        ans+=tmp;
        
    }
    dp[node][state] = ans+1;
    return dp[node][state];
    }
    else
    if(state==0){
    for(auto child : graph[node])
    {
        ll tmp = LLONG_MAX;

        if(child==p)
        continue;
            if(B[child]==0)
            {
                if(visited2[child][0]==1){}
                else
                dp[child][0] = dfs(child , 0 , node);
                tmp = min(tmp ,dp[child][0]-1);
            }
            if(B[child]==1)
            {
                if(visited2[child][1]){}
                else
                dp[child][1] = dfs(child , 1 , node);
                tmp = min(tmp ,dp[child][1]);
            }

        if(tmp!=LLONG_MAX)
        ans+=tmp;
    }
    dp[node][state] = ans+1;
    return dp[node][state];
    }
    else
    {
        if(A[node]!=B[node])
        {
            dp[node][2] = LLONG_MAX;
            return dp[node][2];
        }
        for(auto child : graph[node])
        {
        ll tmp = LLONG_MAX;
        if(child==p){
        continue;
        }
            if(B[child]==0 && A[child]==1)
            {
                ll a = LLONG_MAX;
                if(visited2[child][0]==1){
                    a = dp[child][0];
                }
                else
                {
                dp[child][0] = dfs(child , 0 , node);
                a = dp[child][0];
                }
                tmp = min(tmp ,a);
            }
            if(B[child]==1 && A[child] == 0)
            {
                ll a = LLONG_MAX;
                if(visited2[child][1]==1){
                    a = dp[child][1];
                }
                else{
                dp[child][1] = dfs(child , 1 , node );
                a = dp[child][1];
                }
                tmp = min(tmp ,a);
            }
            if(B[child]==A[child])
            {
                ll a = LLONG_MAX;
                if(visited2[child][2]==1){
                }
                else
                dp[child][2] = dfs(child , 2 , node);
                tmp = min(tmp ,dp[child][2]);
                if(visited2[child][B[child]]==1){
                    a = dp[child][B[child]];
                }
                else{
                dp[child][B[child]] = dfs(child , B[child] , node);
                a = dp[child][B[child]];
                }
                tmp = min(tmp ,a);
            }
        if(tmp!=LLONG_MAX)
        ans+=tmp;
    }
    return dp[node][state] = ans;
    }
}

int main()
{

    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    ll t=1;
    cin>>t;
    while (t--)
    {
        ll n;
        cin>>n;
        
        forn(i , 0 , n)
        cin>>A[i];
        forn(i, 0, n)
        cin>>B[i];  
        forn(i ,0,n-1)
        {
            ll a , b;
            cin>>a>>b;
            a--;
            b--;

            graph[a].push_back(b);
            graph[b].push_back(a);
        }
        forn(i ,0, n+1)
        forn(j  ,0 ,3)
        {
            dp[i][j] = LLONG_MAX;
            visited2[i][j] = 0;}

        dfs(0 ,0 ,0 ); dfs(0 , 1  , 0); dfs(0 , 2  ,0);
        cout<<min(dp[0][1],  min(dp[0][0] ,dp[0][2]))<<endl;
        forn(i, 0 ,n+1)
        graph[i].clear();
    }
}
Tester-1's Solution
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
 
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;

using ii = pair<ll,ll>;
vector<int> a,b;
vector<vector<int> > g;
vector<vector<int> > dp;

int cnt = 0;
void ch_tree(int c, int p){
    cnt++;
    for(auto h:g[c]){
        if(h!=p) ch_tree(h,c);
    }
}

void dfs(int c, int p){
    for(auto h:g[c]){
        if(h!=p) dfs(h,c);
    }

    if(a[c]!=b[c]) dp[c][2] = 1e7;

    for(auto h:g[c]){
        if(h!=p){
            int tmp;
            if(b[c]) tmp = dp[h][1];
            else tmp = dp[h][0];

            dp[c][0]+=tmp;
            dp[c][1]+=tmp;

            dp[c][2] += min({dp[h][0]+1, dp[h][1]+1, dp[h][2]});
        }
    }

    if(b[c]) dp[c][0]++;
    else dp[c][1]++;
}

void solve(){
    int n = readIntLn(1,3e5);

    sum_n+=n;

    a.resize(n), b.resize(n);
    rep(i,n){
        if(i<n-1) a[i] = readIntSp(0,1);
        else a[i] = readIntLn(0,1);
    }

    rep(i,n){
        if(i<n-1) b[i] = readIntSp(0,1);
        else b[i] = readIntLn(0,1);
    }

    g.assign(n, vector<int>());
    dp.assign(n, vector<int>(3, 0));

    int x,y;

    rep(i,n-1){
        x = readIntSp(1,n);
        y = readIntLn(1,n);
        x--, y--;

        g[x].pb(y);
        g[y].pb(x);
    }

    cnt = 0;
    ch_tree(0,-1);
    assert(cnt==n);

    dfs(0,-1);
    dp[0][0]++;
    dp[0][1]++;
    cout<<min({dp[0][0], dp[0][1], dp[0][2]})<<'\n';


}


 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,2e4);
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    assert(getchar() == -1);
    assert(sum_n<=3e5);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    //cerr<<"Sum of lengths : " << sum_n <<" "<<sum_m<<'\n';
    //cerr<<"Maximum answer : " << max_n <<'\n';
    // // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';

    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

Tester-2's Solution
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
#define tll tuple<ll ,ll , ll> 
#define pll pair<ll ,ll> 
#include<bits/stdc++.h>
/*#include<iomanip>   
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>*/
dd pi = acos(-1) ;
ll z =  1000000007 ;
ll inf = 1e12 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 =  202976689 ;
ll mod2 =  203034253 ;
ll fact[100] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ;    return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 , g; g = e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return g ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
ll seed;
mt19937 rnd(seed=chrono::steady_clock::now().time_since_epoch().count()); // include bits

ll get_ans(vector<ll> &v, ll final)
{
    ll ans = v[2] ;
    ans = min(ans , 1 + v[final]) ;
    return ans ;
}


void dfs(vector<ll> adj[] , vector<vector<ll> > &dp , vector<ll> &a, vector<ll> &b, ll u, ll p)
{
    ll c = 0 ;

    dp[u] = {0 , 0 , 0} ;
    if(a[u] != b[u])
        dp[u][2] = inf ;
    fo(i , adj[u].size()) 
    {
        ll v = adj[u][i] ;
        if(v == p)
            continue ;
        c++ ;
        dfs(adj , dp , a , b , v , u);

        dp[u][0] += dp[v][0] ;
        dp[u][1] += dp[v][1] ;
        dp[u][2] += get_ans(dp[v] , b[v]) ;

    }
    fo(i , 3)
        dp[u][i] = min(dp[u][i] , inf) ;

    if(c != 0)
    {
        dp[u][1-b[u]] = 1 + dp[u][b[u]] ;
        return ;
    }

    if(c == 0)
    {
        if(a[u] == b[u])
        {
            dp[u][1 - a[u]] = 1 ;
        }
        else
        {
            dp[u][2] = inf ;
            dp[u][1 - b[u]] = 1 ;
        }
        return ;
    }
    return ;
}


void solve()
{
    ll n ;
    cin >> n ;
    vector<ll> a(n) , b(n) ;
    fo(i , n)
        cin >> a[i] ;
    fo(i , n)
        cin >> b[i] ;

    vector<ll> adj[n] ;
    fo(i , n-1)
    {
        ll u , v ;
        cin >> u >> v ;
        u-- ; v-- ;
        adj[u].pub(v) ;
        adj[v].pub(u) ;
    }

    vector<vector<ll> > dp(n , vector<ll> (3)) ;
    // dp[i][0] represents min number of moves if the complete subtree of i is white
    // dp[i][1] represents min number of moves if the complete subtree of i is black
    // dp[i][2] represents min number of moves if no operation is done on i^th node

    dfs(adj , dp , a , b , 0 , -1) ;

    // fo(i , n)
    // {
    //     cout << i << ": " << dp[i][0] << ' ' << dp[i][1] << ' ' << dp[i][2] << endl ;
    // }

    ll ans = get_ans(dp[0] , b[0]) ;
    cout << ans << endl ;
    return ;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("errorf.txt" , "w" , stderr) ;
    #endif
    
    ll t = 1;
    cin >> t ;
   
    while(t--)
    {
        solve() ;
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
 
    return 0;
}
Editorialist's solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e18;
const int N = 3e5 + 11, mod = 1e9 + 7;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int a[N], b[N];
long long black[N], white[N], none[N];
vll v[N];
void dfs(int u, int p)
{
    if (b[u])
        white[u]++;
    else
        black[u]++;
    for (auto x : v[u])
    {
        if (x == p)
            continue;
        dfs(x, u);
        none[u] += min(min(none[x], 1 + white[x]), 1 + black[x]);
        if (b[u])
        {
            black[u] += black[x];
            white[u] += black[x];
        }
        if (!b[u])
        {
            white[u] += white[x];
            black[u] += white[x];
        }
    }
    if (a[u] != b[u])
        none[u] = 1e9;
    return;
}
void sol(void)
{
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i], v[i].clear(), black[i] = white[i] = none[i] = 0;
    for (int i = 1; i <= n; i++)
        cin >> b[i];
    for (int i = 1; i <= n - 1; i++)
    {
        int x, y;
        cin >> x >> y;
        v[x].pb(y);
        v[y].pb(x);
    }
    dfs(1, -1);
    cout << min(min(1 + black[1], none[1]), 1 + white[1]) << '\n';
    return;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    int test = 1;
    cin >> test;
    while (test--)
        sol();
}

1 Like

https://www.codechef.com/viewsolution/65935551
can anyone help me out what i have done wrong or missed something

Fails for this:

1
1
0
1
2 Likes

Gotcha .I was not checking for root :smiling_face_with_tear: Thanks

Link to editorial is not added on this problem page.

2 Likes

black_ u+=((B_u) ?1+ ∑white_x ​ : ∑black_x) over all children x of node u
If in the end we need the colour of node u as white we need to first colour the whole subtree of u white otherwise we can just leave the node u untouched and calculate the answer for black_x for all children x of node u and add them to black_u

Please explain this @devendra7700 , @cubefreak777 its not clear to me, would be a great help.

.

3 Likes

In this case we are trying to calculate the answer for minimum number of operations needed for subtree of u to look exactly the same as needed in the final tree assuming that the whole subtree was first painted black by some operation. Now since every node is black in the subtree of u, if we need u to be white in the final tree we need to first colour the subtree of u white using an operation and then add the answer for all child nodes assuming they are all painted white. If we need u to be black we can just skip any operation on u as it is already black and add the answer for children of u assuming they are all painted black.

3 Likes

Now its crystal clear thanks a lot bro :grinning:

What’s wrong with my solution? Anyone??

https://www.codechef.com/viewsolution/65925146

@devendra7700 @cubefreak777

Could someone please correct me where my approach will fail.

My approach:-
For all the nodes where a[u] != b[u] we need to change the color once, except if this node belongs to a subtree where all the nodes are to be changed to the same color, and the head of this subtree also has the same b[i].

https://www.codechef.com/viewsolution/65969072

Try this:

1
4
0 1 1 1
0 0 1 0
2 1
1 4
3 4

Correct output is 2

1 Like

Try this:

1
5
1 0 0 0 0
1 1 0 1 1
2 1
1 4
4 3
3 5

Correct output is 3

1 Like

Does Bu = 1 represent white or black @devendra7700

1 Like

image

Here, if the Bu = 1, which should mean the final state is black, So shouldn’t we be summing the answer for when all children are turned black. Why are we doing “1+∑whitex” if Bu = 1.

P.S : Sorry, if am being silly. I am really confused.

1 Like

Thanks, updated

1 Like

my solution
#include <bits/stdc++.h>

#define ull unsigned long long int

#define ui unsigned int

#define ll long long int

#define f(i, n) for (ll i = 0; i < n; i++)

const ll m = 10e9 + 7;

// const ll N = 10e5 + 5;

using namespace std;

ll binpow(ll a, ll b, ll m)

{

a %= m;

ll res = 1;

while (b > 0)

{

    if (b & 1)

        res = res * a % m;

    a = a * a % m;

    b >>= 1;

}

return res;

}

void dfs(vector edges[], vector &visited, vector &dp, vector &dpw, vector &dpb, int curr, vector &arr1, vector &arr2)

{

if (visited[curr])

    return;

visited[curr] = true;

int n = edges[curr].size();

for (int i = 0; i < n; i++)

{

    dfs(edges, visited, dp, dpw, dpb, edges[curr][i], arr1, arr2);

}

ll black = 0;

ll white = 0;

f(i, n) black += dpb[edges[curr][i]];

f(i, n) white += dpw[edges[curr][i]];

if (arr2[curr] == 1)

{

    dpb[curr] += black;

    dpw[curr] += 1;

    dpw[curr] += black;

}

else

{

    dpb[curr] += 1;

    dpb[curr] += white;

    dpw[curr] += white;

}

if (arr1[curr] == arr2[curr])

{

    f(i, n)

        dp[curr] += dp[edges[curr][i]];

    dp[curr] = min(dp[curr], 1 + min(dpw[curr], dpb[curr]));

}

else if (arr2[curr] == 1)

{

    dp[curr] += 1 + black;

}

else

{

    dp[curr] += 1 + white;

}

return;

}

int main()

{

ios_base::sync_with_stdio(0);

cin.tie(0);

ll t;

cin >> t;

while (t--)

{

    ll n;

    cin >> n;

    vector<ll> arr1(n), arr2(n);

    vector<ll> edges[n];

    f(i, n) cin >> arr1[i];

    f(i, n) cin >> arr2[i];

    f(i, n - 1)

    {

        int u, v;

        cin >> u >> v;

        u--;

        v--;

        edges[u].push_back(v);

        edges[v].push_back(u);

    }

    vector<ll> dp(n, 0);

    vector<ll> dpw(n, 0);

    vector<ll> dpb(n, 0);

    vector<bool> visited(n, 0);

    dfs(edges, visited, dp, dpw, dpb, 0, arr1, arr2);

    cout << dp[0] << '\n';

}

return 0;

}

https://www.codechef.com/viewsolution/66028146

I followed editorial,
still getting TLE.
Can anyone help?

 
 void dfs(int v, vi adj[],vl & black, vl & white, vl & none, vl & vis, int A[],int  B[]){
     
     vis[v]=1;
     ll sumbx=0;
     ll sumwx=0;
     ll sumnx=0;
     ll summinch=0; 
     for(auto i : adj[v]){
         if(vis[i]==0){
             dfs(i,adj,black,white,none,vis,A,B);
             sumbx+=black[i];
             sumwx+=white[i];
             sumnx+=none[i];
             summinch+=min({1+black[i],1+white[i],none[i]});
         }   
     }
     none[v]=(A[v]!=B[v])? (int)1e9 : summinch;  // if not equal then inf since no operation , else min of each child added.
     black[v]= (B[v]==0)?  1+sumwx : sumbx ;//if reqd to be white then 1 + sum of all child white, else sum of all child black;
     white[v]= (B[v]==1)?  1 +sumbx : sumwx ;
 }
 
 
int main(){
    int t;
    cin>>t;
    while(t--){
      int n;
      cin>>n;
      int A[n],B[n];
      for (int i = 0; i < n; i++)
      {
          cin>>A[i];
      }
      for (int i = 0; i < n; i++)
      {
          cin>>B[i];
      }
      vi adj[n];
      for (int i = 0; i < n-1; i++)
      {
          int x,y;
          cin>>x>>y; x--;y--;
          adj[x].pb(y);
          adj[y].pb(x);  
      }

      vl black(n,0);
      vl white(n,0);
      vl none(n,0);
      vl vis(n,0);
      dfs(0,adj,black,white, none,vis,A,B);
      
      cout<<min({none[0],1+black[0],1+white[0]})<<"\n";
      
    }
 
  return 0;
  }

Simple solution easy to understand.

Can you please tell me where am I going wrong
https://www.codechef.com/viewsolution/66050471