# LAZYANC - Editorial

Author: Nishit Sharma
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

2723

# PREREQUISITES:

Dynamic programming

# PROBLEM:

Given a tree on N nodes where node i has value A_i, for each u from 1 to N compute

\sum_{i=1} \left\lfloor \frac{A_i}{2^{d(u, i)}}\right\rfloor

# EXPLANATION:

The main observation here is as follows: if d(u, i) \gt 20, then \left\lfloor \frac{A_i}{2^{d(u, i)}}\right\rfloor = 0 no matter what the value of A_i is, since A_i \leq 10^6.

This means that we only care about vertices at distances \leq 20 from a given u.
Of course, this doesnâ€™t directly solve the problem, but itâ€™s a start.

Letâ€™s root the tree at some node, say 1.
With this root, let p^i(u) denote the i-th ancestor of u. In particular, p^0(u) = u and p^1(u) is the parent of u.

Note that u contributes a value of \left\lfloor \frac{A_u}{2^{i}}\right\rfloor to p^i(u), and vice versa.

In particular, this allows us to, at least, compute the answer for every u when only considering values that lie in its subtree: for each u, add \left\lfloor \frac{A_u}{2^{i}}\right\rfloor to the answer of p^i(u) for each i from 0 to 20.
This takes \mathcal{O}(20N) time.

Now, letâ€™s look at a specific u. Weâ€™ve already computed the contribution of things in its subtree, so we need to look outside.
So, letâ€™s look at p^1(u). Consider some node v in the subtree of p^1(u), that is not in the subtree of u.
If d(v, p^1(u)) = k, then d(v, u) = k+1. Can we use this in some way?

Yes, we can!
Letâ€™s compute a 3D dynamic programming table: dp[u][k][x] stores the following:

• Consider the subtree of vertex u, and all nodes at a distance of k from u in this subtree.
• dp[u][k][x] holds the contribution of such nodes to a node at a distance of x from u.

So, coming back to our earlier discussion, the contribution of nodes in the subtree of p^1(u) to the answer of u can be contributed using dp[p^1(u)][k][1] across all k.
Note that this will also include some values in the subtree of u, which shouldnâ€™t be counted: their contribution can be subtracted out separately using the appropriate cell in the dp table.

Note that this allows us to visit every relevant ancestor of u and do the same thing. That is, for each 0 \leq i \leq 20, visit p^i(u) and add the values of dp[p^i(u)][k][i] across all k, while also subtracting appropriate dp values to ensure that nothing is double-counted.
This will cover every node that is at a distance of \leq 20 from u, which is exactly what we wanted.

The algorithm given above takes \mathcal{O}(20^2\cdot N) time and space.
Itâ€™s possible to optimize the space to \mathcal{O}(20N), but this optimization was unnecessary to get AC.

# TIME COMPLEXITY

\mathcal{O}(20^2\cdot N) per test case.

# CODE:

Setter's code (C++)
#include<bits/stdc++.h>
#define ll long long int
#define fab(a,b,i) for(int i=a;i<b;i++)
#define pb push_back
#define db double
#define mp make_pair
#define endl "\n"
#define f first
#define se second
#define all(x) x.begin(),x.end()
#define vll vector<ll>
#define vi vector<int>
#define pii pair<int,int>
#define pll pair<ll,ll>
#define quick ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)

using namespace std;

const int MOD = 1e9 + 7;

ll add(ll x, ll y) {ll res = x + y; return (res >= MOD ? res - MOD : res);}
ll mul(ll x, ll y) {ll res = x * y; return (res >= MOD ? res % MOD : res);}
ll sub(ll x, ll y) {ll res = x - y; return (res < 0 ? res + MOD : res);}
ll power(ll x, ll y) {ll res = 1; x %= MOD; while (y) {if (y & 1)res = mul(res, x); y >>= 1; x = mul(x, x);} return res;}
ll mod_inv(ll x) {return power(x, MOD - 2);}
ll lcm(ll x, ll y) { ll res = x / __gcd(x, y); return (res * y);}

