QUICKEXIT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

DFS, Dynamic Programming

PROBLEM:

You are given a tree with one person standing at each vertex. Person has strength P_i, all P_i are distinct.

The following process repeats:

  • The person at the root leaves the tree.
  • Then, as long as there’s an empty vertex with some child that isn’t empty, the person with maximum strength from a child will move one step up.

You’re at vertex N with strength K.
Find the minimum possible time that you can leave the tree, and the number of arrangements that achieve this minimum.

EXPLANATION:

This will continue from the editorial to the easy version.

Recall that v_i were the vertices on the path from 1 to N, and s was the number of siblings.
We had two cases: for when s \leq K-1 and otherwise.

We’ll do the counting for those cases separately as well.
In contrast to simply finding the answer here, the simpler case is in fact when s \gt K-1, i.e, there’s extra delay.

Case 1: s \gt K-1

Solution

Recall that in this case, the situation was as follows: of the s siblings, the K-1 of them with the largest subtree sizes received values from 1 to K-1, and we couldn’t do anything about the rest.

When counting assignments, we need to ensure that this is preserved - i.e, the largest K-1 subtrees are assigned the values 1 through K-1.

Suppose the subtree sizes are X_1 \geq X_2 \geq \ldots \geq X_s.

Then, any subtree sizes that are \gt S_{K-1} must receive a value from 1 to K-1.
Subtree sizes smaller than S_{K-1} will never receive a small value.
As for subtree sizes that equal S_{K-1}, some of them will receive a small value and some of them won’t - but it doesn’t really matter which of them receive the small values, since they’re all functionally equivalent.

Suppose, there are y_1 subtrees with size S_{K-1}, and in the initial arrangement, y_2 of them received small values.
Then, we can really choose any y_2 of these y_1 values, giving us \binom{y_1}{y_2} choices in total.

After assigning the vertices that receive small values, the small values themselves can be permuted among themselves, and the large values can be permuted among themselves, so the answer is simply

\binom{y_1}{y_2} \cdot (K-1)! \cdot (N-K)!

Case 2: s \leq K-1

Solution

Recall that in this case, we need to ensure that every vertex on the 1\to N path never gets delayed.
So, for each vertex on the path, it must have a larger value than all the path-siblings at or above its own level.

This gives us the motivation for what to count: we can try to place values from N upwards to 1, and when we’re at vertex u, what matters is:

  1. The current minimum value placed on the path from u to N (since we must ensure all higher path-siblings obtain a value less than this minimum).
  2. The number of available elements that are less than this minimum.
  3. The number of available elements that are greater than this minimum.

Let f(u, x, y) denote the number of ways of assigning elements to path and path-sibling vertices from N till u, such that there are x “small” and y “large” elements available.
f(u, x, y) assumes that the value of u has been assigned, but the siblings of u don’t have values yet.

First, all the siblings of u must receive “small” values.
If there are s_u siblings, there are \binom{x}{s_u} ways to choose them and s_u! ways to arrange them.

Next, let v be the parent of u. Let’s try to assign the value of v.
We have two cases:

  1. v receives a “large” value.
    • Here, it doesn’t matter which of the large values v receives, so there are y options.
    • Since v and the siblings of u are satisfied, we can move up to v now.
    • The state for v is now f(v, x - s_u, y - 1), since we used s_u small values and 1 large value.
    • Since there are y choices for the large element, we add \binom{x}{s_u} \cdot s_u! \cdot y \cdot f(u, x, y) to f(v, x - s_u, y - 1).
  2. v receives a “small” value.
    • Now, this value will be the minimum on the path, so it does matter which value v receives.
    • Suppose v receives the m-th largest among the remaining x - s_u small values.
    • Then, there will be m-1 small values and x+y-s_u-m large values moving forward; since some values that were previously small now become large, with the reduction of the minimum.
    • So, for each m from 1 to x - s_u, we add \binom{x}{s_u} \cdot s_u! \cdot f(u, x, y) to f(v, m-1, x+y-m-s_u).

The final answer is the sum of f(1, x, y) across all x, y.

Now, our function f has a three-parameter state, and can have \mathcal{O}(N) transitions in many of those states too, so it appears that we have an overall complexity of \mathcal{O}(N^2).

However, note that the only base state is f(N, K-1,N-K) = 1.
This, in combination with the way the transitions are, means that in fact only \mathcal{O}(N^2) states will have non-zero values at all.

To see which ones: note that from f(u, x, y), we always move to some f(v, x', y') with x'+y' = x+y-s_u-1.
Since we start with a single non-zero state, the only non-zero states overall are the ones where x+y is some constant based on the level of u — in particular, x+y will equal N minus the number of vertices whose values have been assigned; which makes sense since any non-assigned value must be either large or small.

