# FIZZBUZZ2307 - Editorial

Authors: naisheel, jalp1428
Tester: tabr
Editorialist: iceknight1093

2275

DFS, prefix sums

# PROBLEM:

You have a directed tree. For each i = 1, 2, 3, \ldots, N, find the number of distinct vertices that can occur at the i-th position of one of its topological orders.

# EXPLANATION:

First, find the root of the tree: it’s the only vertex with no incoming edges.

Let’s look at the problem from the opposite direction.
For each vertex u, which positions of a topological order can it appear in?

Observe that wherever u appears in topological order:

• All of its ancestors must appear before it.
• All of its descendants must appear after it.
• Other vertices don’t have any relation to it at all.

In particular, if \text{dep}[u] denotes the depth of u (with the root having a depth of 1) and \text{sub}[u] denotes the number of vertices in the subtree of u (not including u itself),

• u definitely can’t appear at positions 1, 2, 3, \ldots, \text{dep}[u]-1 because of the ancestor condition.
• u definitely can’t appear at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 because of the descendant condition.

Finding the \text{dep} and \text{sub} arrays is easily done with a DFS.

So, u can only appear at positions between \text{dep}[u] and N-\text{sub}[u].
It’s not hard to see that it can appear at all these positions!
To place u at some position x (where \text{dep}[u] \leq x \leq N-\text{sub}[u]) we can:

• Place the ancestors of u at positions 1, 2, 3, \ldots, \text{dep}[u]-1 in order from the root downwards.
• Place the descendants of u at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 in some valid order.
This is always possible because, as long as u appears in the order, its descendants only care about their order among themselves and can appear in any \text{sub}[u] positions.
• Place u at position x, which is empty.
This is also possible because, as noted above, u has no relations with anything other than its ancestors/descendants, so once they’re placed u can be freely placed anywhere between them.

Let \text{ans}[i] denote the number of indices that can appear at position i.
Then, each u increases \text{ans}[i] by 1 for every i such that \text{dep}[u] \leq i \leq N-\text{sub}[u].
In other words, we just want to add 1 on some range!

Doing this in a brute-force fashion would be too slow, with a complexity of \mathcal{O}(N^2).
There are several ways to do it faster though: one of them is to use prefix sums.

For each u, add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1].
Then, take the prefix sums of the \text{ans} array and we’ll have the values we want!
This works because:

• Consider how each addition/subtraction affects the prefix sum array we create.
• When we add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1],
• The prefix sum of any position \lt \text{dep}[u] is unaffected.
• The prefix sum of every position between \text{dep}[u] and N-\text{sub}[u] increases by 1.
• The prefix sum of every position after N-\text{sub}[u] is unaffected (since both the +1 and the -1 contribute to it).
• In other words, we’ve effectively added 1 to the exact range of the prefix sum array we wanted to!

The overall time complexity is \mathcal{O}(N) since each step (root-finding, DFS, processing updates) is itself linear.

# TIME COMPLEXITY

\mathcal{O}(N) per testcase.

# CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

// -------------------- Input Checker Start --------------------

// This function reads a long long, character by character, and returns it as a whole long long. It makes sure that it lies in the range [l, r], and the character after the long long is endd. l and r should be in [-1e18, 1e18].
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
if(!(fi == -1))
cerr << "- in between integer\n";
assert(fi == -1);
is_neg = true; // It's a negative integer
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0'; // fi is the first digit
cnt++;

// There shouldn't be leading zeroes. eg. "02" is not valid and assert will fail here.
if(!(fi != 0 || cnt == 1))
assert(fi != 0 || cnt == 1);

// "-0" is invalid
if(!(fi != 0 || is_neg == false))
cerr << "-0 found\n";
assert(fi != 0 || is_neg == false);

// The maximum number of digits should be 19, and if it is 19 digits long, then the first digit should be a '1'.
if(!(!(cnt > 19 || (cnt == 19 && fi > 1))))
cerr << "Value greater than 1e18 found\n";
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
// We've reached the end, but the long long isn't in the right range.
cerr << "Constraint violated: Lower Bound = " << l << " Upper Bound = " << r << " Violating Value = " << x << '\n';
assert(false);
}
return x;
}
else if((g == ' ') && (endd == '\n'))
{
cerr << "Extra space found. It should instead have been a new line.\n";
assert(false);
}
else if((g == '\n') && (endd == ' '))
{
cerr << "A new line found where it should have been a space.\n";
assert(false);
}
else
{
cerr << "Something weird has happened.\n";
assert(false);
}
}
}

string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();

assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
if(!(l <= cnt && cnt <= r))
cerr << "String length not within constraints\n";
assert(l <= cnt && cnt <= r);
return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
{
char g = getchar();
if(g != EOF)
{
if(g == ' ')
cerr << "Extra space found where the file shold have ended\n";
if(g == '\n')
cerr << "Extra newline found where the file shold have ended\n";
else
cerr << "File didn't end where expected\n";
}
assert(g == EOF);
}

vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[n - 1] = readIntLn(l, r);
return a;
}

bool checkStringContents(string &s, char l, char r) {
for(char x: s) {
if (x < l || x > r) {
cerr << "String is not valid\n";
return false;
}
}
return true;
}

bool isStringBinary(string &s) {
return checkStringContents(s, '0', '1');
}

bool isStringLowerCase(string &s) {
return checkStringContents(s, 'a', 'z');
}
bool isStringUpperCase(string &s) {
return checkStringContents(s, 'A', 'Z');
}

