UNICOLOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Nguyen Anh Quân
Testers: Shubham Anand Jain, Aryan
Editorialist: Nishank Suresh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Interval handling, Graphs

PROBLEM:

N students stand in a line. Each of them will wear a shirt which can be in one of m colors. Further, some of the students are in one or more of C clubs, and every student in a club must wear a shirt of the same colour. Students in a club try to stand together, and thus each club is represented by a collection of disjoint intervals. In how many different ways can the students choose the colors of their shirts?

QUICK EXPLANATION:

  • Use a sweepline to find the maximal set of intervals, and the unions of the clubs within them.
  • Create a graph on C vertices, with an edge between two vertices u and v if there is a student in both clubs. Find the number of connected components in this graph, say x.
  • If there are y students not in any club, the answer is m^{x+y} modulo 998244353.

EXPLANATION:

If a student is in two clubs, say u and v, all students of both clubs must wear the same color. If there is also a student in clubs v and w, then all students who are in u, v, w will wear the same color.
The most natural way to represent this information is in the form of a graph.
Consider a graph with C vertices, with an edge between two vertices u and v if there is a student in both clubs. It is easy to see that if two vertices u and v are in the same connected component, all students of both corresponding clubs must wear the same color.
We are thus free to choose one color for each connected component, and then one color for each student who is not in any club (and is therefore unconstrained).
The problem then reduces to computing this graph, and the number of students who aren’t in any club, fast.

The Graph

A first idea would be to take each pair of intervals, and if they intersect, add an edge between the corresponding nodes in the graph.
However, this is too slow - there are upto 1000 \times 100 = 10^5 intervals, and this process is quadratic in the number of intervals.

Instead, we perform a line sweep, and find a maximal set of disjoint intervals.
Given the set of all intervals, sort them in increasing order of left endpoint (while also maintaining which interval corresponds to which club).
Now we ‘sweep’ from left to right by iterating over this sorted set of intervals.
Let the first interval be [l_1, r_1].
If l_2 \leq r_1, we can join [l_1, r_1] and [l_2, r_2] to get the larger [l_1, max(r_1, r_2)] (because l_2 \geq l_1).
This joining process can continue as long as the next interval’s left endpoint is \leq the right endpoint of the current interval.
What happens if, for some i , l_i is larger than the right endpoint of the current interval?
It means that the i th interval is disjoint from every interval considered so far - and because we are considering them in sorted order, every interval after the i th is also disjoint from all intervals upto i-1.
So, our current interval is maximal - it cannot be extended further.
Note that, because of how we constructed this maximal interval, there is some ‘path’ between any pair of intervals in it; and hence, the clubs they represent must all belong to the same connected component in the graph.

Now, say the clubs u_1, u_2, \dotsc, u_k are in the current maximal interval.
We would like them all to be in the same connected component of the graph.
However, adding an edge between each pair of them is once again too slow.
Note that we don’t actually care about the exact edges in the graph - we only care about connectivity information.
Thus, it suffices to add enough edges to connect these vertices in the graph - for example, one could add an edge between u_1 and each of u_2, u_3, \dotsc, u_k; or the edges (u_1, u_2), (u_2, u_3), \dotsc, (u_{k-1}, u_k).
Either way, only a linear number of edges are being added now, which is fast enough.

Clubless students

We get this information for free from the decomposition into disjoint intervals, from above!
Suppose the disjoint intervals are [L_1, R_1], [L_2, R_2], \dotsc, [L_k, R_k].
Then, any student not in one of these intervals is not in any club - and because they are disjoint, the number of such students is simply the total length of these intervals, subtracted from N, i.e, N - (R_1-L_1+1) - (R_2-L_2+1) - \dotsc - (R_k-L_k+1).

Once we have our graph, all that is left is to find the number of connected components in it, which can be done using any of DFS/BFS/DSU (full explanation at CP-Algorithms).