We now have a complexity of \mathcal{O}(N^3).
In particular, the state f(u, x, y) can be written as simply f(u, x), with y being implicit.

Now, notice that when we have \mathcal{O}(N) transitions, it’s of the form “add f(u, x) to f(v, m-1) for all $m \leq x - s_u”.
There are O(N) such range updates at each level applying to the next, and processing them all can be done in O(N) using a difference array, which cuts out another factor of N from the complexity.
(Note that depending on the direction you compute the DP, you might want a range sum instead, which is doable quickly with prefix sums).

We’re now down to a complexity of \mathcal{O}(N^2) which is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

const int MAXN = 3005;
const int MOD = 998244353;
int C[MAXN][MAXN];
int fac[MAXN];

int dp[MAXN][MAXN];
int pref[MAXN][MAXN];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    for (int i = 0; i < MAXN; ++i) {
        C[i][0] = 1;
        fac[i] = max(i, 1);
        if (i > 0) fac[i] = 1LL * fac[i-1] * fac[i] % MOD;
        for (int j = 1; j <= i; ++j)
            C[i][j] = (C[i-1][j] + C[i-1][j-1]) % MOD;
    }
 
    int t; cin >> t;
    while (t--) {
        int n, x; cin >> n >> x;

        vector adj(n+1, vector<int>());
        for (int i = 1; i < n; ++i) {
            int u, v; cin >> u >> v;
            adj[u].push_back(v);
            adj[v].push_back(u);
        }

        vector<int> par(n+1), children(n+1), subsz(n+1);
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            par[u] = p;
            subsz[u] = 1;
            for (int v : adj[u]) if (v != p) {
                self(self, v, u);
                subsz[u] += subsz[v];
                ++children[u];
            }
        };
        dfs(dfs, 1, 0);

        if (x == 1) {
            cout << n - subsz[n] + 1 << ' ' << fac[n-1] << '\n';
            continue;
        }

        auto solve = [&] () {
            vector<int> onpath(n+1);
            int important = 0;
            {
                int cur = n;
                while (cur > 1) {
                    onpath[cur] = 1;
                    important += children[par[cur]] - 1;
                    cur = par[cur];
                }
                onpath[cur] = 1;
            }

            if (important >= x - 1) {
                // Distribute small values to the important nodes with largest subtree sizes
                vector<int> subsizes;
                for (int i = 2; i < n; ++i) {
                    if (!onpath[i] and onpath[par[i]] and par[i] != n) {
                        subsizes.push_back(subsz[i]);
                    }
                }
                sort(rbegin(subsizes), rend(subsizes));

                int delay = accumulate(begin(subsizes) + x-1, end(subsizes), 0);
                cout << accumulate(begin(onpath), end(onpath), 0) + delay << ' ';

                int sp = subsizes[x-2];
                int have = count(begin(subsizes), end(subsizes), sp);
                int want = count(begin(subsizes), begin(subsizes) + (x-1), sp);
                
                int ways = C[have][want];
                ways = 1LL * ways * fac[x - 1] % MOD;
                ways = 1LL * ways * fac[n - x] % MOD;
                cout << ways << '\n';
                return;
            }

            // Other case: there's enough small things
            // dp[u][i][j] -> I've processed path till u, there are i small and j big things remaining
            // Optimization: i+j is a constant when u is fixed, so store only dp[u][i]
            
            auto rec = [&] (const auto &self, int u, int small, int large) -> int {
                if (small < 0) return 0;
                // if (large < 0) return dp[u][small] = 0;
                if (u == 1) return dp[u][small] = 1;
                if (dp[u][small] != -1) return dp[u][small];
                auto &res = dp[u][small];
                res = 0;

                int siblings = children[par[u]] - 1;
                if (siblings > small) return 0;
                int to_siblings = 1LL * C[small][siblings] * fac[siblings] % MOD;
                int above = 0;

                // Give parent something > mine
                above = 1LL * large * self(self, par[u], small - siblings, large - 1) % MOD;

                // Give parent something < mine
                // Optimization: we want the sum of dp[par[u]][i - 1] across all i such that:
                // 1 <= i <= small - siblings
                // So, pref[par[u]][small - siblings - 1]
                
                self(self, par[u], small - siblings - 1, large);
                if (small > siblings) {
                    if (pref[par[u]][small - siblings - 1] == -1) pref[par[u]][small - siblings - 1] = 0;
                    if (par[u] > 1) above = (above + pref[par[u]][small - siblings - 1]) % MOD;
                    else above = (above + small - siblings) % MOD;
                }
                res = 1LL * to_siblings * above % MOD;
                if (large < 0) res = 0;
                
                self(self, u, small - 1, large + 1);
                pref[u][small] = res;
                if (small > 0 and pref[u][small - 1] != -1) pref[u][small] = (pref[u][small] + pref[u][small - 1]) % MOD;
                return res;
            };

            for (int i = 1; i <= n; ++i) for (int j = 0; j <= n; ++j)
                dp[i][j] = pref[i][j] = -1;
            
            int ways = rec(rec, n, x-1, n-x);
            int remain = n - accumulate(begin(onpath), end(onpath), 0) - important;
            ways = 1LL * ways * fac[remain] % MOD;
            cout << accumulate(begin(onpath), end(onpath), 0) << ' ' << ways << '\n';
        };

        solve();
    }
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

