FIZZBUZZ2307 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: naisheel, jalp1428
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

2275

PREREQUISITES:

DFS, prefix sums

PROBLEM:

You have a directed tree. For each i = 1, 2, 3, \ldots, N, find the number of distinct vertices that can occur at the i-th position of one of its topological orders.

EXPLANATION:

First, find the root of the tree: it’s the only vertex with no incoming edges.

Let’s look at the problem from the opposite direction.
For each vertex u, which positions of a topological order can it appear in?

Observe that wherever u appears in topological order:

  • All of its ancestors must appear before it.
  • All of its descendants must appear after it.
  • Other vertices don’t have any relation to it at all.

In particular, if \text{dep}[u] denotes the depth of u (with the root having a depth of 1) and \text{sub}[u] denotes the number of vertices in the subtree of u (not including u itself),

  • u definitely can’t appear at positions 1, 2, 3, \ldots, \text{dep}[u]-1 because of the ancestor condition.
  • u definitely can’t appear at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 because of the descendant condition.

Finding the \text{dep} and \text{sub} arrays is easily done with a DFS.

So, u can only appear at positions between \text{dep}[u] and N-\text{sub}[u].
It’s not hard to see that it can appear at all these positions!
To place u at some position x (where \text{dep}[u] \leq x \leq N-\text{sub}[u]) we can:

  • Place the ancestors of u at positions 1, 2, 3, \ldots, \text{dep}[u]-1 in order from the root downwards.
  • Place the descendants of u at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 in some valid order.
    This is always possible because, as long as u appears in the order, its descendants only care about their order among themselves and can appear in any \text{sub}[u] positions.
  • Place u at position x, which is empty.
    This is also possible because, as noted above, u has no relations with anything other than its ancestors/descendants, so once they’re placed u can be freely placed anywhere between them.

Let’s return to solving the actual problem.
Let \text{ans}[i] denote the number of indices that can appear at position i.
Then, each u increases \text{ans}[i] by 1 for every i such that \text{dep}[u] \leq i \leq N-\text{sub}[u].
In other words, we just want to add 1 on some range!

Doing this in a brute-force fashion would be too slow, with a complexity of \mathcal{O}(N^2).
There are several ways to do it faster though: one of them is to use prefix sums.

For each u, add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1].
Then, take the prefix sums of the \text{ans} array and we’ll have the values we want!
This works because:

  • Consider how each addition/subtraction affects the prefix sum array we create.
  • When we add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1],
    • The prefix sum of any position \lt \text{dep}[u] is unaffected.
    • The prefix sum of every position between \text{dep}[u] and N-\text{sub}[u] increases by 1.
    • The prefix sum of every position after N-\text{sub}[u] is unaffected (since both the +1 and the -1 contribute to it).
  • In other words, we’ve effectively added 1 to the exact range of the prefix sum array we wanted to!

The overall time complexity is \mathcal{O}(N) since each step (root-finding, DFS, processing updates) is itself linear.

TIME COMPLEXITY

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

// -------------------- Input Checker Start --------------------

