PILGRIMS - Editorial

PROBLEM LINK:

Practice
Div-3 Contest
Div-2 Contest
Div-1 Contest

Setter: Ashish Vishal, Aditya Kumar Singh
Tester: Ronit Raj, Daanish Mahajan

DIFFICULTY:

Easy

PREREQUISITES:

Tree, DFS, BFS, Greedy, Sorting

PROBLEM:

Given a tree of N nodes and M pilgrims having energy(E).
On moving from one node to another node certain amount of energy decreases.
Find the number of leaves of the tree that will be non-empty after the end of the journey of all the pilgrims.

EXPLANATION:

This problem can be solved by using dp by storing the energy required to move from the capital
city to the special city.
Let the energy required for a pilgrim to reach the city U be Eu and the
depth of city Node U be D(considering that the capital city has depth 1).
Then the energy required to move to a city V, where V is a child of U is equal to Ev and Ev is
given by Ev=Eu+D*(path length between City U and the capital City)

This gives the minimum energy required to reach the special cities from the
capital city and then we make a multiset of all the pilgrim’s initial energy.

Also, note that the capital city is not a special city.

Initially the answer=0;
Then, for each special city we do as following steps:

  1. Let Emin= the minimum energy required to reach the special city from
    the capital city.
  2. Search a minimum energy value that is greater than or equal to the Emin
    in the multiset.

If such an energy value is present in the multiset, delete it from
the multiset and increase the counter of the answer by 1.

The answer is your required result.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define is insert
#define rep1(i,a,b) for(long long i=a;i<=b;i++)
#define F first
#define S second
#define file ifstream fin("input.txt");ofstream fout("output.txt");
#define fast ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fr(n) for(long long i=0;i<n;i++)
#define rep(i,a,b) for(long long i=a;i<b;i++)
#define ALL(x) (x).begin(), (x).end()
typedef long long int ll;
typedef long double ld;
typedef vector<ll> vi;


vector<pair<ll,ll> > v[1000001];
vi minenergy(1000001),hei(1000001);
vector<bool>vis(1000001);

void dfs(ll x)
{
    stack<ll>s;
    s.push(x);

    while(!s.empty())
    {
        ll num=s.top();
        s.pop();
        if(vis[num])continue;
        vis[num]=1;

        for(auto c:v[num])
        {
            if(!vis[c.F])
            {
                hei[c.F]=hei[num]+1;
                minenergy[c.F]=minenergy[num]+hei[num]*c.S;
                s.push(c.F);
            }
        }
    }
   
}