bool isArrayDistinct(vector<int> a) {
sort(a.begin(), a.end());
for(int i = 1 ; i < a.size() ; ++i) {
if (a[i] == a[i-1])
return false;
}
return 1;
}

bool isPermutation(vector<int> &a) {
int n = a.size();
vector<int> done(n);
for(int x: a) {
if (x <= 0 || x > n || done[x-1]) {
cerr << "Not a valid permutation\n";
return false;
}
done[x-1]=1;
}
return true;
}

// -------------------- Input Checker End --------------------

struct DSU {
std::vector<int> f, siz;

DSU() {}
DSU(int n) {
init(n);
}

void init(int n) {
f.resize(n);
std::iota(f.begin(), f.end(), 0);
siz.assign(n, 1);
}

while (x != f[x]) {
x = f[x] = f[f[x]];
}
return x;
}

bool same(int x, int y) {
}

bool merge(int x, int y) {
if (x == y) {
return false;
}
siz[x] += siz[y];
f[y] = x;
return true;
}

int size(int x) {
}
};

{
vector<int>::iterator u;
count1[s]=1;

if (*u == e)
continue;

count1[s] += count1[*u];
}
}

void solve()
{
int n;
vector<int>indegree(n+1,0);
DSU d(n+1);
for(int i=0;i<n-1;i++)
{
int u,v;
assert(u!=v);
d.merge(u,v);
indegree[v]++;
}
for(int i=1;i<=n;i++){
}
int root;
for(int i=1;i<=n;i++)
{
if(indegree[i]==0){
root=i;
break;
}
}
vector<int>depth(n+1);
depth[root]=1;
queue<int>q;
q.push(root);
while(!q.empty())
{
int node=q.front();
q.pop();
{
depth[child]=depth[node]+1;
q.push(child);
}
}
vector<int>count1(n+1,0);
vector<pair<int,int>>mxmn(n+1);
for(int i=1;i<=n;i++)
{
mxmn[i]={depth[i],n-count1[i]+1};
}
vector<int>ans(n+1,0);
for(int i=1;i<=n;i++)
{
ans[mxmn[i].first]++;
if(mxmn[i].second!=n)ans[mxmn[i].second+1]--;
}
for(int i=1;i<=n;i++)
{
ans[i]+=ans[i-1];
}
for(int i=1;i<=n;i++)
{
cout<<ans[i]<<" ";
}
cout<<endl;
}

int main()
{
int t;
while(t--)
{
solve();
}
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
assert(!isspace(buffer[pos]));
pos++;
}
return res;
}

string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
assert(min_val <= res);
assert(res <= max_val);
return res;
}

long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
assert(min_val <= res);
assert(res <= max_val);
return res;
}

vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
if (i != size - 1) {
}
}
return res;
}

vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
if (i != size - 1) {
}
}
return res;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

int main() {
input_checker in;
int sn = 0;
while (tt--) {
sn += n;
vector<vector<int>> g(n);
vector<int> deg(n);
for (int i = 0; i < n - 1; i++) {
x--;
y--;
assert(x != y);
g[x].emplace_back(y);
deg[y]++;
}
vector<int> sz(n, 1);
vector<int> dep(n);
function<void(int)> Dfs = [&](int v) {
for (int to : g[v]) {
dep[to] = dep[v] + 1;
Dfs(to);
sz[v] += sz[to];
}
};
int r = -1;
for (int i = 0; i < n; i++) {
if (deg[i] == 0) {
r = i;
}
}
assert(r != -1);
Dfs(r);
assert(sz[r] == n);
vector<int> ans(n + 1);
for (int i = 0; i < n; i++) {
ans[dep[i]] += 1;
ans[n - sz[i] + 1] -= 1;
}
for (int i = 0; i < n; i++) {
ans[i + 1] += ans[i];
cout << ans[i] << " ";
}
cout << '\n';
}
assert(sn <= 1e5);
return 0;
}

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> mark(n, 1);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
mark[v] = 0;
}
int root = 0;
while (!mark[root]) ++root;

vector<int> dep(n, 1), subsz(n);
auto dfs = [&] (const auto &self, int u) -> void {
for (int v : adj[u]) {
dep[v] = 1 + dep[u];
self(self, v);
subsz[u] += subsz[v] + 1;
}
};
dfs(dfs, root);
vector<int> ans(n+1);
for (int i = 0; i < n; ++i) {
++ans[dep[i]-1];
--ans[n-subsz[i]];
}
for (int i = 1; i < n; ++i) ans[i] += ans[i-1];
for (int i = 0; i < n; ++i) cout << ans[i] << ' ';
cout << '\n';
}
}

2 Likes

I got scared by the fact that finding number of possible Topsort is P-complete ;(

Runtime error in test #3 ? help anyone . SHeeeer

from collections import defaultdict
for _ in range(int(input())):
n=int(input())
root={x for x in range(1,n+1)}
g=defaultdict(list)
for _ in range(n-1):
u,v=map(int,input().split())
g[u].append(v)
root.remove(v)
# print(root.pop())
root=root.pop()
# print(root)
levels={}
below={}
def dfs(node,c):
levels[node]=c
belows=0
for ch in g[node]:
belows+=dfs(ch,c+1)
belows+=1
below[node]=belows
return belows
dfs(root,0)
arr=[0]* (n+1)
for node in range(1,n+1):
arr[levels[node]]+=1
arr[n-below[node]]-=1
final=[]
sums=0
for ele in arr:
sums+=ele
final.append(sums)
final.pop()
print(*final)