// This function reads a long long, character by character, and returns it as a whole long long. It makes sure that it lies in the range [l, r], and the character after the long long is endd. l and r should be in [-1e18, 1e18].
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            if(!(fi == -1))
                cerr << "- in between integer\n";
            assert(fi == -1);
            is_neg = true; // It's a negative integer
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0'; // fi is the first digit
            cnt++;
            
            // There shouldn't be leading zeroes. eg. "02" is not valid and assert will fail here.
            if(!(fi != 0 || cnt == 1))
                cerr << "Leading zeroes found\n";
            assert(fi != 0 || cnt == 1); 
            
            // "-0" is invalid
            if(!(fi != 0 || is_neg == false))
                cerr << "-0 found\n";
            assert(fi != 0 || is_neg == false); 
            
            // The maximum number of digits should be 19, and if it is 19 digits long, then the first digit should be a '1'.
            if(!(!(cnt > 19 || (cnt == 19 && fi > 1))))
                cerr << "Value greater than 1e18 found\n";
            assert(!(cnt > 19 || (cnt == 19 && fi > 1))); 
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                // We've reached the end, but the long long isn't in the right range.
                cerr << "Constraint violated: Lower Bound = " << l << " Upper Bound = " << r << " Violating Value = " << x << '\n'; 
                assert(false); 
            }
            return x;
        }
        else if((g == ' ') && (endd == '\n'))
        {
            cerr << "Extra space found. It should instead have been a new line.\n";
            assert(false);
        }
        else if((g == '\n') && (endd == ' '))
        {
            cerr << "A new line found where it should have been a space.\n";
            assert(false);
        }
        else
        {
            cerr << "Something weird has happened.\n";
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    if(!(l <= cnt && cnt <= r))
        cerr << "String length not within constraints\n";
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() 
{ 
    char g = getchar();
    if(g != EOF)
    {
        if(g == ' ')
            cerr << "Extra space found where the file shold have ended\n";
        if(g == '\n')
            cerr << "Extra newline found where the file shold have ended\n";
        else
            cerr << "File didn't end where expected\n";
    }
    assert(g == EOF); 
}

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

bool checkStringContents(string &s, char l, char r) {
    for(char x: s) {
        if (x < l || x > r) {
            cerr << "String is not valid\n";
            return false;
        }
    }
    return true;
}

bool isStringBinary(string &s) {
    return checkStringContents(s, '0', '1');
}

bool isStringLowerCase(string &s) {
    return checkStringContents(s, 'a', 'z');
}
bool isStringUpperCase(string &s) {
    return checkStringContents(s, 'A', 'Z');
}

bool isArrayDistinct(vector<int> a) {
    sort(a.begin(), a.end());
    for(int i = 1 ; i < a.size() ; ++i) {
        if (a[i] == a[i-1])
        return false;
    }
    return 1;
}

bool isPermutation(vector<int> &a) {
    int n = a.size();
    vector<int> done(n);
    for(int x: a) {
      if (x <= 0 || x > n || done[x-1]) {
        cerr << "Not a valid permutation\n";
        return false;
      }
      done[x-1]=1;
    }
    return true;
}

// -------------------- Input Checker End --------------------

struct DSU {
    std::vector<int> f, siz;
    
    DSU() {}
    DSU(int n) {
        init(n);
    }
    
    void init(int n) {
        f.resize(n);
        std::iota(f.begin(), f.end(), 0);
        siz.assign(n, 1);
    }
    
    int leader(int x) {
        while (x != f[x]) {
            x = f[x] = f[f[x]];
        }
        return x;
    }
    
    bool same(int x, int y) {
        return leader(x) == leader(y);
    }
    
    bool merge(int x, int y) {
        x = leader(x);
        y = leader(y);
        if (x == y) {
            return false;
        }
        siz[x] += siz[y];
        f[y] = x;
        return true;
    }
    
    int size(int x) {
        return siz[leader(x)];
    }
};




void numberOfNodes(int s, int e,vector<int>&count1,vector<vector<int>>&adj) 
{ 
    vector<int>::iterator u; 
    count1[s]=1;
    for (u = adj[s].begin(); u != adj[s].end(); u++) { 
           
        if (*u == e) 
            continue; 

        numberOfNodes(*u, s,count1,adj); 
        
        count1[s] += count1[*u]; 
    } 
} 

void solve()
{
    int n;
    n=readIntLn(1,1e5);
    vector<vector<int>>adj(n+1);
    vector<int>indegree(n+1,0);
    DSU d(n+1);
    for(int i=0;i<n-1;i++)
    {
        int u,v;
        u=readIntSp(1,n);
        v=readIntLn(1,n);
        assert(u!=v);
        assert(d.leader(u)!=d.leader(v));
        d.merge(u,v);
        adj[u].push_back(v);
        indegree[v]++;
    }
    int rt=d.leader(1);
    for(int i=1;i<=n;i++){
        assert(rt==d.leader(i));
    }
    int root;
    for(int i=1;i<=n;i++)
    {
        if(indegree[i]==0){
            root=i;
            break;
        }
    }
    vector<int>depth(n+1);
    depth[root]=1;
    queue<int>q;
    q.push(root);
    while(!q.empty())
    {
        int node=q.front();
        q.pop();
        for(auto child:adj[node])
        {
            depth[child]=depth[node]+1;
            q.push(child);
        }
    }
    vector<int>count1(n+1,0);
    numberOfNodes(root,0,count1,adj);
    vector<pair<int,int>>mxmn(n+1);
    for(int i=1;i<=n;i++)
    {
        mxmn[i]={depth[i],n-count1[i]+1};
    }
    vector<int>ans(n+1,0);
    for(int i=1;i<=n;i++)
    {
        ans[mxmn[i].first]++;
        if(mxmn[i].second!=n)ans[mxmn[i].second+1]--;
    }
    for(int i=1;i<=n;i++)
    {
        ans[i]+=ans[i-1];
    }
    for(int i=1;i<=n;i++)
    {
        cout<<ans[i]<<" ";
    }
    cout<<endl;
}

int main()
{
    int t;
    t=readIntLn(1,3000);
    while(t--)
    {
        solve();
    }
    readEOF();
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            assert(!isspace(buffer[pos]));
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 3000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        sn += n;
        vector<vector<int>> g(n);
        vector<int> deg(n);
        for (int i = 0; i < n - 1; i++) {
            int x = in.readInt(1, n);
            in.readSpace();
            int y = in.readInt(1, n);
            in.readEoln();
            x--;
            y--;
            assert(x != y);
            g[x].emplace_back(y);
            deg[y]++;
        }
        vector<int> sz(n, 1);
        vector<int> dep(n);
        function<void(int)> Dfs = [&](int v) {
            for (int to : g[v]) {
                dep[to] = dep[v] + 1;
                Dfs(to);
                sz[v] += sz[to];
            }
        };
        int r = -1;
        for (int i = 0; i < n; i++) {
            if (deg[i] == 0) {
                r = i;
            }
        }
        assert(r != -1);
        Dfs(r);
        assert(sz[r] == n);
        vector<int> ans(n + 1);
        for (int i = 0; i < n; i++) {
            ans[dep[i]] += 1;
            ans[n - sz[i] + 1] -= 1;
        }
        for (int i = 0; i < n; i++) {
            ans[i + 1] += ans[i];
            cout << ans[i] << " ";
        }
        cout << '\n';
    }
    assert(sn <= 1e5);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector adj(n, vector<int>());
        vector<int> mark(n, 1);
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            mark[v] = 0;
        }
        int root = 0;
        while (!mark[root]) ++root;

        vector<int> dep(n, 1), subsz(n);
        auto dfs = [&] (const auto &self, int u) -> void {
            for (int v : adj[u]) {
                dep[v] = 1 + dep[u];
                self(self, v);
                subsz[u] += subsz[v] + 1;
            }
        };
        dfs(dfs, root);
        vector<int> ans(n+1);
        for (int i = 0; i < n; ++i) {
            ++ans[dep[i]-1];
            --ans[n-subsz[i]];
        }
        for (int i = 1; i < n; ++i) ans[i] += ans[i-1];
        for (int i = 0; i < n; ++i) cout << ans[i] << ' ';
        cout << '\n';
    }
}
2 Likes

I got scared by the fact that finding number of possible Topsort is P-complete ;(

Runtime error in test #3 ? help anyone . SHeeeer

from collections import defaultdict
for _ in range(int(input())):
    n=int(input())
    root={x for x in range(1,n+1)}
    g=defaultdict(list)
    for _ in range(n-1):
        u,v=map(int,input().split())
        g[u].append(v)
        root.remove(v)
    # print(root.pop())
    root=root.pop()
    # print(root)
    levels={}
    below={}
    def dfs(node,c):
        levels[node]=c
        belows=0
        for ch in g[node]:
            belows+=dfs(ch,c+1)
            belows+=1
        below[node]=belows
        return belows
    dfs(root,0)
    arr=[0]* (n+1)
    for node in range(1,n+1):
        arr[levels[node]]+=1
        arr[n-below[node]]-=1
    final=[]
    sums=0
    for ele in arr:
        sums+=ele
        final.append(sums)
    final.pop()
    print(*final)