COMBIISFUN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

3210

PREREQUISITES:

Combinatorics, counting connected graphs

PROBLEM:

Given N, find the number of pairs of arrays A and B of length N such that:

  • 1 \leq A_i, B_i \leq N
  • A_i \neq B_i
  • There exists a way to choose either A_i or B_i for each i such that all chosen values are distinct, i.e, form a permutation.

EXPLANATION:

Our first aim should be to characterize when a pair of arrays is ‘good’, i.e, when it’s possible to choose elements such that the final result is a permutation of \{1, 2, \ldots, N\}.

Consider a graph G on N vertices and N edges, where for each i we create an edge between A_i and B_i.
Claim: The pair (A, B) is good if and only if every component of G contains a cycle.
Equivalently, every component of G contains as many edges as it does vertices.

Proof

Suppose some connected component of G is a tree.
For each edge, we can choose only one of its endpoints.
This component has k vertices but only k-1 edges; so some value will be unchosen no matter what.
This means the final sequence will never be a permutation.

So, every component of G must contain a cycle.
For a component with a cycle, it’s easy to see that choosing every vertex is possible: pick some vertex u on a cycle, then build a spanning tree of the component rooted at u that excludes one of the cycle edges adjacent to it.
For each edge of the spanning tree, pick the child vertex - everything other than u will be chosen.
Finally, use the previously excluded edge to pick u.


So, instead of counting pairs of arrays, we can instead count graphs that satisfy this structure.

A preliminary idea is as follows:

  • Let f(N) denote the answer for N.
  • Fix the size of the component containing N, say k.
    Then, there are \binom{N-1}{k-1} ways to choose which elements other than N go into the component; and \binom{N}{k} ways to choose which positions of the array they represent.
    Next, there are f(N-k) ways to create the rest of the elements.
    Finally, let g(k) be the number of ways to arrange the k values we’ve chosen into a connected component.
    Putting everything together, we obtain
f(N) = \sum_{k=1}^N \binom{N-1}{k-1}\binom{N}{k}\cdot f(N-k)\cdot g(k)

So, if we knew all the g(k) values, the problem would be solved in \mathcal{O}(N^2) time easily.
Let’s focus on computing the values of g now.


g(N) denotes the number of connected labelled graphs with N nodes and N edges (with a couple extra factors specific to this problem for orderings, which we’ll come back to at the end).
With N vertices and N edges, the graph will look like a single cycle with some trees hanging off of it.

Let’s fix the size of the cycle, say x.
A_i = B_i is not allowed, meaning self-loops aren’t allowed. So, we have x \geq 2.
There are \binom{N}{x} ways to choose which vertices go into the cycle.
There are (x-1)! ways to arrange there vertices into a cycle; which is further divided by 2 if x \geq 3 because we don’t distinguish between a cycle and its reverse.

Finally, we have to arrange all the remaining vertices.
It turns out that there are exactly x\cdot N^{N-x-1} ways to do this.
This formula follows from the fact that we have with us N-x+1 components: one cycle of size k and N-x single vertices; and we want to add edges between them to connect them.
In general, the number of ways of adding edges between connected components of sizes s_1, s_2, \ldots, s_k to connect them is

s_1s_2\ldots s_k\cdot N^{k-2}

A proof of this fact can be found at the bottom of this page.

Finally, there are a couple more things to consider:

  • Each ‘edge’ represents the unordered pair (A_i, B_i), but we’re counting ordered pairs.
    There are two choices for each edge and N edges; so we get a factor of 2^N.
  • Next, when computing f(N) we accounted for choosing the positions we’re assigning to the connected component, but not their order.
    Accounting for that gives us an extra N! term.

Putting it together and reducing a couple of terms, we have

g(N) = \sum_{x=1}^N \binom{N}{x}x!\cdot N^{N-x-1} \cdot N! \cdot 2^{N-1}

Once again, it’s easy to compute this in \mathcal{O}(N) for a fixed N, and so \mathcal{O}(N^2) for all N.

Once all the g(N) values are known, all the f(N) values can be found in \mathcal{O}(N) as earlier and we’re done.

TIME COMPLEXITY

\mathcal{O}(N^2) per testcase.

CODE:

Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

int mod;
struct mi {
    int64_t v; explicit operator int64_t() const { return v % mod; }
    mi() { v = 0; }
    mi(int64_t _v) {
        v = (-mod < _v && _v < mod) ? _v : _v % mod;
        if (v < 0) v += mod;
    }
    friend bool operator==(const mi& a, const mi& b) {
        return a.v == b.v; }
    friend bool operator!=(const mi& a, const mi& b) {
        return !(a == b); }
    friend bool operator<(const mi& a, const mi& b) {
        return a.v < b.v; }

    mi& operator+=(const mi& m) {
        if ((v += m.v) >= mod) v -= mod;
        return *this; }
    mi& operator-=(const mi& m) {
        if ((v -= m.v) < 0) v += mod;
        return *this; }
    mi& operator*=(const mi& m) {
        v = v*m.v%mod; return *this; }
    mi& operator/=(const mi& m) { return (*this) *= inv(m); }
    friend mi pow(mi a, int64_t p) {
        mi ans = 1; assert(p >= 0);
        for (; p; p /= 2, a *= a) if (p&1) ans *= a;
        return ans;
    }
    friend mi inv(const mi& a) { assert(a.v != 0);
        return pow(a,mod-2); }

    mi operator-() const { return mi(-v); }
    mi& operator++() { return *this += 1; }
    mi& operator--() { return *this -= 1; }
    mi operator++(int32_t) { mi temp; temp.v = v++; return temp; }
    mi operator--(int32_t) { mi temp; temp.v = v--; return temp; }
    friend mi operator+(mi a, const mi& b) { return a += b; }
    friend mi operator-(mi a, const mi& b) { return a -= b; }
    friend mi operator*(mi a, const mi& b) { return a *= b; }
    friend mi operator/(mi a, const mi& b) { return a /= b; }
    friend ostream& operator<<(ostream& os, const mi& m) {
        os << m.v; return os;
    }
    friend istream& operator>>(istream& is, mi& m) {
        int64_t x; is >> x;
        m.v = x;
        return is;
    }
    friend void __print(const mi &x) {
        cerr << x.v;
    }
};


bool prime(int s) {
    for(int i = 2 ; i * i <= s ; i++) {
        if(s % i == 0)  return false;
    }
    return true;
}


int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int n = input.readInt(1, 2000);    input.readSpace();
    mod = input.readInt((int)1e8, (int)1e9);  input.readEoln();

    assert(prime(mod));
//    int n;  cin >> n >> mod;

    constexpr int maxn = 10000;
    vector<mi> fct(maxn, 1), invf(maxn, 1);
    auto calc_fact = [&]() {
        for(int i = 1 ; i < maxn ; i++) {
            fct[i] = fct[i - 1] * i;
        }
        invf.back() = mi(1) / fct.back();
        for(int i = maxn - 1 ; i ; i--)
            invf[i - 1] = i * invf[i];
    };

    auto choose = [&](int n, int r) { // choose r elements out of n elements
        if(r > n)   return mi(0);
        assert(r <= n);
        return fct[n] * invf[r] * invf[n - r];
    };

    auto place = [&](int n, int r) { // x1 + x2 ---- xr = n and limit value of xi >= n
        assert(r > 0);
        return choose(n + r - 1, r - 1);
    };

    calc_fact();

    vector<mi> g(n + 1), p2(n + 1);
    vector<mi> p(n + 1);    p[0] = 1;

    p2[0] = 1;
    for(int i = 0 ; i < n ; i++)    p2[i + 1] = p2[i] + p2[i];

    for(int i = 1 ; i <= n ; ++i) {
        for(int j = 1 ; j <= i ; j++)
            p[j] = p[j - 1] * i;
        for(int j = 2 ; j <= i ; ++j) {
            mi here = choose(i, j) * fct[j - 1];
            if(i > j)  here *= j * p[i - j - 1];
            g[i] += here * p2[i - 1] * fct[i];
        }
    }

    vector<mi> dp(n + 1), invs(n + 1);
    for(int i = 1;  i <= n ; i++)   invs[i] = mi(1) / i;
    dp[0] = 1;


    for(int i = 1 ; i <= n ; i++) {
        for(int j = 1 ; j <= i ; j++) {
            dp[i] += dp[i - j] * g[j] * choose(i - 1, j - 1) * choose(i, j);
        }
    }

    cout << dp[n] << '\n';

    input.readEof();

    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;


// constexpr int mod = 1e9 + 7; // 1000000007
// constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
// using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int n; cin >> n >> mod;
    vector C(n+5, vector(n+5, Zp(0))), pows(n+5, vector(n+5, Zp(1)));
    vector fac(n+5, Zp(1));
    for (int i = 0; i < n+5; ++i) C[i][0] = 1;
    for (int i = 1; i < n+5; ++i) {
        fac[i] = i*fac[i-1];
        for (int j = 1; j < n+5; ++j)
            C[i][j] = C[i-1][j] + C[i-1][j-1];
    }
    for (int p = 0; p < n+5; ++p) for (int k = 1; k < n+5; ++k)
        pows[p][k] = pows[p][k-1] * p;

    vector<Zp> comp_ways(n+1);
    for (int k = 1; k <= n; ++k) {
        for (int x = 2; x <= k; ++x) {
            Zp ways = C[k][x] * fac[x-1];
            ways *= x;
            if (x < k) ways *= pows[k][k-x-1];
            else ways /= k;
            comp_ways[k] += ways * fac[k] * pows[2][k-1];
        }
    }

    vector<Zp> dp(n+1, 0);
    dp[0] = 1;
    for (int i = 1; i <= n; ++i) {
        for (int k = 1; k <= i; ++k) {
            dp[i] += dp[i-k] * comp_ways[k] * C[i-1][k-1] * C[i][k];
        }
    }
    cout << dp[n] << '\n';
}