void solve()
{
    ll n,m,a,b,w;
    cin>>n>>m;

    fr(n+1)
    {
        minenergy[i]=0;
        v[i].clear();
        hei[i]=0;
        vis[i]=0;
    }

    vi pegy(m);
    fr(m)cin>>pegy[i];

    fr(n-1)
    {
        cin>>a>>b>>w;
        v[a].pb({b,w});
        v[b].pb({a,w});
    }
    hei[1]=1;
    dfs(1);

    vi scity;
    rep(i,2,n+1)
    {
        if(v[i].size()==1)
        {
            scity.pb(minenergy[i]);            
        }
    }

    sort(ALL(scity));
    sort(ALL(pegy));
    
    int i=0,j=0,cnt=0;
    while(i<m && j<scity.size())
    {
        if(pegy[i]>=scity[j])
        {
            i++;j++;
            cnt++;
        }
        else
            i++;
    }
    cout<<cnt<<endl;
}
int32_t main()
{
    #ifndef ONLINE_JUDGE
    freopen("inputf.in", "r", stdin);
    freopen("outputf.in", "w", stdout);
    #endif
    fast;
    ll t=1;
    cin>>t; 
    while(t--)
    solve();
    return 0;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
# define ll long long int  
# define pb push_back
# define pii pair<int, int>
# define mp make_pair
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            // cout << x << " " << l << " " << r << endl;
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
const int maxn = 1e5, maxm = 1e5, maxk = 1e6, maxt = 50;
const ll maxe = 1e18;
vector<pii> g[100010]; bool visit[100010];
ll dp[100010];
multiset<ll> mset;
bool checkTree(int u, int pa){
    visit[u] = true;
    bool r = true;
    for(pii v : g[u]){
        if(v.first == pa)continue;
        if(visit[v.first])return false;
        r &= checkTree(v.first, u);
    }
    return r;
}
void dfs(int u, int pa, int d){
    bool in = false;
    for(pii v : g[u]){
        if(v.first == pa)continue;
        in = true;
        dp[v.first] = dp[u] + (ll)(d + 1) * v.second;
        dfs(v.first, u, d + 1);
    }
    if(!in){
        // cout << u << " " << dp[u] << endl;
        mset.insert(dp[u]);
    }
}
int main()
{   
    int t = readIntLn(1, maxt);
    while(t--){
        int n = readIntSp(2, maxn), m = readIntLn(1, maxm);
        vector<ll> v;
        for(int i = 0; i < m; i++){
            v.pb(i == m - 1 ? readIntLn(1, maxe) : readIntSp(1, maxe));
        }
        sort(v.begin(), v.end(), greater<ll>());
        for(int i = 0; i <= n; i++){
            g[i].clear(); visit[i] = false;
        }
        for(int i = 0; i < n - 1; i++){
            int u = readIntSp(1, n), v = readIntSp(1, n), k = readIntLn(1, maxk);
            g[u].pb(mp(v, k)); g[v].pb(mp(u, k));
            assert(u != v);
        }
        assert(checkTree(1, 0));
        memset(dp, 0, sizeof(dp));
        mset.clear();
        dfs(1, 0, 0);
        int ptr = 0, ans = 0;
        for(multiset<ll>::reverse_iterator it = mset.rbegin(); it != mset.rend(); it++){
            if(ptr < m && v[ptr] >= *it){
                ans++; ptr++;
            }
        }
        cout << ans << endl;
    }
    assert(getchar()==-1);
}
3 Likes

It seems my code is correct for problem D - Click Here
But I am getting WA in it, I have also covered the corner case (didn’t consider 1 as the special city).

I ran a DFS and calculated the distances of the leaf nodes(special cities) and then I greedily assigned the special cities to the people.

Can anyone point out the mistake, where I am going wrong

Hey mate, you did a silly mistake when you are making visit array 0 you did visit[1]= 1 in line 57 and ran a loop from 0 to 100005 in line 58 which makes visit[1] = 0 again.
I did a small change instead of running the loop from 0 I ran it from 2 to 100005 and it got accepted.

1 Like

@aadiupadhyay , Thank you so much mate, i did a blunder mistake.

But I have one more doubt, I made a small change in my code, instead of writing a for loop and assigning every element of visit [ i ] as 0, i just wrote visit.clear() at line 41, but i am getting WA, Can you tell why it is happening?

My code with a small change - Click Here

@aadiupadhyay, i got the reason why visit.clear() is not working and why running a for loop and individually assigning value works.

See this code - Click Here

Because I have assigned a constant size to the vector, so if I clear the vector
using “clear()” method, then the size of the vector becomes 0.

So, there is nothing in the vector. I learned something new today, I always misunderstood clear() method thinking that it assigns all elements as zero.

I did the same thing as the editorial says but it gave me WA.
Can anyone tell me what is wrong with my submission?

UPD: Found error, I was running the last loop till n instead of m. So stupid of me :frowning_face:.

1 Like

Getting WA. Solution: 45077377 | CodeChef
Any corner case which I am missing ?

1 Like

You were doing a silly mistake at line 73.
You assumed that going to node 1 from node 1 has a cost of 1 which should actually be 0.

So I corrected q.push({1,1}); to q.push({1,0});

The next mistake was in line 96, in which you are decrementing 1 from W[i], I guess you were compensating the extra one that you took at line 73 here, so i removed the “-1” and your code got an AC.

You AC code - Click Here

@adikr_singh Please link your editorial with the problem, it’s difficult to find for it otherwise.

But how that was a mistake ?

Since, I subtracted one when inserting weight, therefore, it should not matter in the end. right?

I have taken weight for city 1 to city 1, one because, I have checked if it was visited or not by checking the value of W array if it is 0 or not

1 Like

Yes, exactly, when I was debugging it, it seemed correct, but I just experimented with it.

Also, I defined int as a long long int for safety reasons.

I am also wandering why your code is incorrect, if I just wrote it in another way.

1 Like

Ohh, issue was overflow

1 Like

Yes, by looking at the constraints i defined int as long long int

Your WA code without long long int - Click Here

Yeah, got it, issue is overflow somewhere

1 Like

You can check for overflow using the technique described here

1 Like

Overflow is here mate,
int wt=q.front().second;

q.front().second is long long.

1 Like
ll wt=q.front().second;

Here overflow was occurring Line 76

1 Like

Exactly

Can’t use advantages of these

I use CC IDE

try:
def visi(a,b):
for i in a:
if b[i]==0:
return 1
return 0
def dfs(b,c,n,m,sum,v,ans):
if visi(b[n],v)==0:
ans.append(sum)
return
for i in b[n]:
if v[i]==0:
v[i]=1
dfs(b,c,i,m+1,sum+m*c[n][i],v,ans)
def solve():
ans=[]
n,m=map(int,input().split())
a=list(map(int,input().split()))
b=[[] for i in range(n)]
c=[[0]*n for i in range(n)]
for _ in range(n-1):
x,y,z=map(int,input().split())
b[x-1].append(y-1)
b[y-1].append(x-1)
c[x-1][y-1]=z
c[y-1][x-1]=z
vis=[0]*n
vis[0]=1
dfs(b,c,0,1,0,vis,ans)
ans.sort()
a.sort()
i=0
j=0
count=0
while i<len(ans) and j<len(a):
if a[j]<ans[i]:
j+=1
continue
else:
j+=1
i+=1
count+=1
print(count)
for _ in range(int(input())):
ans=[]
solve()
except:
pass

what’s wrong here it is giving tle