Finally, knowing x and y, print m^{x+y} \ \% 998244353.
Note that y (the number of students not in any club) can be quite large, upto 10^9. So, calculating this power requires binary exponentation (again, detailed explanation at CP-Algorithms).

TIME COMPLEXITY

\mathcal{O}(NlogN)

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
#define int long long
struct state{
    int location, type, club;
    bool operator < (const state & t)const{
        if(location == t.location && type == t.type) return club < t.club;
        if(location == t.location) return type < t.type;
        return location < t.location;
    }
};
const int MAXC = 1e3 + 10;
const int mod = 998244353;
int q;
int c, m, n;
vector<state> a;
vector<int> adj[MAXC];
int num[MAXC];
int check[MAXC];
int connected_components;
int Pow(int x, int y){
    if(y == 0)return 1;
    int k = Pow(x, y / 2);
    if(y % 2 == 1)return k * k % mod * x % mod;
    return k * k % mod;
}
void dfs(int x){
    check[x] = 1;
    for(int u: adj[x]){
        if(check[u])continue;
        dfs(u);
    }
}
void reset(){
    a.clear();
    connected_components = 0;
    for(int i = 1; i <= n; i++){
        adj[i].clear();
        check[i] = 0;
    }
}
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> q;
    while(q--){
        cin >> c >> n >> m;
        for(int i = 1; i <= c; i++){
            cin >> num[i];
            for(int j = 1; j <= num[i]; j++){
                int l, r;
                cin >> l >> r;
                a.push_back({l, 0, i});
                a.push_back({r, 1, i});
            }
        }
        sort(a.begin(), a.end());
        set<int> s;
        int clubless = n;
        int L = 0;
        int R = 0;
        for(state u: a){
            if(u.type == 0){
                if(s.empty()){
                    L = u.location;
                }
                s.insert(u.club);
            }
            else{
                s.erase(s.find(u.club));
                if(!s.empty()){
                    int current_club = u.club;
                    int other_club = *s.begin();
                    adj[current_club].push_back(other_club);
                    adj[other_club].push_back(current_club);
                }
                if(s.empty()){
                    R = u.location;
                    clubless -= (R - L + 1);
                }
            }
        }
        for(int i = 1; i <= c; i++){
            if(!check[i]){
                connected_components++;
                dfs(i);
            }
        }
        connected_components += clubless;
        int ans = Pow(m, connected_components);
        cout << ans << '\n';
        reset();
    }
}
Tester's Solution
//By TheOneYouWant
#pragma GCC optimize ("-O2")
#include <bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define forstl(i,v) for(auto &i: v)
 
const int LIM=2e5+5,MOD=998244353;
 
int link[LIM] = {0};
int sz[LIM] = {0};
 
int find(int x){
    if(x == link[x]) return x;
    return link[x] = find(link[x]);
}
 
void unite(int a, int b){
    a = find(a);
    b = find(b);
    if(a == b) return;
    if(sz[a]<sz[b]) swap(a,b);
    sz[a]+=sz[b];
    link[b] = a;
}
 
long long int fastpow(long long int a, long long int p){
    if(a == 0) return 0;
    if(p == 0) return 1;
    long long int z = fastpow(a, p/2);
    z = (z * z) % MOD;
    if(p % 2) z = (a * z) % MOD;
    return z;
}
 