void solve(istringstream cin) {
    int n, k;
    cin >> n >> k;
    vector<vector<int>> g(n);
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        cin >> x >> y;
        x--;
        y--;
        g[x].emplace_back(y);
        g[y].emplace_back(x);
    }
    vector<int> p(n, -1), sz(n, 1), dep(n);
    vector<vector<int>> c(n);
    function<void(int)> dfs = [&](int v) {
        for (int to : g[v]) {
            if (to == p[v]) {
                continue;
            }
            p[to] = v;
            c[v].emplace_back(to);
            dep[to] = dep[v] + 1;
            dfs(to);
            sz[v] += sz[to];
        }
    };
    dfs(0);
    int cnt = 0;
    {
        int v = p[n - 1];
        while (v != -1) {
            cnt += int(c[v].size()) - 1;
            v = p[v];
        }
    }
    C(n, 0);
    if (k - 1 <= cnt) {
        vector<int> t;
        {
            int v = n - 1;
            while (p[v] != -1) {
                for (int u : c[p[v]]) {
                    if (u != v) {
                        t.emplace_back(sz[u]);
                    }
                }
                v = p[v];
            }
        }
        sort(t.rbegin(), t.rend());
        int d = dep[n - 1] + 1;
        int c0 = 0, c1 = 0;
        if (k == 1) {
            d = n - sz[n - 1] + 1;
        } else {
            for (int i = 0; i < cnt; i++) {
                if (i >= k - 1) {
                    d += t[i];
                }
                if (t[i] == t[k - 2]) {
                    c0++;
                    if (i <= k - 2) {
                        c1++;
                    }
                }
            }
        }
        cout << d << " " << C(c0, c1) * fact[k - 1] * fact[n - k] << '\n';
    } else {
        int v = n - 1;
        vector<mint> dp(k);
        dp[k - 1] = 1;
        int e = 1;
        while (v != 0) {
            vector<mint> new_dp(k);
            int d = int(c[p[v]].size()) - 1;
            e += d + 1;
            mint sum = 0;
            for (int i = k - 1 - d; i >= 0; i--) {
                new_dp[i] = sum;
                new_dp[i] += dp[i + d] * C(i + d, d) * fact[d] * max(0, n - e + 1 - i);
                sum += dp[i + d] * C(i + d, d) * fact[d];
            }
            v = p[v];
            swap(dp, new_dp);
        }
        cout << dep[n - 1] + 1 << " " << accumulate(dp.begin(), dp.end(), mint(0)) * fact[n - e] << '\n';
    }
}

////////////////////////////////////////

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct dsu {
    vector<int> p;
    vector<int> sz;
    int n;

    dsu(int _n) : n(_n) {
        p = vector<int>(n);
        iota(p.begin(), p.end(), 0);
        sz = vector<int>(n, 1);
    }

    inline int get(int x) {
        if (p[x] == x) {
            return x;
        } else {
            return p[x] = get(p[x]);
        }
    }

    inline bool unite(int x, int y) {
        x = get(x);
        y = get(y);
        if (x == y) {
            return false;
        }
        p[x] = y;
        sz[y] += sz[x];
        return true;
    }

    inline bool same(int x, int y) {
        return (get(x) == get(y));
    }

    inline int size(int x) {
        return sz[get(x)];
    }

    inline bool root(int x) {
        return (x == get(x));
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 3000);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readEoln();
        sn += n;
        vector<int> u(n - 1), v(n - 1);
        dsu uf(n);
        for (int i = 0; i < n - 1; i++) {
            u[i] = in.readInt(1, n);
            in.readSpace();
            v[i] = in.readInt(1, n);
            in.readEoln();
            assert(uf.unite(u[i] - 1, v[i] - 1));
        }
        ostringstream sout;
        sout << n << " " << k << '\n';
        for (int i = 0; i < n - 1; i++) {
            sout << u[i] << " " << v[i] << '\n';
        }
        solve(istringstream(sout.str()));
    }
    cerr << sn << endl;
    assert(sn <= 3000);
    in.readEof();
    return 0;
}