# SPOOKYSEQ - Editorial

Author: varshil27
Tester: raysh07
Editorialist: iceknight1093

2304

# PREREQUISITES:

DFS/BFS/DSU, combinatorics

# PROBLEM:

N people have M friendships among them. Each person also has a strength A_i.
Friendship is transitive.

Find the number of orders S of these people such that, for any two friends i and j such that A_i \lt A_j, i appears before j in S.

# EXPLANATION:

The M friendships define an undirected graph among the N people.
Further, the friendship relation being transitive really just splits the people up based on their connected components: i and j are friends with each other if and only if theyâ€™re in the same connected component.

So, as a first step, find all connected components of the graph.
This can be done in a variety of ways: depth-first search, breadth-first search, or even a DSU.

Now, observe that:

• Two different components donâ€™t interfere with each othersâ€™ orders at all.
• Within a single connected component, the people must be ordered by increasing strength.
However, for a fixed strength within this component, the people can be ordered in any way.

Letâ€™s first find the number of arrangements for a single component.
This is not too hard: as noted above, the only choice we have at all is to move around people with the same strength.
So, if there are \text{ct}[x] people with frequency x in a given component, the number of ways to arrange them is (\text{ct}[x])!, the factorial of \text{ct}[x].
So, for a single component, the total number of arrangements is the product of (\text{ct}[x])! across all x that appear in the component.
If the component is known, this is quite easy to compute: build a frequency table of all the elements in the component, then directly find the product of the required factorials.

Next, we need to think about interactions between components.
Suppose there are k components, with sizes s_1, s_2, \ldots, s_k.
Suppose weâ€™ve also fixed an order of elements for each of the components.
How many ways do we have to arrange them into an overall order?

Recall that thereâ€™s no interaction between components at all.
So, the only thing that matters is which positions are chosen by the different components.

Thinking of it differently, we have k types of balls, and s_i of the i-th type. Weâ€™d like to find the number of ways to arrange these balls in a line, where balls of the same type arenâ€™t distinguished.
That number is just

\frac{N!}{s_1! s_2! \ldots s_k!}

One way of seeing this is as follows:

• There are \binom{N}{s_1} ways to choose which s_1 positions the first type occupies.
• There are \binom{N-s_1}{s_2} ways to chose the positions of the second type.
• There are \binom{N-s_1-s_2}{s_3} ways for the third type, and so on till \binom{s_k}{s_k} ways for the k-th type.
Multiplying out everything and cancelling out factorials from the numerator/denominator will give the above expression.

Note that this problem involves division under modulo.
That cannot be done directly, and must instead be done using the concept of modular inverses.

# TIME COMPLEXITY

\mathcal{O}(N\log N) per testcase.

# CODE:

Author's code (C++)
//Code by Varshil Kavathiya

#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;

/*
****************************************************************************************************
*/

