PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Author: Nguyen Anh Quân
Testers: Shubham Anand Jain, Aryan
Editorialist: Nishank Suresh
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Interval handling, Graphs
PROBLEM:
N students stand in a line. Each of them will wear a shirt which can be in one of m colors. Further, some of the students are in one or more of C clubs, and every student in a club must wear a shirt of the same colour. Students in a club try to stand together, and thus each club is represented by a collection of disjoint intervals. In how many different ways can the students choose the colors of their shirts?
QUICK EXPLANATION:
- Use a sweepline to find the maximal set of intervals, and the unions of the clubs within them.
- Create a graph on C vertices, with an edge between two vertices u and v if there is a student in both clubs. Find the number of connected components in this graph, say x.
- If there are y students not in any club, the answer is m^{x+y} modulo 998244353.
EXPLANATION:
If a student is in two clubs, say u and v, all students of both clubs must wear the same color. If there is also a student in clubs v and w, then all students who are in u, v, w will wear the same color.
The most natural way to represent this information is in the form of a graph.
Consider a graph with C vertices, with an edge between two vertices u and v if there is a student in both clubs. It is easy to see that if two vertices u and v are in the same connected component, all students of both corresponding clubs must wear the same color.
We are thus free to choose one color for each connected component, and then one color for each student who is not in any club (and is therefore unconstrained).
The problem then reduces to computing this graph, and the number of students who aren’t in any club, fast.
The Graph
A first idea would be to take each pair of intervals, and if they intersect, add an edge between the corresponding nodes in the graph.
However, this is too slow - there are upto 1000 \times 100 = 10^5 intervals, and this process is quadratic in the number of intervals.
Instead, we perform a line sweep, and find a maximal set of disjoint intervals.
Given the set of all intervals, sort them in increasing order of left endpoint (while also maintaining which interval corresponds to which club).
Now we ‘sweep’ from left to right by iterating over this sorted set of intervals.
Let the first interval be [l_1, r_1].
If l_2 \leq r_1, we can join [l_1, r_1] and [l_2, r_2] to get the larger [l_1, max(r_1, r_2)] (because l_2 \geq l_1).
This joining process can continue as long as the next interval’s left endpoint is \leq the right endpoint of the current interval.
What happens if, for some i , l_i is larger than the right endpoint of the current interval?
It means that the i th interval is disjoint from every interval considered so far - and because we are considering them in sorted order, every interval after the i th is also disjoint from all intervals upto i-1.
So, our current interval is maximal - it cannot be extended further.
Note that, because of how we constructed this maximal interval, there is some ‘path’ between any pair of intervals in it; and hence, the clubs they represent must all belong to the same connected component in the graph.
Now, say the clubs u_1, u_2, \dotsc, u_k are in the current maximal interval.
We would like them all to be in the same connected component of the graph.
However, adding an edge between each pair of them is once again too slow.
Note that we don’t actually care about the exact edges in the graph - we only care about connectivity information.
Thus, it suffices to add enough edges to connect these vertices in the graph - for example, one could add an edge between u_1 and each of u_2, u_3, \dotsc, u_k; or the edges (u_1, u_2), (u_2, u_3), \dotsc, (u_{k-1}, u_k).
Either way, only a linear number of edges are being added now, which is fast enough.
Clubless students
We get this information for free from the decomposition into disjoint intervals, from above!
Suppose the disjoint intervals are [L_1, R_1], [L_2, R_2], \dotsc, [L_k, R_k].
Then, any student not in one of these intervals is not in any club - and because they are disjoint, the number of such students is simply the total length of these intervals, subtracted from N, i.e, N - (R_1-L_1+1) - (R_2-L_2+1) - \dotsc - (R_k-L_k+1).
Once we have our graph, all that is left is to find the number of connected components in it, which can be done using any of DFS/BFS/DSU (full explanation at CP-Algorithms).
Finally, knowing x and y, print m^{x+y} \ \% 998244353.
Note that y (the number of students not in any club) can be quite large, upto 10^9. So, calculating this power requires binary exponentation (again, detailed explanation at CP-Algorithms).
TIME COMPLEXITY
\mathcal{O}(NlogN)
SOLUTIONS:
Setter's Solution
#include<bits/stdc++.h>
using namespace std;
#define int long long
struct state{
int location, type, club;
bool operator < (const state & t)const{
if(location == t.location && type == t.type) return club < t.club;
if(location == t.location) return type < t.type;
return location < t.location;
}
};
const int MAXC = 1e3 + 10;
const int mod = 998244353;
int q;
int c, m, n;
vector<state> a;
vector<int> adj[MAXC];
int num[MAXC];
int check[MAXC];
int connected_components;
int Pow(int x, int y){
if(y == 0)return 1;
int k = Pow(x, y / 2);
if(y % 2 == 1)return k * k % mod * x % mod;
return k * k % mod;
}
void dfs(int x){
check[x] = 1;
for(int u: adj[x]){
if(check[u])continue;
dfs(u);
}
}
void reset(){
a.clear();
connected_components = 0;
for(int i = 1; i <= n; i++){
adj[i].clear();
check[i] = 0;
}
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> q;
while(q--){
cin >> c >> n >> m;
for(int i = 1; i <= c; i++){
cin >> num[i];
for(int j = 1; j <= num[i]; j++){
int l, r;
cin >> l >> r;
a.push_back({l, 0, i});
a.push_back({r, 1, i});
}
}
sort(a.begin(), a.end());
set<int> s;
int clubless = n;
int L = 0;
int R = 0;
for(state u: a){
if(u.type == 0){
if(s.empty()){
L = u.location;
}
s.insert(u.club);
}
else{
s.erase(s.find(u.club));
if(!s.empty()){
int current_club = u.club;
int other_club = *s.begin();
adj[current_club].push_back(other_club);
adj[other_club].push_back(current_club);
}
if(s.empty()){
R = u.location;
clubless -= (R - L + 1);
}
}
}
for(int i = 1; i <= c; i++){
if(!check[i]){
connected_components++;
dfs(i);
}
}
connected_components += clubless;
int ans = Pow(m, connected_components);
cout << ans << '\n';
reset();
}
}
Tester's Solution
//By TheOneYouWant
#pragma GCC optimize ("-O2")
#include <bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define forstl(i,v) for(auto &i: v)
const int LIM=2e5+5,MOD=998244353;
int link[LIM] = {0};
int sz[LIM] = {0};
int find(int x){
if(x == link[x]) return x;
return link[x] = find(link[x]);
}
void unite(int a, int b){
a = find(a);
b = find(b);
if(a == b) return;
if(sz[a]<sz[b]) swap(a,b);
sz[a]+=sz[b];
link[b] = a;
}
long long int fastpow(long long int a, long long int p){
if(a == 0) return 0;
if(p == 0) return 1;
long long int z = fastpow(a, p/2);
z = (z * z) % MOD;
if(p % 2) z = (a * z) % MOD;
return z;
}
int main(){
fastio;
int tests;
cin>>tests;
while(tests--){
int c, n, m;
cin>>c>>n>>m;
for(int i = 0; i < c; i++){
link[i] = i;
sz[i] = 1;
}
vector<tuple<int,int,int>> v;
vector<pair<int,int>> length[c];
for(int i = 0; i < c; i++){
int x; cin>>x;
for(int j = 0; j < x; j++){
int l, r;
cin>>l>>r;
l--; r--;
length[i].push_back(make_pair(l, r));
v.push_back({l, r, i});
}
}
sort(all(v));
pair<int,int> mx1 = mp(-1, -1), mx2 = mp(-1, -1);
forstl(k, v){
int l, r, col;
tie(l, r, col) = k;
if(mx1.fi >= l && mx1.se != col){
unite(mx1.se, col);
}
else if(mx2.fi >= l && mx2.se != col){
unite(mx2.se, col);
}
int rm = r, cm = col;
if(rm >= mx1.fi){
swap(mx1.fi, rm);
swap(mx1.se, cm);
}
if(rm >= mx2.fi && cm != mx1.se){
mx2.fi = rm;
mx2.se = cm;
}
}
// created the DSU
// now need to do interval handling
vector<int> child[c];
for(int i = 0; i < c; i++){
int par = find(i);
child[par].pb(i);
}
long long int tot_cov = 0;
long long int num = 0;
for(int i = 0; i < c; i++){
// merge all intervals, and calculate the total size
set<pair<int,int>> interv;
long long int tot = 0;
forstl(r, child[i]){
// merge intervals of r
forstl(k, length[r]){
int l = k.fi, r = k.se;
long long int chg = 0;
// remove intervals that intersect with left part first
auto it = interv.upper_bound(mp(l, -1));
if(it != interv.begin()){
it--;
if((*it).se >= l){
l = min(l, (*it).fi);
r = max(r, (*it).se);
chg -= ((*it).se - (*it).fi + 1);
interv.erase(it);
}
}
// remove the rest now
while(true){
auto it = interv.upper_bound(mp(l, -1));
if(it != interv.end() && (*it).fi <= r){
l = min(l, (*it).fi);
r = max(r, (*it).se);
chg -= ((*it).se - (*it).fi + 1);
interv.erase(it);
continue;
}
break;
}
chg += (r-l+1);
tot += chg;
interv.insert(mp(l, r));
}
}
tot_cov += tot;
if(tot > 0) num++;
}
long long int fin = num + n - tot_cov;
cout<<fastpow(m, fin)<<endl;
}
return 0;
}
Editorialist's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long;
/*
Standard DSU, path compression + union by size
0-indexed
Source: me
*/
struct DSU {
vector<int> par;
DSU(int n = 1): par(n, -1) {}
int root(int u) {return par[u] < 0 ? u : par[u] = root(par[u]);}
int size(int u) {return -par[root(u)];}
bool merge(int u, int v) {
u = root(u), v = root(v);
if (u == v) return false;
if (par[u] > par[v]) swap(u, v);
par[u] += par[v], par[v] = u;
return true;
}
};
const int MOD = 998244353;
ll mpow(ll a, ll n)
{
ll r = 1;
while (n) {
if (n&1) r = (r*a)%MOD;
a = (a*a)%MOD;
n >>= 1;
}
return r;
}
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int q; cin >> q;
while (q--) {
int c, n, m; cin >> c >> n >> m;
vector<array<int, 3>> event;
DSU D(c);
for (int i = 0; i < c; ++i) {
int x; cin >> x;
for (int j = 0; j < x; ++j) {
int l, r; cin >> l >> r;
event.push_back({l, r, i});
}
}
sort(begin(event), end(event));
int rem = n, mx = -1, left = 0, comps = c;
set<int> cur;
for (auto [l, r, club] : event) {
if (l <= mx) {
mx = max(r, mx);
cur.insert(club);
}
else {
while (cur.size() > 1) {
int u = *cur.begin(); cur.erase(u);
int v = *cur.begin();
comps -= D.merge(u, v);
}
rem -= mx-left+1;
left = l, mx = r;
cur.clear(); cur.insert(club);
}
}
while (cur.size() > 1) {
int u = *cur.begin(); cur.erase(u);
int v = *cur.begin();
comps -= D.merge(u, v);
}
rem -= mx-left+1;
cout << mpow(m, rem + comps) << '\n';
}
}