PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Arun Sharma
Tester: Abhinav Sharma, Lavish Gupta
Editorialist: Devendra Singh
DIFFICULTY:
2601
PREREQUISITES:
Depth first search, Dynamic programming, Trees
PROBLEM:
Arun has a rooted tree of N vertices rooted at vertex 1. Each vertex can either be coloured black or white.
Initially, the vertices are coloured A_1, A_2, \ldots A_N, where A_i \in \{0, 1\} denotes the colour of the i-th vertex (here 0 represents white and 1 represents black). He wants to perform some operations to change the colouring of the vertices to B_1, B_2, \ldots B_N respectively.
Arun can perform the following operation any number of times. In one operation, he can choose any subtree and either paint all its vertices white or all its vertices black.
Help Arun find the minimum number of operations required to change the colouring of the vertices to B_1, B_2, \ldots B_N respectively.
EXPLANATION:
This problem can be solved using dynamic programming as it has optimal substructure and overlapping subproblems.
The tree is rooted at node 1.
Let black_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u when the complete subtree of node u is painted black with an operation.
Let white_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u when the complete subtree of node u is painted white with an operation.
Let none_u represent the minimum number of steps needed to make A_i=B_i for each node i in the subtree of node u without doing any operation on u.
Initialize each of the value black_u, white_u\: and\: none_u for all u where 1\leq u\leq N with 0
Start the dfs at node 1. Let us suppose we are at some node u during the dfs traversal then for this node we have three values :
none_u+=((A_u!=B_u)?INF:min(1+black_x,1+white_x,none_x) over all children x of node u
We can choose to colour any child’s subtree either black, white or leave the child untouched whichever gives minimum number of operations we add it to none_u.
black_u+=((!B_u)?1+\sum white_x : \sum black_x) over all children x of node u
If in the end we need the colour of node u as white we need to first colour the whole subtree of u white otherwise we can just leave the node u untouched and calculate the answer for black_x for all children x of node u and add them to black_u.
white_u+=((B_u)?1+\sum black_x : \sum white_x) over all children x of node u
If in the end we need the colour of node u as black we need to first colour the whole subtree of u black otherwise we can just leave the node u untouched and calculate the answer for white_x for all children x of node u and add them to white_u.
The answer to the problem is min(none_1,1+black_1,1+white_1);
TIME COMPLEXITY:
O(N) for each test case.
SOLUTION:
Setter's solution
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define ordered_set tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long int
#define ld long double
#define forn(i, x, n) for (ll i = x; i < n; i++)
#define fornb(i, n, x) for (ll i = n; i >=x; i--)
#define all(x) x.begin(), x.end()
#define pii pair<ll, ll>
#define MOD 1000000007
#define MAX 300007
#define endl "\n" // REMOVE in lleraction problem
#define debug cout << "K"
vector<ll> visited(MAX), color(MAX), dist(MAX, -1);
vector<ll> graph[MAX];
vector<ll> parent(MAX);
vector<pii> graph2[MAX];
vector<ll> A(MAX) , B(MAX);
ll dp[MAX][3];
ll visited2[MAX][3];
//state 0-> black , 1->white , 2->none
ll dfs(ll node , ll state , ll p)
{
ll val =A[node];
if(state!=2)
val =state;
if((state!=2 && val!=B[node]) || (state==2 && A[node]!=B[node]))
{
dp[node][state] = LLONG_MAX;
return LLONG_MAX;
}
visited2[node][state] =1;
ll ans = 0;
if(state==1)
{
for(auto child : graph[node])
{
ll tmp = LLONG_MAX;
if(child==p)
continue;
if(B[child]==0)
{
if(visited2[child][0]){}
else
dp[child][0] = dfs(child , 0 , node);
tmp = min(tmp ,dp[child][0]);
}
if(B[child]==1)
{
if(visited2[child][1]==1){}
else
dp[child][1] = dfs(child , 1 , node);
tmp = min(tmp ,dp[child][1] -1);
}
if(tmp!=LLONG_MAX)
ans+=tmp;
}
dp[node][state] = ans+1;
return dp[node][state];
}
else
if(state==0){
for(auto child : graph[node])
{
ll tmp = LLONG_MAX;
if(child==p)
continue;
if(B[child]==0)
{
if(visited2[child][0]==1){}
else
dp[child][0] = dfs(child , 0 , node);
tmp = min(tmp ,dp[child][0]-1);
}
if(B[child]==1)
{
if(visited2[child][1]){}
else
dp[child][1] = dfs(child , 1 , node);
tmp = min(tmp ,dp[child][1]);
}
if(tmp!=LLONG_MAX)
ans+=tmp;
}
dp[node][state] = ans+1;
return dp[node][state];
}
else
{
if(A[node]!=B[node])
{
dp[node][2] = LLONG_MAX;
return dp[node][2];
}
for(auto child : graph[node])
{
ll tmp = LLONG_MAX;
if(child==p){
continue;
}
if(B[child]==0 && A[child]==1)
{
ll a = LLONG_MAX;
if(visited2[child][0]==1){
a = dp[child][0];
}
else
{
dp[child][0] = dfs(child , 0 , node);
a = dp[child][0];
}
tmp = min(tmp ,a);
}
if(B[child]==1 && A[child] == 0)
{
ll a = LLONG_MAX;
if(visited2[child][1]==1){
a = dp[child][1];
}
else{
dp[child][1] = dfs(child , 1 , node );
a = dp[child][1];
}
tmp = min(tmp ,a);
}
if(B[child]==A[child])
{
ll a = LLONG_MAX;
if(visited2[child][2]==1){
}
else
dp[child][2] = dfs(child , 2 , node);
tmp = min(tmp ,dp[child][2]);
if(visited2[child][B[child]]==1){
a = dp[child][B[child]];
}
else{
dp[child][B[child]] = dfs(child , B[child] , node);
a = dp[child][B[child]];
}
tmp = min(tmp ,a);
}
if(tmp!=LLONG_MAX)
ans+=tmp;
}
return dp[node][state] = ans;
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
ll t=1;
cin>>t;
while (t--)
{
ll n;
cin>>n;
forn(i , 0 , n)
cin>>A[i];
forn(i, 0, n)
cin>>B[i];
forn(i ,0,n-1)
{
ll a , b;
cin>>a>>b;
a--;
b--;
graph[a].push_back(b);
graph[b].push_back(a);
}
forn(i ,0, n+1)
forn(j ,0 ,3)
{
dp[i][j] = LLONG_MAX;
visited2[i][j] = 0;}
dfs(0 ,0 ,0 ); dfs(0 , 1 , 0); dfs(0 , 2 ,0);
cout<<min(dp[0][1], min(dp[0][0] ,dp[0][2]))<<endl;
forn(i, 0 ,n+1)
graph[i].clear();
}
}
Tester-1's Solution
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;
using ii = pair<ll,ll>;
vector<int> a,b;
vector<vector<int> > g;
vector<vector<int> > dp;
int cnt = 0;
void ch_tree(int c, int p){
cnt++;
for(auto h:g[c]){
if(h!=p) ch_tree(h,c);
}
}
void dfs(int c, int p){
for(auto h:g[c]){
if(h!=p) dfs(h,c);
}
if(a[c]!=b[c]) dp[c][2] = 1e7;
for(auto h:g[c]){
if(h!=p){
int tmp;
if(b[c]) tmp = dp[h][1];
else tmp = dp[h][0];
dp[c][0]+=tmp;
dp[c][1]+=tmp;
dp[c][2] += min({dp[h][0]+1, dp[h][1]+1, dp[h][2]});
}
}
if(b[c]) dp[c][0]++;
else dp[c][1]++;
}
void solve(){
int n = readIntLn(1,3e5);
sum_n+=n;
a.resize(n), b.resize(n);
rep(i,n){
if(i<n-1) a[i] = readIntSp(0,1);
else a[i] = readIntLn(0,1);
}
rep(i,n){
if(i<n-1) b[i] = readIntSp(0,1);
else b[i] = readIntLn(0,1);
}
g.assign(n, vector<int>());
dp.assign(n, vector<int>(3, 0));
int x,y;
rep(i,n-1){
x = readIntSp(1,n);
y = readIntLn(1,n);
x--, y--;
g[x].pb(y);
g[y].pb(x);
}
cnt = 0;
ch_tree(0,-1);
assert(cnt==n);
dfs(0,-1);
dp[0][0]++;
dp[0][1]++;
cout<<min({dp[0][0], dp[0][1], dp[0][2]})<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,2e4);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
assert(sum_n<=3e5);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
//cerr<<"Sum of lengths : " << sum_n <<" "<<sum_m<<'\n';
//cerr<<"Maximum answer : " << max_n <<'\n';
// // cerr<<"Total operations : " << total_ops << '\n';
// cerr<<"Answered yes : " << yess << '\n';
// cerr<<"Answered no : " << nos << '\n';
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester-2's Solution
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
#define tll tuple<ll ,ll , ll>
#define pll pair<ll ,ll>
#include<bits/stdc++.h>
/*#include<iomanip>
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>*/
dd pi = acos(-1) ;
ll z = 1000000007 ;
ll inf = 1e12 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 = 202976689 ;
ll mod2 = 203034253 ;
ll fact[100] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ; return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 , g; g = e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return g ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val) no. of elements strictly less than val
// s.find_by_order(i) itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
ll seed;
mt19937 rnd(seed=chrono::steady_clock::now().time_since_epoch().count()); // include bits
ll get_ans(vector<ll> &v, ll final)
{
ll ans = v[2] ;
ans = min(ans , 1 + v[final]) ;
return ans ;
}
void dfs(vector<ll> adj[] , vector<vector<ll> > &dp , vector<ll> &a, vector<ll> &b, ll u, ll p)
{
ll c = 0 ;
dp[u] = {0 , 0 , 0} ;
if(a[u] != b[u])
dp[u][2] = inf ;
fo(i , adj[u].size())
{
ll v = adj[u][i] ;
if(v == p)
continue ;
c++ ;
dfs(adj , dp , a , b , v , u);
dp[u][0] += dp[v][0] ;
dp[u][1] += dp[v][1] ;
dp[u][2] += get_ans(dp[v] , b[v]) ;
}
fo(i , 3)
dp[u][i] = min(dp[u][i] , inf) ;
if(c != 0)
{
dp[u][1-b[u]] = 1 + dp[u][b[u]] ;
return ;
}
if(c == 0)
{
if(a[u] == b[u])
{
dp[u][1 - a[u]] = 1 ;
}
else
{
dp[u][2] = inf ;
dp[u][1 - b[u]] = 1 ;
}
return ;
}
return ;
}
void solve()
{
ll n ;
cin >> n ;
vector<ll> a(n) , b(n) ;
fo(i , n)
cin >> a[i] ;
fo(i , n)
cin >> b[i] ;
vector<ll> adj[n] ;
fo(i , n-1)
{
ll u , v ;
cin >> u >> v ;
u-- ; v-- ;
adj[u].pub(v) ;
adj[v].pub(u) ;
}
vector<vector<ll> > dp(n , vector<ll> (3)) ;
// dp[i][0] represents min number of moves if the complete subtree of i is white
// dp[i][1] represents min number of moves if the complete subtree of i is black
// dp[i][2] represents min number of moves if no operation is done on i^th node
dfs(adj , dp , a , b , 0 , -1) ;
// fo(i , n)
// {
// cout << i << ": " << dp[i][0] << ' ' << dp[i][1] << ' ' << dp[i][2] << endl ;
// }
ll ans = get_ans(dp[0] , b[0]) ;
cout << ans << endl ;
return ;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#ifndef ONLINE_JUDGE
freopen("inputf.txt" , "r" , stdin) ;
freopen("outputf.txt" , "w" , stdout) ;
freopen("errorf.txt" , "w" , stderr) ;
#endif
ll t = 1;
cin >> t ;
while(t--)
{
solve() ;
}
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
return 0;
}
Editorialist's solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e18;
const int N = 3e5 + 11, mod = 1e9 + 7;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int a[N], b[N];
long long black[N], white[N], none[N];
vll v[N];
void dfs(int u, int p)
{
if (b[u])
white[u]++;
else
black[u]++;
for (auto x : v[u])
{
if (x == p)
continue;
dfs(x, u);
none[u] += min(min(none[x], 1 + white[x]), 1 + black[x]);
if (b[u])
{
black[u] += black[x];
white[u] += black[x];
}
if (!b[u])
{
white[u] += white[x];
black[u] += white[x];
}
}
if (a[u] != b[u])
none[u] = 1e9;
return;
}
void sol(void)
{
int n;
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i], v[i].clear(), black[i] = white[i] = none[i] = 0;
for (int i = 1; i <= n; i++)
cin >> b[i];
for (int i = 1; i <= n - 1; i++)
{
int x, y;
cin >> x >> y;
v[x].pb(y);
v[y].pb(x);
}
dfs(1, -1);
cout << min(min(1 + black[1], none[1]), 1 + white[1]) << '\n';
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL), cout.tie(NULL);
int test = 1;
cin >> test;
while (test--)
sol();
}