#define ll          long long
#define ld          long double
#define vll         vector<long long>
#define mll         map<long long,long long>
#define umll        unordered_map<ll,ll,custom_hash>
#define ss          second
#define ff          first
#define bs          binary_search
#define lb          lower_bound
#define ub          upper_bound
#define all(x)      x.begin(), x.end()
#define rep(i,n)    for(ll i=0;i<n;++i)
#define rep1(i,n)   for(ll i=1;i<n;++i)
#define tt          for(ll TT = 1; TT <= tc ; TT++)
#define pb          push_back
#define ppb         pop_back
#define mkp         make_pair
#define sqrt        sqrtl
#define cntSetBits  __builtin_popcountll
#define Tzeros      __builtin_ctzll
#define Lzeros      __builtin_clzll
#define endl        '\n'
#define iendl       '\n', cout<<flush
#define fast        ios_base::sync_with_stdio(false);cin.tie(NULL); cout.tie(NULL);
#define timetaken cerr<<fixed<<setprecision(10); cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl
const ll INF =      8e18;
const ll mod =      1000000007;
ll tc =             1;
const ll N =        200005;
const int dx[4] = { -1, 1, 0, 0}; const int dy[4] = {0, 0, -1, 1};
inline ll power(ll x, unsigned ll y, ll p = LLONG_MAX) {ll res = 1; x = x % p; if (x == 0) {return 0;} while (y > 0) { if (y & 1) {res = (res * x) % p;} y = y >> 1; x = (x * x) % p;} return res;} // CALCULATING POWER IN LOG(Y) TIME COMPLEXITY
inline ll inversePrimeModular(ll a, ll p) {return power(a, p - 2, p);}
ll mod_add(ll a, ll b, ll mod) {a = a % mod; b = b % mod; return (((a + b) % mod) + mod) % mod;}
ll mod_mul(ll a, ll b, ll mod) {a = a % mod; b = b % mod; return (((a * b) % mod) + mod) % mod;}
ll mod_sub(ll a, ll b, ll mod) {a = a % mod; b = b % mod; return (((a - b) % mod) + mod) % mod;}
struct custom_hash {static uint64_t splitmix64(uint64_t x) {x += 0x9e3779b97f4a7c15; x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9; x = (x ^ (x >> 27)) * 0x94d049bb133111eb; return x ^ (x >> 31);} size_t operator()(uint64_t x) const {static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count(); return splitmix64(x + FIXED_RANDOM);}};
ll gcd(ll a, ll b) {if (b > a) {return gcd(b, a);} if (b == 0) {return a;} return gcd(b, a % b);}
ll lcm(ll a, ll b) {return ((a / gcd(a, b)) * b);}
template<class T, class V>istream& operator>>(istream &in, pair<T, V> &a) {in >> a.ff >> a.ss; return in;}
template<class T>istream& operator>>(istream &in, vector<T> &a) {for (auto &i : a) {in >> i;} return in;}
template<class T, class V>ostream& operator<<(ostream &os, pair<T, V> &a) {os << a.ff << " " << a.ss; return os;}
template<class T>ostream& operator<<(ostream &os, vector<T> &a) {for (int i = 0 ; i < a.size() ; i++) {if (i != 0) {os << ' ';} os << a[i];} return os;}
#define ordered_set tree<ll, null_type,less<ll>, rb_tree_tag,tree_order_statistics_node_update>
// ifdef->hide & ifndef->show
#ifndef ONLINE_JUDGE
#include "debug.cpp"
#define dbg(x...) cerr << #x << ": "; __(x)
#else
#define dbg(x...)
#endif
/*
****************************************************************************************************
*/
long long readInt(long long l, long long r, char endd) {
long long x = 0;
int cnt = 0;
int fi = -1;
bool is_neg = false;
while (true) {
char g = getchar();
if (g == '-') {
assert(fi == -1);
is_neg = true;
continue;
}
if ('0' <= g && g <= '9') {
x *= 10;
x += g - '0';
if (cnt == 0) {
fi = g - '0';
}
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);

assert(!(cnt > 19 || ( cnt == 19 && fi > 1) ));
} else if (g == endd) {
if (is_neg) {
x = -x;
}

if (!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret = "";
int cnt = 0;
while (true) {
char g = getchar();
assert(g != -1);
if (g == endd) {
break;
}
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) {
}
long long readIntLn(long long l, long long r) {
}
string readStringLn(int l, int r) {
}
string readStringSp(int l, int r) {
}
const ll fac_size = 500005;
vector<ll> fac(fac_size + 1);
vector<ll> inv(fac_size + 1);
void calcFact()
{
fac[0] = 1;
for (ll i = 1; i <= fac_size; i++)
{
fac[i] = mod_mul(fac[i - 1], i, mod);
}

inv[fac_size] = inversePrimeModular(fac[fac_size], mod);

for (ll i = fac_size; i > 0; i--)
{
inv[i - 1] = mod_mul(inv[i], i, mod);
}
}
{
vis[i] = 1;
store.pb(i);

{
if(vis[x]==0)
{
}
}
}
void solve()
{
ll n = readInt(1, 2e5, ' ');
ll m = readInt(0, 2e5, '\n');

vector<ll> store;
rep(i,m)
{
ll a = readInt(1, 2e5, ' ');
ll b = readInt(1, 2e5, '\n');
a--;
b--;
}
vector<ll> s(n);
rep(i,n)
{
if(i==n-1)
{
}
else
{
s[i] = readInt(1, 1e9, ' ');
}
}
vector<ll> vis(n);
ll ans = fac[n];
rep(i,n)
{
if(vis[i]==0)
{
ans *= inv[store.size()];
ans %= mod;
map<ll, ll> cnt;
for(auto&x:store)
{
cnt[s[x]]++;
}
for(auto&x:cnt)
{
ans *= fac[x.ss];
ans %= mod;
}
store.clear();
}
}
cout<<ans<<endl;
}
/*
****************************************************************************************************
*/

int32_t main()
{
fast;
cout << setprecision(30);
calcFact();
tt
{
// cout << "#Case: " << TT << endl;
solve();
}
timetaken;
return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}

assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}

string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int minv, int maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

const int facN = 1e6 + 5;
const int mod = 1e9 + 7; // 998244353
int ff[facN], iff[facN];
bool facinit = false;

int power(int x, int y){
if (y == 0) return 1;

int v = power(x, y / 2);
v = 1LL * v * v % mod;

if (y & 1) return 1LL * v * x % mod;
else return v;
}

void factorialinit(){
facinit = true;
ff[0] = iff[0] = 1;

for (int i = 1; i < facN; i++){
ff[i] = 1LL * ff[i - 1] * i % mod;
}

iff[facN - 1] = power(ff[facN - 1], mod - 2);
for (int i = facN - 2; i >= 1; i--){
iff[i] = 1LL * iff[i + 1] * (i + 1) % mod;
}
}

int C(int n, int r){
if (!facinit) factorialinit();

if (n == r) return 1;

if (r < 0 || r > n) return 0;
return 1LL * ff[n] * iff[r] % mod * iff[n - r] % mod;
}

int P(int n, int r){
if (!facinit) factorialinit();

assert(0 <= r && r <= n);
return 1LL * ff[n] * iff[n - r] % mod;
}

int Solutions(int n, int r){
//solutions to x1 + ... + xn = r
//xi >= 0

return C(n + r - 1, n - 1);
}

input_checker inp;
int sum_n = 0, sum_m = 0;

void Solve()
{

assert(2 * m <= n * (n - 1));

sum_n += n;
sum_m += m;

assert(sum_n <= (int)2e5);
assert(sum_m <= (int)2e5);

set <pair<int, int>> st;

for (int i = 1; i <= m; i++){
int u, v;

assert(u != v);
assert(st.find({u, v}) == st.end());
st.insert({u, v});
st.insert({v, u});

u--;
v--;

}

auto a = inp.readInts(n, 1, (int)1e9);

int ans = ff[n];

vector <bool> vis(n, false);
for (int i = 0; i < n; i++){
if (!vis[i]){
queue <int> q;
q.push(i);
vis[i] = true;
vector <int> b;
b.push_back(a[i]);

while (!q.empty()){
int u = q.front(); q.pop();

for (int v : adj[u]) if (!vis[v]){
q.push(v);
vis[v] = true;
b.push_back(a[v]);
}
}

map <int, int> mp;
for (auto x : b) mp[x]++;

for (auto [x, y] : mp){
ans *= ff[y]; ans %= mod;
}

ans *= iff[(int)b.size()];
ans %= mod;
}
}

cout << ans << "\n";
}

int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in",  "r", stdin);
// freopen("out", "w", stdout);

factorialinit();

for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}

auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}