int main(){
    fastio;
 
    int tests;
    cin>>tests;
 
    while(tests--){
        int c, n, m;
        cin>>c>>n>>m;
 
        for(int i = 0; i < c; i++){
            link[i] = i;
            sz[i] = 1;
        }
 
        vector<tuple<int,int,int>> v;
        vector<pair<int,int>> length[c];
 
        for(int i = 0; i < c; i++){
            int x; cin>>x;
            for(int j = 0; j < x; j++){
                int l, r; 
                cin>>l>>r;
                l--; r--;
                length[i].push_back(make_pair(l, r));
                v.push_back({l, r, i});
            }
        }
        sort(all(v));
 
        pair<int,int> mx1 = mp(-1, -1), mx2 = mp(-1, -1);
        forstl(k, v){
            int l, r, col;
            tie(l, r, col) = k;
            if(mx1.fi >= l && mx1.se != col){
                unite(mx1.se, col);
            }
            else if(mx2.fi >= l && mx2.se != col){
                unite(mx2.se, col);
            }
            int rm = r, cm = col;
            if(rm >= mx1.fi){
                swap(mx1.fi, rm);
                swap(mx1.se, cm);
            }
            if(rm >= mx2.fi && cm != mx1.se){
                mx2.fi = rm;
                mx2.se = cm;
            }
        }
        // created the DSU
        // now need to do interval handling
        vector<int> child[c];
 
        for(int i = 0; i < c; i++){
            int par = find(i);
            child[par].pb(i);
        }
        long long int tot_cov = 0;
        long long int num = 0;
        for(int i = 0; i < c; i++){
            // merge all intervals, and calculate the total size
            set<pair<int,int>> interv;
            long long int tot = 0;
            forstl(r, child[i]){
                // merge intervals of r
                forstl(k, length[r]){
                    int l = k.fi, r = k.se;
                    long long int chg = 0;
                    // remove intervals that intersect with left part first
                    auto it = interv.upper_bound(mp(l, -1));
                    if(it != interv.begin()){
                        it--;
                        if((*it).se >= l){
                            l = min(l, (*it).fi);
                            r = max(r, (*it).se);
                            chg -= ((*it).se - (*it).fi + 1);
                            interv.erase(it);
                        }
                    }
                    // remove the rest now
                    while(true){
                        auto it = interv.upper_bound(mp(l, -1));
                        if(it != interv.end() && (*it).fi <= r){
                            l = min(l, (*it).fi);
                            r = max(r, (*it).se);
                            chg -= ((*it).se - (*it).fi + 1);
                            interv.erase(it);
                            continue;
                        }
                        break;
                    }
                    chg += (r-l+1);
                    tot += chg;
                    interv.insert(mp(l, r));
                }
            }
            tot_cov += tot;
            if(tot > 0) num++;
        }
        long long int fin = num + n - tot_cov;
        cout<<fastpow(m, fin)<<endl;
    }
 
    return 0;
}
Editorialist's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,mmx,avx,avx2")
using namespace std;
using ll = long long;
 
/*
Standard DSU, path compression + union by size
0-indexed
Source: me
*/
struct DSU {
    vector<int> par;
    DSU(int n = 1): par(n, -1) {}
    int root(int u) {return par[u] < 0 ? u : par[u] = root(par[u]);}
    int size(int u) {return -par[root(u)];}
    bool merge(int u, int v) {
        u = root(u), v = root(v);
        if (u == v) return false;
        if (par[u] > par[v]) swap(u, v);
        par[u] += par[v], par[v] = u;
        return true;
    }
};
 
const int MOD = 998244353;
 
ll mpow(ll a, ll n)
{
    ll r = 1;
    while (n) {
        if (n&1) r = (r*a)%MOD;
        a = (a*a)%MOD;
        n >>= 1;
    }
    return r;
}
 
 
int main()
{
    ios::sync_with_stdio(0); cin.tie(0);
    mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
 
 
    int q; cin >> q;
    while (q--) {
        int c, n, m; cin >> c >> n >> m;
        vector<array<int, 3>> event;
        DSU D(c);
        for (int i = 0; i < c; ++i) {
            int x; cin >> x;
            for (int j = 0; j < x; ++j) {
                int l, r; cin >> l >> r;
                event.push_back({l, r, i});
            }
        }
        sort(begin(event), end(event));
        int rem = n, mx = -1, left = 0, comps = c;
        set<int> cur;
        for (auto [l, r, club] : event) {
            if (l <= mx) {
                mx = max(r, mx);
                cur.insert(club);
            }
            else {
                while (cur.size() > 1) {
                    int u = *cur.begin(); cur.erase(u);
                    int v = *cur.begin();
                    comps -= D.merge(u, v);
                }
                rem -= mx-left+1;
                left = l, mx = r;
                cur.clear(); cur.insert(club);
            }
        }
        while (cur.size() > 1) {
            int u = *cur.begin(); cur.erase(u);
            int v = *cur.begin();
            comps -= D.merge(u, v);
        }
        rem -= mx-left+1;
        cout << mpow(m, rem + comps) << '\n';
    }
}
9 Likes

