PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Authors: naisheel, jalp1428
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
2275
PREREQUISITES:
DFS, prefix sums
PROBLEM:
You have a directed tree. For each i = 1, 2, 3, \ldots, N, find the number of distinct vertices that can occur at the i-th position of one of its topological orders.
EXPLANATION:
First, find the root of the tree: it’s the only vertex with no incoming edges.
Let’s look at the problem from the opposite direction.
For each vertex u, which positions of a topological order can it appear in?
Observe that wherever u appears in topological order:
- All of its ancestors must appear before it.
- All of its descendants must appear after it.
- Other vertices don’t have any relation to it at all.
In particular, if \text{dep}[u] denotes the depth of u (with the root having a depth of 1) and \text{sub}[u] denotes the number of vertices in the subtree of u (not including u itself),
- u definitely can’t appear at positions 1, 2, 3, \ldots, \text{dep}[u]-1 because of the ancestor condition.
- u definitely can’t appear at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 because of the descendant condition.
Finding the \text{dep} and \text{sub} arrays is easily done with a DFS.
So, u can only appear at positions between \text{dep}[u] and N-\text{sub}[u].
It’s not hard to see that it can appear at all these positions!
To place u at some position x (where \text{dep}[u] \leq x \leq N-\text{sub}[u]) we can:
- Place the ancestors of u at positions 1, 2, 3, \ldots, \text{dep}[u]-1 in order from the root downwards.
- Place the descendants of u at positions N, N-1, N-2, \ldots, N-\text{sub}[u]+1 in some valid order.
This is always possible because, as long as u appears in the order, its descendants only care about their order among themselves and can appear in any \text{sub}[u] positions. - Place u at position x, which is empty.
This is also possible because, as noted above, u has no relations with anything other than its ancestors/descendants, so once they’re placed u can be freely placed anywhere between them.
Let’s return to solving the actual problem.
Let \text{ans}[i] denote the number of indices that can appear at position i.
Then, each u increases \text{ans}[i] by 1 for every i such that \text{dep}[u] \leq i \leq N-\text{sub}[u].
In other words, we just want to add 1 on some range!
Doing this in a brute-force fashion would be too slow, with a complexity of \mathcal{O}(N^2).
There are several ways to do it faster though: one of them is to use prefix sums.
For each u, add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1].
Then, take the prefix sums of the \text{ans} array and we’ll have the values we want!
This works because:
- Consider how each addition/subtraction affects the prefix sum array we create.
- When we add 1 to \text{ans[dep[}u]] and subtract 1 from \text{ans}[N-\text{sub}[u]+1],
- The prefix sum of any position \lt \text{dep}[u] is unaffected.
- The prefix sum of every position between \text{dep}[u] and N-\text{sub}[u] increases by 1.
- The prefix sum of every position after N-\text{sub}[u] is unaffected (since both the +1 and the -1 contribute to it).
- In other words, we’ve effectively added 1 to the exact range of the prefix sum array we wanted to!
The overall time complexity is \mathcal{O}(N) since each step (root-finding, DFS, processing updates) is itself linear.
TIME COMPLEXITY
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
// -------------------- Input Checker Start --------------------
// This function reads a long long, character by character, and returns it as a whole long long. It makes sure that it lies in the range [l, r], and the character after the long long is endd. l and r should be in [-1e18, 1e18].
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
if(!(fi == -1))
cerr << "- in between integer\n";
assert(fi == -1);
is_neg = true; // It's a negative integer
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0'; // fi is the first digit
cnt++;
// There shouldn't be leading zeroes. eg. "02" is not valid and assert will fail here.
if(!(fi != 0 || cnt == 1))
cerr << "Leading zeroes found\n";
assert(fi != 0 || cnt == 1);
// "-0" is invalid
if(!(fi != 0 || is_neg == false))
cerr << "-0 found\n";
assert(fi != 0 || is_neg == false);
// The maximum number of digits should be 19, and if it is 19 digits long, then the first digit should be a '1'.
if(!(!(cnt > 19 || (cnt == 19 && fi > 1))))
cerr << "Value greater than 1e18 found\n";
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
// We've reached the end, but the long long isn't in the right range.
cerr << "Constraint violated: Lower Bound = " << l << " Upper Bound = " << r << " Violating Value = " << x << '\n';
assert(false);
}
return x;
}
else if((g == ' ') && (endd == '\n'))
{
cerr << "Extra space found. It should instead have been a new line.\n";
assert(false);
}
else if((g == '\n') && (endd == ' '))
{
cerr << "A new line found where it should have been a space.\n";
assert(false);
}
else
{
cerr << "Something weird has happened.\n";
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
if(!(l <= cnt && cnt <= r))
cerr << "String length not within constraints\n";
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF()
{
char g = getchar();
if(g != EOF)
{
if(g == ' ')
cerr << "Extra space found where the file shold have ended\n";
if(g == '\n')
cerr << "Extra newline found where the file shold have ended\n";
else
cerr << "File didn't end where expected\n";
}
assert(g == EOF);
}
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
bool checkStringContents(string &s, char l, char r) {
for(char x: s) {
if (x < l || x > r) {
cerr << "String is not valid\n";
return false;
}
}
return true;
}
bool isStringBinary(string &s) {
return checkStringContents(s, '0', '1');
}
bool isStringLowerCase(string &s) {
return checkStringContents(s, 'a', 'z');
}
bool isStringUpperCase(string &s) {
return checkStringContents(s, 'A', 'Z');
}
bool isArrayDistinct(vector<int> a) {
sort(a.begin(), a.end());
for(int i = 1 ; i < a.size() ; ++i) {
if (a[i] == a[i-1])
return false;
}
return 1;
}
bool isPermutation(vector<int> &a) {
int n = a.size();
vector<int> done(n);
for(int x: a) {
if (x <= 0 || x > n || done[x-1]) {
cerr << "Not a valid permutation\n";
return false;
}
done[x-1]=1;
}
return true;
}
// -------------------- Input Checker End --------------------
struct DSU {
std::vector<int> f, siz;
DSU() {}
DSU(int n) {
init(n);
}
void init(int n) {
f.resize(n);
std::iota(f.begin(), f.end(), 0);
siz.assign(n, 1);
}
int leader(int x) {
while (x != f[x]) {
x = f[x] = f[f[x]];
}
return x;
}
bool same(int x, int y) {
return leader(x) == leader(y);
}
bool merge(int x, int y) {
x = leader(x);
y = leader(y);
if (x == y) {
return false;
}
siz[x] += siz[y];
f[y] = x;
return true;
}
int size(int x) {
return siz[leader(x)];
}
};
void numberOfNodes(int s, int e,vector<int>&count1,vector<vector<int>>&adj)
{
vector<int>::iterator u;
count1[s]=1;
for (u = adj[s].begin(); u != adj[s].end(); u++) {
if (*u == e)
continue;
numberOfNodes(*u, s,count1,adj);
count1[s] += count1[*u];
}
}
void solve()
{
int n;
n=readIntLn(1,1e5);
vector<vector<int>>adj(n+1);
vector<int>indegree(n+1,0);
DSU d(n+1);
for(int i=0;i<n-1;i++)
{
int u,v;
u=readIntSp(1,n);
v=readIntLn(1,n);
assert(u!=v);
assert(d.leader(u)!=d.leader(v));
d.merge(u,v);
adj[u].push_back(v);
indegree[v]++;
}
int rt=d.leader(1);
for(int i=1;i<=n;i++){
assert(rt==d.leader(i));
}
int root;
for(int i=1;i<=n;i++)
{
if(indegree[i]==0){
root=i;
break;
}
}
vector<int>depth(n+1);
depth[root]=1;
queue<int>q;
q.push(root);
while(!q.empty())
{
int node=q.front();
q.pop();
for(auto child:adj[node])
{
depth[child]=depth[node]+1;
q.push(child);
}
}
vector<int>count1(n+1,0);
numberOfNodes(root,0,count1,adj);
vector<pair<int,int>>mxmn(n+1);
for(int i=1;i<=n;i++)
{
mxmn[i]={depth[i],n-count1[i]+1};
}
vector<int>ans(n+1,0);
for(int i=1;i<=n;i++)
{
ans[mxmn[i].first]++;
if(mxmn[i].second!=n)ans[mxmn[i].second+1]--;
}
for(int i=1;i<=n;i++)
{
ans[i]+=ans[i-1];
}
for(int i=1;i<=n;i++)
{
cout<<ans[i]<<" ";
}
cout<<endl;
}
int main()
{
int t;
t=readIntLn(1,3000);
while(t--)
{
solve();
}
readEOF();
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
assert(!isspace(buffer[pos]));
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int tt = in.readInt(1, 3000);
in.readEoln();
int sn = 0;
while (tt--) {
int n = in.readInt(1, 1e5);
in.readEoln();
sn += n;
vector<vector<int>> g(n);
vector<int> deg(n);
for (int i = 0; i < n - 1; i++) {
int x = in.readInt(1, n);
in.readSpace();
int y = in.readInt(1, n);
in.readEoln();
x--;
y--;
assert(x != y);
g[x].emplace_back(y);
deg[y]++;
}
vector<int> sz(n, 1);
vector<int> dep(n);
function<void(int)> Dfs = [&](int v) {
for (int to : g[v]) {
dep[to] = dep[v] + 1;
Dfs(to);
sz[v] += sz[to];
}
};
int r = -1;
for (int i = 0; i < n; i++) {
if (deg[i] == 0) {
r = i;
}
}
assert(r != -1);
Dfs(r);
assert(sz[r] == n);
vector<int> ans(n + 1);
for (int i = 0; i < n; i++) {
ans[dep[i]] += 1;
ans[n - sz[i] + 1] -= 1;
}
for (int i = 0; i < n; i++) {
ans[i + 1] += ans[i];
cout << ans[i] << " ";
}
cout << '\n';
}
assert(sn <= 1e5);
in.readEof();
return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector adj(n, vector<int>());
vector<int> mark(n, 1);
for (int i = 0; i < n-1; ++i) {
int u, v; cin >> u >> v;
adj[--u].push_back(--v);
mark[v] = 0;
}
int root = 0;
while (!mark[root]) ++root;
vector<int> dep(n, 1), subsz(n);
auto dfs = [&] (const auto &self, int u) -> void {
for (int v : adj[u]) {
dep[v] = 1 + dep[u];
self(self, v);
subsz[u] += subsz[v] + 1;
}
};
dfs(dfs, root);
vector<int> ans(n+1);
for (int i = 0; i < n; ++i) {
++ans[dep[i]-1];
--ans[n-subsz[i]];
}
for (int i = 1; i < n; ++i) ans[i] += ans[i-1];
for (int i = 0; i < n; ++i) cout << ans[i] << ' ';
cout << '\n';
}
}