#define int ll
void dfs(int src, int par, vector<int> &a, vector<vector<int>> &v, vector<vector<int>> &values, vector<int> &parent) {

parent[src] = par;
for (int &i : v[src]) {
if (i ^ par) {
dfs(i, src, a, v, values, parent);
}
}

int curr = src;
int val = a[src];
while (curr != -1 and val > 0) {
values[curr].push_back(val);
curr = parent[curr];
val >>= 1;
}
}

int32_t main()
{

quick;
int t = 1;
cin >> t;
while (t--)
{
int n;
cin >> n;
vector<vector<int>> v(n);

for (int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
x--, y--;
v[x].push_back(y);
v[y].push_back(x);
}

vector<int> a(n);
for (int i = 0; i < n; i++) {
cin >> a[i];
}

vector<vector<int>> values(n);
vector<int> parent(n, -1);
dfs(0, -1, a, v, values, parent);

vector<int> ans(n);

for (int &i : values[0]) {
ans[0] += i;
}

const int maxA = 1e6 + 5;

const int N = log2(maxA) + 3;

vector<vector<int>> moveNodes(n, vector<int> (N));

for (int i = 0; i < n; i++) {
for (int &j : values[i]) {
for (int k = 0; k < N; k++) {
int val = (j >> k);
moveNodes[i][k] += val;
if (val == 0) break;
}
}
}

for (int i = 1; i < n; i++)  {

ans[i] = moveNodes[i][0];
int last = i;
int curr = parent[i];
for (int j = 1; j < N - 1 and curr != -1; j++) {
int val = (moveNodes[curr][j] - moveNodes[last][j + 1]);
ans[i] += val;
last = curr;
curr = parent[curr];
}
}

for (int i = 0; i < n; i++) cout << ans[i] << " ";
cout << endl;

}

cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl;
return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>

using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}

assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
// cerr << res << endl;
return res;
}

string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int minv, int maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

struct dsu {
int n;
vector<int> p;
vector<int> sz;

dsu(int _n) : n(_n) {
p = vector<int>(n);
iota(p.begin(), p.end(), 0);
sz = vector<int>(n, 1);
}

inline int get(int x) {
if (p[x] == x) {
return x;
} else {
return p[x] = get(p[x]);
}
}

inline bool unite(int x, int y) {
x = get(x);
y = get(y);
if (x == y) {
return false;
}
p[x] = y;
sz[y] += sz[x];
return true;
}

inline bool same(int x, int y) {
return (get(x) == get(y));
}

inline int size(int x) {
return sz[get(x)];
}

inline bool root(int x) {
return (x == get(x));
}
};

int main() {
input_checker in;
int sn = 0;
while (tt--) {
sn += n;
vector<vector<int>> g(n);
dsu uf(n);
for (int i = 0; i < n - 1; i++) {
x--;
y--;
g[x].emplace_back(y);
g[y].emplace_back(x);
uf.unite(x, y);
}
assert(uf.size(0) == n);
vector<int> a(n);
for (int i = 0; i < n; i++) {
}
vector<int> pv(n, -1);
{
function<void(int, int)> Dfs = [&](int v, int p) {
for (int to: g[v]) {
if (to == p) {
continue;
}
pv[to] = v;
Dfs(to, v);
}
};
Dfs(0, -1);
}
vector s(n, vector(20, vector<long long>(20)));
for (int i = 0; i < n; i++) {
int v = i;
for (int j = 0; j < 20; j++) {
if (v == -1) {
break;
}
for (int k = j; k < 20; k++) {
s[v][j][k] += a[i] >> k;
}
v = pv[v];
}
}
for (int i = 0; i < n; i++) {
long long ans = 0;
for (int j = 0; j < 20; j++) {
ans += s[i][j][j];
}
int last = i;
int v = pv[i];
for (int j = 1; j < 20; j++) {
if (v == -1) {
break;
}
for (int k = 0; j + k < 20; k++) {
ans += s[v][k][j + k];
if (k != 0) {
ans -= s[last][k - 1][j + k];
}
}
last = v;
v = pv[v];
}
cout << ans << " \n"[i == n - 1];
}
}
assert(sn <= 5e4);
return 0;
}

Tester's code (C++)
//Utkarsh.25dec
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}

if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}
int par[N];
void checktree(int curr)
{
vis[curr]=1;
{
if(vis[it])
continue;
par[it]=curr;
checktree(it);
}
}
int sumN=0;
void solve()
{
sumN+=n;
assert(sumN<=50000);
for(int i=1;i<=n;i++)
{
vis[i]=0;
}
for(int i=1;i<n;i++)
{
int u,v;
assert(u!=v);
}
checktree(1);
for(int i=1;i<=n;i++)
{
assert(vis[i]==1);
vis[i]=0;
}
int A[n+1];
memset(A,0,sizeof(A));
for(int i=1;i<=n;i++)
{
if(i==n)
else
}
vector <int> vals[n+1];
for(int i=1;i<=n;i++)
{
int curr=i;
for(int j=0;j<=22;j++)
{
if((A[i]/(1<<j))>0)
vals[curr].pb(A[i]/(1<<j));
curr=par[curr];
if(curr==0)
break;
}
}
ll shifts[n+1][23];
memset(shifts,0,sizeof(shifts));
for(int i=1;i<=n;i++)
{
for(int j=0;j<=22;j++)
{
for(auto it:vals[i])
shifts[i][j]+=(it/(1<<j));
}
}
ll ans[n+1];
memset(ans,0,sizeof(ans));
for(int i=1;i<=n;i++)
{
ans[i]=shifts[i][0];
int x=i;
int y=par[i];
int cnt=1;
while(y!=0)
{
if(cnt>=21)
break;
ans[i]+=(shifts[y][cnt]-shifts[x][cnt+1]);
x=par[x];
y=par[y];
cnt++;
}
}
for(int i=1;i<=n;i++)
cout<<ans[i]<<' ';
cout<<'\n';
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
ios_base::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
while(T--)
solve();
assert(getchar()==-1);
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}


I have O(20 n) solution.

Like in the solution, first calculate the sum for subtree of every node. To do this, calculate freq[i][j] which is the frequency of the j^{th} bit in the summation of the subtree of i^{th} node.

vector<vector<int>> freq(n,vector<int> (bit,0));
function<void(int,int)> sub_dfs = [&](int i,int par){
for(int j = 0; j < bit; j++){
if(1 << j & a[i]){
freq[i][j]++;
}
}
for(int node : graph[i]){
if(node != par){
sub_dfs(node,i);
for(int j = 0; j < bit-1; j++){
freq[i][j] += freq[node][j+1];
}
}
}
}; sub_dfs(0,-1);


Here we have considered 0 as the root. Now run another dfs which will find the freq array rooted at child if the value rooted at parent is known

vector<long long int> ans(n);
function<void(int,int)> dfs = [&](int i,int par){
if(par != -1){
for(int j = 2; j < bit; j++){
int ex = freq[par][j-1] - freq[i][j];
freq[i][j-2] += ex;
}
}

for(int j = 0; j < bit; j++){
ans[i] += (1ll << j) * (long long int)freq[i][j];
}

for(int node : graph[i]){
if(node != par){
dfs(node,i);
}
}
}; dfs(0,-1);


Which gives us the answer in O(20 N). Submission during contest - CodeChef | Competitive Programming | Participate & Learn

1 Like

I solved it using a different approach.
For a node v, lets denote score_{ij} to the sum of a_x / 2^j for every node x that is i distance away from node v (Only considering 0 \le i < 20 and 0 \le j < 20).
We need to calculate score_{ij} for all nodes.
In first DFS, we calculate score_{ij} for each node v, but only considering the subtree rooted at v.
Then in second DFS, we compute the score fully.
Couldnâ€™t solve it in contest though.

Accepted submission : CodeChef | Competitive Programming | Participate & Learn

@jaimahakal Nice solution. Your code is very clean and readable and quite small as well.
Are you on codeforces?