Editorialist's code (Python)
class DisjointSetUnion:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.num_sets = n

def find(self, a):
acopy = a
while a != self.parent[a]:
a = self.parent[a]
while acopy != a:
self.parent[acopy], acopy = a, self.parent[acopy]
return a

def union(self, a, b):
a, b = self.find(a), self.find(b)
if a != b:
if self.size[a] < self.size[b]:
a, b = b, a

self.num_sets -= 1
self.parent[b] = a
self.size[a] += self.size[b]

def set_size(self, a):
return self.size[self.find(a)]

def __len__(self):
return self.num_sets

mod = 10**9 + 7
maxn = 2*10**5 + 5
fac = [1]*(maxn)
for i in range(2, maxn): fac[i] = fac[i-1] * i % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * pow(fac[r] * fac[n-r], mod-2, mod) % mod

for _ in range(int(input())):
n, m = map(int, input().split())
D = DisjointSetUnion(n)
for i in range(m):
u, v = map(int, input().split())
D.union(u-1, v-1)
a = list(map(int, input().split()))
comps = [ [] for _ in range(n)]
for i in range(n): comps[D.find(i)].append(a[i])
rem = n
ans = 1
for i in range(n):
freq = {}
for x in comps[i]:
if x not in freq: freq[x] = 0
freq[x] += 1
ans = ans * C(rem, len(comps[i])) % mod
rem -= len(comps[i])
for x in freq.values():
ans = ans * fac[x] % mod
print(ans)

1 Like

why am i getting WA ?
https://www.codechef.com/viewsolution/1027210120

Your solution is almost perfect, but you made only one mistake in the implementation.
On line 368, par[i] is the parent of i of the UnionFind tree, but it may not be the leader(root) of i.
For example, if par[0] = 1, par[1] = 2 and par[2] = 2, the leader(root) of the group containing 0 is 2, but par[i] is 1.

1 Like

Why is this giving WA?my_code

1 Like

Your func(n, r) function involves division, which canâ€™t be done directly under modulo.
This is mentioned in the last section of the editorial: