MINCOST-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setters: Kshitij Sodani
Tester: Manan Grover, Abhinav Sharma
Editorialist: Devendra Singh

DIFFICULTY:

2694

PREREQUISITES:

Depth-first Search, Dynamic Programing

PROBLEM:

You are a given a tree of N vertices. The i-th vertex has two things assigned to it — a value A_i and a range [L_i, R_i]. It is guaranteed that L_i \leq A_i \leq R_i.

You perform the following operation exactly once:

  • Pick a (possibly empty) subset S of vertices of the tree. S must satisfy the additional condition that no two vertices in S are directly connected by an edge, i.e, it is an independent set.
  • Then, for each x \in S, you can set A_x to any value in the range [L_x, R_x] (endpoints included).

Your task is to minimize the value of \sum |A_u - A_v| over all unordered pairs (u, v) such that u and v are connected by an edge.

EXPLANATION:

Observation : For a particular node u if we are going to change its value, we will either make it equal to the median of all A_i values of its neighbors or either it would be equal to L_u or R_u depending upon which is the nearest value to the median.

Median minimizes the sum of absolute deviations

Given a set of values S
We’re basically after:
arg min(x) \sum_{i=1}^{N} |si−x|

One should notice that d|x|/dx=sign(x)
Hence, deriving the sum above yields \sum_{i=1}^{N} sign|si−x|
This equals to zero only when the number of positive items equals the number of negative which happens when x=median[s_1,s_2,⋯,s_N].

If median is less than L then we will set A_u to L: Suppose we set A_u to some value L^{'} > L. Reducing L^{'} by 1 does not increase the absolute sum of deviations(number of values \geq L^{'} is less than or equal to number of values < than L^{'}) and thus we can keep on decreasing the value of L^{'} until we reach L.
Similarly, if median is greater than R we will set it to R.

Now this problem can be solved using dynamic programing.
Let us root the tree at node 1. For each node of the tree calculate the optimal value after a change by sorting all the values of the neighbors and using the above observation and let this value for node i be stored in optimal_i. Let dp[u][1] represent the minimum value that can be obtained in the subtree of node u by selecting some or more nodes in set S while also changing the value of node u and dp[u][0] represents the minimum value that can be obtained in its subtree by selecting some or more nodes in set S while not changing the value of node u. If we change the value of node u we cannot change the value of any of its neighbors. Initialize all the dp states with 0.

start a depth first search from node 1, Now
dp[u][1]= \sum_{} dp[x][0] + abs(A[x] - optimal[u]) for every child x of node u
dp[u][0] = \sum_{} min(dp[x][0] + abs(A[x] - A[u]),dp[x][1]+abs(optimal[x]-A[u])) for every child x of node u

The answer to the problem is min(dp[1][0], dp[1][1])

TIME COMPLEXITY:

O(Nlog(N)) for each test case.

SOLUTION:

Setter's solution
#include <bits/stdc++.h>
using namespace std;
typedef long long llo;
#define a first
#define b second
#define pb push_back
#define endl '\n'



void setIO(string name) {
	ios_base::sync_with_stdio(0); cin.tie(0);
	freopen((name+".in").c_str(),"r",stdin);

	freopen((name+".out").c_str(),"w",stdout);
}
vector<llo> adj[200001];
llo dp[200001][2];
llo aa[200001];
llo ll[200001];
llo rr[200001];
llo cc[200001];
void dfs(llo no,llo par=-1){
	dp[no][0]=0;
	dp[no][1]=0;
	for(auto j:adj[no]){
		if(j!=par){
			dfs(j,no);
			dp[no][1]+=dp[j][0]+abs(aa[j]-cc[no]);
			dp[no][0]+=min(dp[j][1]+abs(aa[no]-cc[j]),dp[j][0]+abs(aa[no]-aa[j]));
		}
	}

}
int main(){

	//for(int ii=1;ii<=12;ii++){
		//setIO("12");
		llo t;
		cin>>t;
		while(t--){
			llo n;
			cin>>n;
			for(llo i=0;i<n;i++){
				adj[i].clear();
			}
			llo zz=5;
			for(llo i=0;i<n;i++){
				cin>>ll[i]>>aa[i]>>rr[i];
			}
			for(llo i=0;i<n-1;i++){
				llo aa,bb;
				cin>>aa>>bb;
				aa--;
				bb--;
				adj[aa].pb(bb);
				adj[bb].pb(aa);
			}
			for(llo i=0;i<n;i++){
				vector<llo> cur;
				for(auto j:adj[i]){
					cur.pb(aa[j]);
				}
				sort(cur.begin(),cur.end());
				llo x=cur.size();
				cc[i]=cur[x/2];
				if(cc[i]<ll[i]){
					cc[i]=ll[i];
				}
				else if(cc[i]>rr[i]){
					cc[i]=rr[i];
				}
			}
			dfs(0);
			cout<<min(dp[0][0],dp[0][1])<<endl;
		}
	//}






	return 0;
}

Tester-1's Solution
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I long long int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define rnd randGen(rng)
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert(g != -1);
        if (g == endd) {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
P(I,I) dfs(I x,I pr,V(I) tr[],I a[],I l[],I r[],I vis[]){
  vis[x]++;
  V(I) v;
  asc(i,0,sz(tr[x])){
    v.pb(a[tr[x][i]]);
  }
  sort(all(v));
  I k=max(l[x],min(r[x],v[sz(v)/2]));
  I f=0;
  asc(i,0,sz(v)){
    f+=abs(v[i]-k);
  }
  I s=0;
  asc(i,0,sz(tr[x])){
    I y=tr[x][i];
    if(y!=pr){
      P(I,I) temp=dfs(y,x,tr,a,l,r,vis);
      f+=temp.se;
      s+=min(temp.fi,temp.se+abs(a[x]-a[y]));
    }
  }
  return {f,s};
}
int main(){
  mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
  uniform_int_distribution<I> randGen;
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
  I t;
  t=readInt(1,10000,'\n');
  I ns=0;
  while(t--){
    I n;
    n=readInt(3,100000,'\n');
    ns+=n;
    assert(ns<=200000);
    I a[n+1],l[n+1],r[n+1];
    asc(i,1,n+1){
      l[i]=readInt(1,1000000000,' ');
      a[i]=readInt(1,1000000000,' ');
      r[i]=readInt(1,1000000000,'\n');
    }
    V(I) tr[n+1];
    asc(i,0,n-1){
      I u,v;
      u=readInt(1,n,' ');
      v=readInt(1,n,'\n');
      tr[u].pb(v);
      tr[v].pb(u);
    }
    I vis[n+1]={};
    P(I,I) ans=dfs(1,0,tr,a,l,r,vis);
    asc(i,1,n+1){
      assert(vis[i]);
    }
    cout<<min(ans.fi,ans.se)<<"\n";
  }
  return 0;
}
Tester-2's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll max(ll l, ll r){ if(l > r) return l ; return r;}
ll min(ll l, ll r){ if(l < r) return l ; return r;}

 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1;
const int MAX_N = 100;
const int SUM_N = 300000;
const int MAX_VAL = 100; 
const int SUM_VAL = 20005 ;
const int OFFSET = 10000 ;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
#define int long long
 
ll sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 998244353;

using ii = pair<ll,ll>;

vector<int> a,l,r;
vector<vector<int> > g, dp;
int cnt;

void dfs(int c, int p){
    cnt++;
    vector<int> tmp;

    for(auto h:g[c]){
        if(h!=p) dfs(h,c);
        tmp.pb(a[h]);
    }

    sort(tmp.begin(), tmp.end());
    for(auto h:g[c]){
        if(h!=p) dp[c][0] += min(dp[h][0]+abs(a[h]-a[c]), dp[h][1]);
    }
    ll z1,z2,z;
    if(tmp.size()&1) z1 = z2 = tmp[tmp.size()/2];
    else{
        z1 = tmp[tmp.size()/2-1];
        z2 = tmp[tmp.size()/2];
    } 

    if(l[c]>z2) z = l[c];
    else if(r[c]<z1) z = r[c];
    else if(z1>=l[c] && z1<=r[c]) z = z1;
    else z = l[c];

    for(auto h:g[c]){
        dp[c][1]+=dp[h][0]+abs(z-a[h]);
    }

}


void solve()
{   
    int n = readIntLn(2,1e5);
    sum_n+=n;

    a.resize(n);
    l.resize(n);
    r.resize(n);
    g.assign(n, vector<int>());
    dp.assign(n, vector<int>(2,0));

    int x,y;

    rep(i,n){
        l[i] = readIntSp(1,1e9);
        a[i] = readIntSp(l[i],1e9);
        r[i] = readIntLn(a[i],1e9); 
    }

    rep(i,n-1){
        x = readIntSp(1,n);
        y = readIntLn(1,n);
        assert(x!=y);
        --x, --y;
        g[x].pb(y);
        g[y].pb(x);
    }

    cnt = 0;
    dfs(0,-1);

    assert(cnt==n);

    cout<<min(dp[0][0], dp[0][1])<<'\n';

    
}
 
signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    freopen("input.txt" , "r" , stdin) ;
    freopen("output.txt" , "w" , stdout) ;
    #endif
    
    int t = 1;
    
    t = readIntLn(1,1e4);

    for(int i=1;i<=t;i++)
    {    
        solve() ;
    }
    
    assert(getchar() == -1);
    assert(sum_n<=2e5);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    // cerr<<"Sum of lengths : " << sum_n << '\n';
    // cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Minimum length : " << min_n << '\n';
    // cerr << "Sum o f product : " << sum_nk << '\n' ;
    // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e18;
const int N = 2e5 + 11, mod = 1e9 + 7;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int L[N], R[N], A[N];
vll v[N];
ll dp[N][2], optimal[N];
void dfs(int u, int p = -1)
{
    dp[u][0] = 0;
    dp[u][1] = 0;
    for (auto x : v[u])
    {
        if (x != p)
        {
            dfs(x, u);
            dp[u][1] += dp[x][0] + abs(A[x] - optimal[u]);
            dp[u][0] += min(dp[x][1] + abs(A[u] - optimal[x]), dp[x][0] + abs(A[u] - A[x]));
        }
    }
}
void sol(void)
{
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> L[i] >> A[i] >> R[i],v[i].clear();
    for (int i = 0; i < n - 1; i++)
    {
        int a, b;
        cin >> a >> b;
        v[a].pb(b);
        v[b].pb(a);
    }
    for (int i = 1; i <= n; i++)
    {
        vll cur;
        for (auto x : v[i])
        {
            cur.pb(A[x]);
        }
        sort(all(cur));
        optimal[i] = cur[cur.size() / 2];
        if (optimal[i] < L[i])
        {
            optimal[i] = L[i];
        }
        else if (optimal[i] > R[i])
        {
            optimal[i] = R[i];
        }
    }
    dfs(1);
    cout << min(dp[1][0], dp[1][1]) << endl;
    return;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    int test = 1;
    cin >> test;
    while (test--)
        sol();
}


1 Like

Inspired by : Problem - A - Codeforces CF Round 722

3 Likes

can anyone tell me why my sol is giving WA?
https://www.codechef.com/viewsolution/68970482

Instead of using median we can also just ternary search to find optimal value and then dp
Solution

i have a doubt that, if we are considering to change a value in the node then it is optimal to a change it to median but if we are not changing the value then why it is optimal to consider only the optimal values of children nodes if we assume that they are changed, because changed values in some children nodes sometimes maximizes the dp value.
basically , if we are at node u and try not to change it and some child node x which we change then why it is optimal to take only that one value out of the range[lx , rx].

@piyush_3004 - you can use this feature to check which test case your code is failing on -

1 Like

Thanks, I found the mistake :upside_down_face::upside_down_face:
ban jata to 5 star ban jate :upside_down_face: