PSTONES - Editorial

PROBLEM LINK:

Division 1
Division 2

Author: Arthur Nascimento
Tester: Данило Мочернюк
Editorialist: Ritesh Gupta

DIFFICULTY:

HARD

PREREQUISITES:

DP on Trees, Bitmasking, DFS, Walsh-Hadamard

PROBLEM:

A tree consists of N nodes connected by N-1 edges in such a way that each node is reachable from any other node. In the middle of each edge, there is a hidden precious stone. There are K different colours of stones, conveniently numbered 1 through K, and for each edge, we also know the colour of the stone hidden there.

We organise an expedition to find some stones in the following way:

  • You choose a set of M nodes to be the expedition basis, in such a way that they form a connected subgraph i.e. for any two nodes in the basis, the path between them must not visit any node that is not in the basis.
  • The expedition gathers the stones in all roads which lie on the “frontier” of the expedition basis, i.e. all edges for which one endpoint lies in the basis and the other endpoint does not.

For each of the 2^K possible sets of colours, you want to know if it is possible that the expedition will return with exactly this set, i.e. there is a basis such that the expedition returns with at least one stone of each colour in this set and no stones with colours that are not in this set. You also want you to find these answers for each possible value of M.

QUICK EXPLANATION:

  • The tree is a line for subtask 1, so it can be solved by simply generating each subgraph.

  • Subtask 2, we can perform a DFS and calculate the answer for each node by considering all its subtrees and storing their answer, i.e., using Dynamic Programming.

  • Subtask 3, we can use the same Dynamic Programming, writing the transitions as a Walsh-Hadamard transform (OR convolution), it is O(N^2 * K * 2^K).

  • Subtask 4, construct each subtree uniquely by deciding whether to include or not the descendant for each of the descendants of a particular node and find all possible sets of colours with Dynamic Programming.

EXPLANATION:

Subtask 1: For subtask 1, generating all possible valid subgraphs would be enough. Now, since the given tree is a line, generating subgraphs is quite similar to generating subarrays of an array for which the time complexity is O(n^2).

Subtask 2: For subtask 2, let’s root the tree at some arbitrary vertex. Perform DFS from the root and calculate the answer for each node while moving up the DFS, using Dynamic Programming. For each node, consider all valid subtrees having size varying from 0 through S (where S is the size of the subtree with the node as root of subtree and root as the root of the tree) and find out the colour combination which can be obtained with this subtree size. Denote the colour combination by a bitmask of size 2^k.

Formally, let the DP states be current_node, size_of_subtree, and mask which shows that the subtrees that are connected to current_node and have their size equal to size_of_subtree can have a mask as their frontier if the corresponding value is true. There can be three possibilities for a node :

  1. No child: In this case, we do not have any child to consider in the root’s subtree.
  2. One child: In this case, including the child in the root’s subtree would increase the size of the root’s subtree by 1
  3. More than one children
    1. Two children: In this case, we want all the possible size combinations of subtrees which have some mask associated with it, so we try out all (size_1 * 2^k)*(size_2 * 2^k) possibilities.
    2. More than two children: In this case, we will go through each child one by one and try out all the possibilities as we did in the previous case, but here, we will take help of the fact that we can memorize the values, so we don’t recalculate them over and over. Now, we only have to consider all the possibilities of this particular child with the previously stored values.

Note: We’re ignoring any edge between the current vertex and its parent during these calculations. However, the answer for the current node can be easily updated to include this edge (if it exists).

The time complexity of this approach is, O(n^2 2^k * X), X = amortized cost of multiplying all sizes of subtrees at each vertex. Using bitsets to do operations will make things even faster!! X is at its very worst case is O(n^2/2) for a star tree (which may seem a little high but it is not exactly the case) but it’s just for a single vertex and all others will be done in O(1). In binary tree, it will take O(n^2 *log) for all nodes(i.e., X = N logN).

Subtask 3: We can see that the resulting array after combining two arrays for any particular node’s children is some kind of multiplication of two arrays. You can see that the resulting array is similar to that of OR-convolution (using Fast Walsh Hadamard Transform). i.e., c_k = 𝜮 a_i *a_j , i | j=k. This will reduce the O(2^k * 2^k) operations at each step to O(2^k * k).

Subtask 4: Any subtree can be uniquely constructed from a rooted tree by choosing some vertex x, adding all descendants of x to the subtree, and then removing the descendants of all vertices in S where S is an independent set of descendants of x, such that no vertex in S is the ancestor of the other.

Let’s do this for a particular x. Then how do we choose S?

Iterate over the descendants of x in DFS traversal order and for each iteration, decide whether to put this node in S or not. If node u is chosen, then obviously all its descendants are not going to be included. It can be done by storing the DFS order traversal and skipping the interval from it which is not being included.

Now, we are able to find all possible sets of colours with Dynamic Programming. States of the DP are current_node, size_of_subtrees_to_be_removed, and mask (target subset of colours).

Note that since x is not actually fixed, we need a modification in DP: it will return the first moment in which it is possible to achieve our set of colours (ie. it minimizes the largest DFS order of the elements in S). Now, given any x, we can check if some mask is possible by checking if the returned vertex of the DP is in x ’s subtree.

TIME COMPLEXITY:

TIME: O(N^2 * 2^K)
SPACE: O(N^2 * 2^K)

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
 
#define maxn 105
 
#define maxK (1<<12)
 
#define ll long long
 
#define pb push_back
 
#define pii pair<int,int>
 
#define mod 1000000007 
 
#define debug(args...) //fprintf(stderr,args)
 
using namespace std;
 
 
 
vector<pii> L[maxn];
 
int pre[maxn];
 
int cnt = 0;
 
 
 
int par[maxn];
 
int sub[maxn];
 
int nxt[maxn];
 
 
 
void dfs(int vx,int p,int nx,int c){
 
    debug("oi %d\n",vx);
 
    pre[vx] = cnt++;
    nxt[pre[vx]] = nx;
 
    sub[pre[vx]] = 1;
 
    par[pre[vx]] = c;
 
    vector<int> prox;
 
    for(int i=L[vx].size()-1;i>=0;i--){
        prox.pb(nx);
        if(L[vx][i].first != p) nx = L[vx][i].first;
    }
    reverse(prox.begin(),prox.end());
 
    
    for(int i=0;i<L[vx].size();i++){
        pii u = L[vx][i];
        if(u.first == p) continue;
        dfs(u.first,vx,prox[i],u.second);
        sub[pre[vx]] += sub[pre[u.first]];
        nx = u.first;
    }
    
 
}
 
 
 
int dp[maxn][maxn][maxK];
 
 
 
main(){
 
 
 
    int nt;
 
    scanf("%d",&nt);
 
 
 
    while(nt--){
 
 
 
        int n,k;
 
        scanf("%d%d",&n,&k);
 
 
 
        cnt = 0;
 
        for(int i=0;i<n;i++)
 
            L[i].clear();
 
 
 
        for(int i=0;i<n-1;i++){
 
            int a,b,c;
 
            scanf("%d%d%d",&a,&b,&c), a--, b--, c--;
 
            L[a].pb({b,c});
 
            L[b].pb({a,c});
 
        }
 
 
        pre[n] = n;
 
        dfs(0,0,n,0);
 
        for(int i=0;i<n;i++)
            nxt[i] = pre[nxt[i]];
 
 
 
        for(int i=0;i<n;i++)
 
            debug("vx %d sub %d nxt %d par %d\n",i,sub[i],nxt[i],par[i]);
 
 
 
        for(int sz=0;sz<=n;sz++)
 
            for(int mask=0;mask<(1<<k);mask++)
 
                dp[n][sz][mask] = n;
 
 
 
    
 
            for(int sz=0;sz<=n;sz++)    for(int pos=n-1;pos>=0;pos--)
 
                for(int mask=0;mask<(1<<k);mask++){
 
 
 
                    dp[pos][sz][mask] = dp[pos+1][sz][mask];
 
                    if(mask == (1<<par[pos]) && sz == sub[pos])
 
                        dp[pos][sz][mask] = pos;
 
                    if((mask & (1<<par[pos])) && sz >= sub[pos])
 
                        dp[pos][sz][mask] = min(dp[pos][sz][mask], min(dp[nxt[pos]][sz-sub[pos]][mask],dp[nxt[pos]][sz-sub[pos]][mask-(1<<par[pos])]));
 
 
 
                    debug("dp[%d][%d][%d] = %d\n",pos,sz,mask,dp[pos][sz][mask]);
 
 
 
                }
 
 
 
        for(int sz=1;sz<=n;sz++){
 
            debug("%d: ",sz);
 
            for(int m=0;m<(1<<k);m++){
 
                int ok = 0;
 
                for(int i=0;i<n;i++){
 
                    int pos = pre[i];
 
                    if((m&(1<<par[pos])) == 0 && pos != 0) continue;
 
                    if(sub[pos] < sz) continue;
 
                    int m2 = m;
 
                    if(pos != 0) m2 -= (1<<par[pos]);
 
                    if(min(dp[pos][sub[pos]-sz][m],dp[pos][sub[pos]-sz][m2]) < nxt[pos])
 
                        ok = 1;
 
                    if(m2 == 0 && sub[pos] == sz)
 
                        ok = 1;
                        
                    //if(ok) break;
 
                }
 
                printf("%d",ok);
 
            }
 
            printf("\n");
 
        }
 
        
 
 
 
        
 
 
 
    }
 
 
 
 
 
}
Tester's Solution
#include<bits/stdc++.h>
#define pb push_back
#define x first
#define y second
#define sz(a) (int)(a.size())
using namespace std;
const int MAX = 105;
vector<pair<int, int> > g[MAX];
int n, k;
int parentEdge[MAX] , sub[MAX];
int T;
int dp[MAX][MAX][1 << 12];
int res[MAX][1 << 12];
void setmin(int &a, int b)
{
    a = min(a , b);
}
void setmax(int &a, int b)
{
    a = max(a , b);
}
int dfs(int u, int p = -1)
{
    int id = T;
    sub[id] = 1;
    for(auto v : g[u])
    {
        //cerr << u << " " << v.x << " " << v.y << endl;
        if(v.x == p)
            continue;
        parentEdge[++T] = v.y;
        sub[id] += dfs(v.x, u);
    }
    //cerr << u << " " << sub[id] << endl;
    return sub[id];
}
int main()
{
    int t;
    cin >> t;
    int total = 0;
    while(t--)
    {
        cin >> n >> k;
        total += n;
        assert(total <= 500);
        assert(1 <= n && n <= 100);
        assert(1 <= k && k <= 12);
        for(int i = 1; i < n; i++)
        {
            int u, v, c;
            cin >> u >> v >> c;
            assert(1 <= u && u <= n);
            assert(1 <= v && v <= n);
            assert(1 <= c && c <= k);
            u--;v--;c--;
            //cerr << u << " " << v << " " << c << endl;
            g[u].pb({v , c});
            g[v].pb({u , c});
        }
        assert(dfs(0) == n);
        for(int i = 0; i <= n; i++)
            for(int j = 0; j <= n; j++)
                for(int mask = 0; mask < (1 << k); mask++)
                    dp[i][j][mask] = n;
        dp[n][0][0] = 0;
        for(int pos = n - 1; pos >= 0; pos--)
        {
            dp[pos][0][0] = pos;
            int c = parentEdge[pos];
            dp[pos][sub[pos]][1 << c] = pos;
            for(int sz = 0; sz <= n; sz++)
            {
                for(int mask = 0; mask < (1 << k); mask++)
                {
                    setmin(dp[pos][sz][mask], dp[pos + 1][sz][mask]);
                    if(((mask >> c) & 1) && sz >= sub[pos])
                    {
                        setmin(dp[pos][sz][mask], dp[pos + sub[pos]][sz - sub[pos]][mask ^ (1 << c)]);
                        setmin(dp[pos][sz][mask], dp[pos + sub[pos]][sz - sub[pos]][mask]);
                    }
                    setmax(dp[pos][sz][mask], pos);
                }
            }
        }
        for(int i = 1; i < n; i++)
            res[sub[i]][(1 << parentEdge[i])] = 1;
        for(int sz = 1; sz <= n; sz++)
        {
            for(int mask = 0; mask < (1 << k); mask++)
            {
                for(int root = 0; root < n; root++)
                {
                    int c = root ? parentEdge[root] : -1;
                    if(c != -1 && !((mask >> c) & 1))
                    {
                        continue;
                    }
                    int mask2 = c == -1 ? mask : (mask ^ (1 << c));
                    int minVertex = min(dp[root][sz][mask] , dp[root][sz][mask2]);
                    if(minVertex < root + sub[root])
                    {
                        res[sub[root] - sz][mask] = 1;
                    }
                }
            }
        }
        res[n][0] = 1;
        for(int i = 1; i <= n; i++)
        {
            for(int j = 0; j < (1 << k); j++)
                cout << res[i][j];
            cout << endl;
        }
        for(int i = 1; i <= n; i++) 
            for(int j = 0; j < (1 << k); j++)
                res[i][j] = 0;
        T = 0;
        for(int i = 0; i < n; i++)
            g[i].clear();
    }
    return 0;
} 
2 Likes

We can do it in O(N^2 * 2^K) also by slight variation in fast walsh hadamard transform.
Instead of multiplying polynomials again and again (which takes O(2^k * k) each time), we can just keep track of point value form for each polynomial and perform addition and multiplication on point value form only which will only take O(2^k) each time.
For example, if we have to calculate OR multiplication of two polynomial, what we generally do is

  1. Calculate point value form of both polynomials using FWHT.
  2. Multiply values of point value form to obtain point value form of our desired polynomial.
  3. Then derive actual desired result by taking inverse FWHT of final point form.

Actually we do not need to perform 1st and 3rd step again and again. we will just keep track of point value form of polynomials and perform addition and multiplication on point value form only without taking inverse FWHT at every step to obtain actual polynomial. We will perform inverse FWHT only once after all calculations are done to obtain actual polynomial.

This is one more problem where this trick of dealing with point-value form only comes very useful: https://www.codechef.com/problems/PSUM

Basically this trick is useful whenever we are repeatedly multpliplying, adding, subtracting polynomials but we do not need the actual polynomial right at that moment, it is unnecessary to convert it back to actual polynomial and just keep performing operations on point value form.

1 Like

can you please elaborate it more
i read it multiple times but couldn’t understand it much

I think it means that there is a bijection between each subtree T and a pair of (x, S), where x is a node and S is an independent set of nodes. Below let’s see how these 2 relate:

Subtree T -> Pair (x, S):
if the tree is rooted, any subtree has a vertex that is highest (i.e. closest to the root) -> declare this as x. Initialize S to be initially empty, and for each leaf L of T, include all L’s children into S. It’s clear that S is an independent set - otherwise, there were two different nodes L1 and L2 with children a & b where a is an ancestor of b, which is impossible (this would imply that L1 is an ancestor of L2, but they’re both leaves of a subtree). This means that each subtree correspond to a pair (x, S).

Pair (x, S) -> Subtree T:
Construct T in a following way: add all descendants of x; For each element Y in S, remove all descendants of Y from T. It’s easy to prove that the resulting set of nodes in T form a tree, hence T is a subtree indeed.

Hope this is helpful.

You can apply centroid decomposition of initial algorithm described in subtask 4 to achieve O(2^k * n^2 log n ).

1 Like