Thanks for this decent problem! kudos to the Author :slightly_smiling_face:.

8 Likes

Hey, I used Disjoint set union to get number of connected components but gives TLE for Subtask 2. Can you check what am I missing ? Thank you :smiley:
https://www.codechef.com/viewsolution/44202978

1 Like

Hi, I did the algorithm as described, but got WA in whole second set. Where did I go wrong? Thanks.

https://www.codechef.com/viewsolution/44219270

You are creating a graph on the set of students and finding connected components in it. Take a look at the constraints - there can be upto 10^9 students, so that obviously won’t work.
The intended solution is to create a graph on the set of clubs instead, of which there are at most 1000.

3 Likes

@iceknight1093 I have done coordinate compression + DSU , can you please help me know why I am getting TLE
https://www.codechef.com/viewsolution/44210405

Oh yes! But let’s say I merge sets that have the same clubs to find x. How do I find y with a DSU implementation?
Or maybe I’m missing something entirely?

Good question. Had fun solving it!

1 Like

Can someone please give links to more problems based on Intervals.

2 Likes

Your computation of cnt doesn’t look right to me.
Try the following case

Click
1
3
5
2
1
1 5
1
2 2
1
3 3

The answer should be 2, you print 4.

1 Like

You also print 4 on the test case I shared in my previous comment, when the answer is supposed to be 2. Maybe you can use that to debug?

Assuming you implement things properly, you get y almost for free when computing x - please do read through the editorial, I’ve tried to explain in some detail.

1 Like

Yeah, thanks very much. I’ve solved that issue, but my new fix still seems to be incorrect. I’ll try to figure it out though. :slight_smile:

i tried using dsu on clubs , but im getting WA . is it possible for someone to suggest any test cases so that i can debug my code .
the link to my code is CodeChef: Practical coding for everyone . can you check what i am missing ?

Why I’m getting TLE in subtask 2? I’m constructing a graph of size at most 10^5 (C*x) and do BFS Traversal to find non-intersecting clubs.
code is here Solution: 44236788 | CodeChef it’s time complexity is O(c*x log(c*x)))

There were several issues with your code:

  1. There was a useless vi club(n + 1, 0) which is bad for when n \leq 10^9
  2. Your mul(x, y) function is too slow. while (prod >= MOD) doesn’t really work well because you might have to do \mathcal{O}(MOD) subtractions to get the result down - just take the mod directly instead.
  3. Your power(x, y) function is technically wrong as well, but it gets saved by C++ being smart.
  4. Fixing the above will make your code run fast enough, but it gets WA because you’re missing a case.
Try this input
1
1 4 2
1
1 1

A couple of my latest submissions fix your code, take a look at them if you like.

:neutral_face::relieved:that was my mistakes but,
ThankYou so much after fixing this I got fully accepted, thank you for helping.

You forgot to account for the students not in a club, before the first interval.

Yep found it…Thanks for the help. @iceknight1093

https://www.codechef.com/viewsolution/44272106

@iceknight1093 cam u plz explain why i m getting RE on subtask 2
i have used dsu with both path compression and rank by union and after doing the operations if the par node is -ve it signifies its the parent of a connected component so the count is increased if par[node]<0