COWSHEDS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Nghia Pham
Tester: Trung Nguyen
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Segment Tree, Hashing, Disjoint Set Union and Small to Large.

PROBLEM

There are N nodes numbered from 1 to N. A group of nodes is defined as the set of nodes, which are reachable from each other directly or indirectly through other nodes.

Initially, all N nodes are disconnected. For next Q days, Kuroni adds connections in the following manner

  • For each day, choose an interval [L_i, R_i]
  • For each k \in [0, R_i-L_i], add a direct connection between node L_i+k and node R_i-k

Print the number of groups after each day.

QUICK EXPLANATION

  • The maximum number of times two groups are merged is N-1, as the number of groups reduces by one every time two different groups are merged.
  • For each day, we try to find the first position L_i+k in the range [L_i, R_i] such that it is not connected to node R_i-k. For this, build a segment tree over dsu array, comparing forward and reverse hashes and using binary search to find the value of k. Repeat the above process till there’s no valid k
  • For each root of dsu, maintain a list of nodes having that root. Use Small to Large trick to update both lists and the hashes stored in the segment tree.

EXPLANATION

The first thing we can see is that for N nodes, there can be at most N-1 connections, which reduce the number of connected components. The rest connections only add cycles in existing connected components, thus not affecting the number of groups.

So, for each day, we need to efficiently find the connections which merge different groups, as the number of merges is bounded by N-1.

Let us use the Disjoint Set Union to represent the connected components.

For operation [L_i, R_i], we need to find the smallest k such that L_i+k and R_i-k are in different connected components. We can also visualize it similar to strings, the above query appears as “Make string in the range [L_i, R_i] a palindrome” (This is just for visualization.)

Now, to find the smallest position of L_i+k, we can use hashing. Let’s compute hashes over this dsu array, and it’s reverse. The interval [L_i, R_i] shall correspond to [LL_i, RR_i] = [N-R_i+1, N-L_i+1] interval in reverse array (assuming 1-based indexing). This allows us to binary search over k and compares hashes of range [L_i, L_i+k-1] in the forward array and [LL_i, LL_i+k-1] to obtain minimum k such that [L_i, L_i+k-1] and [LL_i, LL_i+k-1] differ, meaning L_i+k-1 is the smallest position where L_i+k-1 and R_i-k+1 are in different components.

Now, we need to handle merge-operations. Every time there’s a merge operation, we also need to update our forward and backward hashes, so we need to keep list of positions rooted at same node (DSU terminology). We also need to perform merges in small-to-large manner, otherwise, merging lists can go up to O(N^2). As for updating hashes, since each element is updated at most O(log(N)) times due to small to large, we can use point updates to update hashes, which we maintain via segment tree.

This is all we need to do. This was the author’s intended solution.

Tester’s Solution
I really liked the tester’s unique approach to this problem, which goes as follows.

Firstly, make another N nodes numbered from N to 2*N-1 (using 0-based indexing, so first N nodes are from 0 to N-1) and connect node i with node 2*N-i-1. This way, we can rewrite update [L, R] as for each k \in [0, R-L], connect node L+k with node 2*N-R-1 +k.

Now, we can see that we just need to connect each node L_i+k in the range [L_i, R_i] with the node 2*N-R_i-1+k. Let’s write D = 2*N-R_i-1-L_i, so now, let’s represent operation as [L_i, R_i, D] as connect each node p in range [L_i, R_i] with p+D. Another thing is that applying same operation multiple times doesn’t change anything.

Suppose L = R_i-L_i+1 and let’s assume 2^k \leq L and 2^{k+1} > L for some k. Then operation [L_i, R_i, D] can be written as two operations [L_i, L_i+2^k-1, D] and [R_i-(2^k-1), R_i, D]

Now that we have each operation having interval length as a power of two, let’s denote each length as a level, so an update [L_i, k, D] represents [L_i, L_i+2^k-1, D]. Each update can be split into [L_i, L_i+2^{k-1}-1, D] and [L_i+2^{k-1}, L_i+2^k-1, D].

Now, for each level from the largest k to smallest k, run iterate over operations in order of input and try to split into two updates of lower-level. If the first position of any update is already merged, the whole interval is already considered for merging in some previous update, so we do not add it again.

This is the most important part, as above ensures we have to process only max(Q, 2*N-1) merge operations and there is log(N) levels.

In the end, we get the N-1 updates, spread across some days. We can now brute force and when we get a merge, we reduce the number of connected components.

TIME COMPLEXITY ANALYSIS

Author’s solution
O(N*logN+Q*log^2N)

Tester’s solution
O((Q+N)*log(N) as each level is processed in O(Q+N) time.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>

using namespace std;

const int maxn = 500005;

int mod[2];
int basepow[2][maxn];

void prepare_mod(int t, int x, int y){
  basepow[t][0] = 1;
  mod[t] = y;
  for(int i = 1; i < maxn; ++i) {
	basepow[t][i] = 1LL * basepow[t][i - 1] * x % mod[t];
  }
}


struct Hash{
  int type;
  int length;
  int hash_value;
  Hash(){}
  Hash(int x, int y, int z): type(x), length(y), hash_value(z){}
  Hash operator +(const Hash &rhs) const{
	return Hash(type, length + rhs.length, (1LL * hash_value * basepow[type][rhs.length] + rhs.hash_value) % mod[type]);
  }
  bool operator ==(const Hash &rhs) const{
	return (length == rhs.length && hash_value == rhs.hash_value);
  }
};

int n, q, cnt;

struct segtree_hash{
  Hash tree[maxn << 1];
  Hash combine(Hash x, Hash y) {
	return x + y;
  }
  void modify(int p, int val) {
	p += n;
	tree[p] = Hash(0, 1, val);
	for(; p >>= 1; ) {
	  tree[p] = combine(tree[p << 1], tree[p << 1 | 1]);
	}
  }
  Hash query(int l, int r) {
	cnt++;
	Hash resl, resr;
	resl = Hash(0, 0, 0);
	resr = Hash(0, 0, 0);
	for(l += n, r += n + 1; l < r; l >>= 1, r >>= 1) {
	  if(l & 1) resl = combine(resl, tree[l++]);
	  if(r & 1) resr = combine(tree[--r], resr);
	}
	return combine(resl, resr);
  }
}up, down;

int ncomp;
int id[maxn];
vector<int> comp[maxn];

void change(int p, int v) {
  id[p] = v;
  comp[v].push_back(p);
  up.modify(p, v);
  down.modify(n - p - 1, v);
}

void join(int x, int y) {
  if(x == y) return;
  ncomp--;
  if(comp[x].size() > comp[y].size()) swap(x, y);
  for(int v : comp[x]) change(v, y);
}

int main() {
  ios_base::sync_with_stdio(false); cin.tie(NULL);

  prepare_mod(0, maxn, 1000000033);
  prepare_mod(1, maxn + 33, 1000003577);



  cin >> n >> q;
  for(int i = 0; i < n; ++i) id[i] = i;

  mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
  shuffle(id, id + n, rng);


  for(int i = 0; i < n; ++i) {
	change(i, id[i]);
  }

  ncomp = n;

  for(int t = 0; t < q; ++t) {
	int L, R; cin >> L >> R;
	L--; R--;

	while(L < R) {
	  if(up.query(L, R) == down.query(n - R - 1, n - L - 1)) break;
	  int low = 0, high = (R - L) / 2, pos_diff = -1;
	  while(low <= high) {
	    int mid = (low + high) >> 1;
	    auto x = up.query(L, L + mid);
	    auto y = down.query(n - R - 1, n - (R - mid) - 1);

	    if(x == y) {
	      low = mid + 1;
	    }
	    else {
	      pos_diff = mid;
	      high = mid - 1;
	    }
	  }
	  if(pos_diff == -1) break;

	  join(id[L + pos_diff], id[R - pos_diff]);
	  L += pos_diff + 1;
	  R -= pos_diff + 1;
	}
	cout << ncomp << "\n";
  }
  cerr << clock() << endl;

  return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
 
#define ms(s, n) memset(s, n, sizeof(s))
#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define FORd(i, a, b) for (int i = (a) - 1; i >= (b); --i)
#define FORall(it, a) for (__typeof((a).begin()) it = (a).begin(); it != (a).end(); it++)
#define sz(a) int((a).size())
#define present(t, x) (t.find(x) != t.end())
#define all(a) (a).begin(), (a).end()
#define uni(a) (a).erase(unique(all(a)), (a).end())
#define pb push_back
#define pf push_front
#define mp make_pair
#define fi first
#define se second
#define prec(n) fixed<<setprecision(n)
#define bit(n, i) (((n) >> (i)) & 1)
#define bitcount(n) __builtin_popcountll(n)
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pi;
typedef vector<int> vi;
typedef vector<pi> vii;
const int MOD = (int) 1e9 + 7;
const int FFTMOD = 119 << 23 | 1;
const int INF = (int) 1e9 + 23111992;
const ll LINF = (ll) 1e18 + 23111992;
const ld PI = acos((ld) -1);
const ld EPS = 1e-9;
inline ll gcd(ll a, ll b) {ll r; while (b) {r = a % b; a = b; b = r;} return a;}
inline ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
inline ll fpow(ll n, ll k, int p = MOD) {ll r = 1; for (; k; k >>= 1) {if (k & 1) r = r * n % p; n = n * n % p;} return r;}
template<class T> inline int chkmin(T& a, const T& val) {return val < a ? a = val, 1 : 0;}
template<class T> inline int chkmax(T& a, const T& val) {return a < val ? a = val, 1 : 0;}
inline ull isqrt(ull k) {ull r = sqrt(k) + 1; while (r * r > k) r--; return r;}
inline ll icbrt(ll k) {ll r = cbrt(k) + 1; while (r * r * r > k) r--; return r;}
inline void addmod(int& a, int val, int p = MOD) {if ((a = (a + val)) >= p) a -= p;}
inline void submod(int& a, int val, int p = MOD) {if ((a = (a - val)) < 0) a += p;}
inline int mult(int a, int b, int p = MOD) {return (ll) a * b % p;}
inline int inv(int a, int p = MOD) {return fpow(a, p - 2, p);}
inline int sign(ld x) {return x < -EPS ? -1 : x > +EPS;}
inline int sign(ld x, ld y) {return sign(x - y);}
mt19937 mt(chrono::high_resolution_clock::now().time_since_epoch().count());
inline int mrand() {return abs((int) mt());}
inline int mrand(int k) {return abs((int) mt()) % k;}
#define db(x) cerr << "[" << #x << ": " << (x) << "] ";
#define endln cerr << "\n";

void chemthan() {
	int n, q; cin >> n >> q;
	vector<vector<vii>> que(20, vector<vii>(q));
	FOR(i, 0, q) {
	    int u, v; cin >> u >> v; u--, v--;
	    int k = 0;
	    while ((1 << k + 1) < v - u + 1) k++;
	    que[k][i].pb({u, n + n - v - 1});
	    int d = (v - u + 1) - (1 << k);
	    u += d, v -= d;
	    que[k][i].pb({u, n + n - v - 1});
	}
	vi dj(n + n);
	function<int(int)> find = [&] (int u) {
	    return dj[u] == u ? dj[u] : dj[u] = find(dj[u]);
	};
	auto join = [&] (int u, int v) {
	    u = find(u);
	    v = find(v);
	    if (u ^ v) {
	        dj[u] = v;
	        return 1;
	    }
	    return 0;
	};
	FORd(k, 19, 0) {
	    FOR(i, 0, sz(dj)) dj[i] = i;
	    FOR(i, 0, q) {
	        vii nxt;
	        for (auto e : que[k + 1][i]) {
	            int u, v; tie(u, v) = e;
	            if (join(u, v)) {
	                nxt.pb({u, v});
	            }
	            if (join(u + (1 << k), v + (1 << k))) {
	                nxt.pb({u + (1 << k), v + (1 << k)});
	            }
	        }
	        for (auto e : que[k][i]) {
	            int u, v; tie(u, v) = e;
	            if (join(u, v)) {
	                nxt.pb({u, v});
	            }
	        }
	        que[k][i] = nxt;
	    }
	}
	FOR(i, 0, sz(dj)) dj[i] = i;
	FOR(i, 0, n) join(i, n + n - i - 1);
	int res = n;
	FOR(i, 0, q) {
	    for (auto e : que[0][i]) {
	        int u, v; tie(u, v) = e;
	        if (join(u, v)) {
	            res--;
	        }
	    }
	    cout << res << "\n";
	}
}

int main(int argc, char* argv[]) {
	ios_base::sync_with_stdio(0), cin.tie(0);
	if (argc > 1) {
	    assert(freopen(argv[1], "r", stdin));
	}
	if (argc > 2) {
	    assert(freopen(argv[2], "wb", stdout));
	}
	chemthan();
	cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
	return 0;
}
Editorialist's Solution (using author's idea)
import java.util.*;
import java.io.*;
class COWSHEDS{
	//SOLUTION BEGIN
	int TIMES = 2, MAXN = (int)1e6;
	long[] mod;
	long[] base;
	long[][] basePow;
	void pre() throws Exception{
	    mod = new long[]{(long)1e9+7, (long)1e8+7};
	    base = new long[]{(long)1e6+3, (long)1e6+7};
	    basePow = new long[TIMES][1+MAXN];
	    for(int t = 0; t < TIMES; t++){
	        basePow[t][0] = 1;
	        for(int i = 1; i<= MAXN; i++)basePow[t][i] = basePow[t][i-1]*base[t]%mod[t];
	    }
	}
	void solve(int TC) throws Exception{
	    int N = ni(), Q = ni();
	    LinkedList<Integer>[] list = new LinkedList[1+N];
	    int[] sz = new int[1+N];
	    int[] dsu = new int[1+N];
	    for(int i = 1; i<= N; i++){
	        list[i] = new LinkedList<>();
	        list[i].add(i);
	        sz[i] = 1;dsu[i] = i;
	    }
	    SegTree forward = new SegTree(N), reverse = new SegTree(N);
	    for(int i = 1; i<= N; i++){
	        forward.update(i, i);
	        reverse.update(i, N-i+1);
	    }
	    int ans = N;
	    while(Q-->0){
	        int u = ni(), v = ni();
	        int len = v-u+1;
	        int uu = N-v+1;
	        while(!equalHash(forward.query(u, u+len-1), reverse.query(uu, uu+len-1))){
	            int lo = 1, hi = len;
	            while(lo+1 < hi){
	                int mid = lo+(hi-lo)/2;
	                if(!equalHash(forward.query(u, u+mid-1), reverse.query(uu, uu+mid-1)))hi = mid;
	                else lo = mid;
	            }
	            if(!equalHash(forward.query(u, u+lo-1), reverse.query(uu, uu+lo-1)))hi = lo;
	            union(forward, reverse, sz, dsu, list, u+hi-1, v-hi+1);
	            ans--;
	        }
	        pn(ans);
	    }
	}
	int find(int[] set, int u){return set[u] = (set[u] == u?u:find(set, set[u]));}
	void union(SegTree forward, SegTree reverse, int[] sz, int[] dsu, LinkedList<Integer>[] list, int u, int v){
	    u = find(dsu, u);
	    v = find(dsu, v);
	    if(sz[u] < sz[v]){
	        int tmp = u;
	        u = v;
	        v = tmp;
	    }
	    int N = list.length-1;
	    sz[u] += sz[v];
	    dsu[v] = u;
	    Iterator<Integer> it = list[v].iterator();
	    while(it.hasNext()){
	        int w = it.next();
	        forward.update(w, u);
	        reverse.update(N-w+1, u);
	        list[u].addLast(w);
	    }
	}
	class Hash{
	    int len;
	    long[] hash;
	    public Hash(int len, long[] h){
	        this.len = len;
	        hash = h;
	    }
	}
	Hash combine(Hash le, Hash ri){
	    long[] hash = new long[TIMES];
	    for(int t = 0; t< TIMES; t++)hash[t] = (le.hash[t]*basePow[t][ri.len]+ri.hash[t])%mod[t];
	    return new Hash(le.len+ri.len, hash);
	}
	boolean equalHash(Hash h1, Hash h2){
	    boolean equal = true;
	    for(int t = 0; t< TIMES && equal; t++)equal &= h1.hash[t] == h2.hash[t];
	    return equal;
	}
	class SegTree{
	    int m = 1;
	    Hash[] t;
	    public SegTree(int n){
	        while(m<=n)m<<=1;
	        t = new Hash[m<<1];
	        for(int i = 0; i< m; i++)
	            t[i+m] = new Hash(1, new long[]{i, i});
	        for(int i = m-1; i> 0; i--)t[i] = combine(t[i<<1], t[i<<1|1]);
	    }
	    void update(int pos, int val){
	        t[pos+= m] = new Hash(1, new long[]{val, val});
	        for(pos>>=1; pos > 0; pos>>=1)t[pos] = combine(t[pos<<1], t[pos<<1|1]); 
	    }
	    Hash query(int l, int r){
	        Hash le = new Hash(0, new long[]{0, 0}), ri = new Hash(0, new long[]{0, 0});
	        for(l += m, r += m+1; l< r; l>>=1, r>>=1){
	            if((l&1)==1)le = combine(le, t[l++]);
	            if((r&1)==1)ri = combine(t[--r], ri);
	        }
	        return combine(le, ri);
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new COWSHEDS().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}
Editorialist's Solution (using tester's idea)
import java.util.*;
import java.io.*;
class COWSHEDS{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int N = ni(), Q = ni(), B = 20;
	    Queue<int[]>[][] queuedUpdates = new LinkedList[B][Q];
	    for(int b = 0; b< B; b++)for(int i = 0; i< Q; i++)queuedUpdates[b][i] = new LinkedList<>();
	    for(int i = 0; i< Q; i++){
	        int u = ni()-1, v = ni()-1;
	        int k = 0;
	        while(1<<(k+1) < v-u+1)k++;
	        queuedUpdates[k][i].add(new int[]{u, N+N-v-1});
	        int d = (v-u+1)-(1<<k);
	        u+= d;v -= d;
	        queuedUpdates[k][i].add(new int[]{u, N+N-v-1});
	    }
	    int[] dsu = new int[2*N];
	    for(int b = B-2; b>= 0; b--){
	        resetDSU(dsu);
	        for(int i = 0; i< Q; i++){
	            for(int[] up:queuedUpdates[b+1][i]){
	                if(join(dsu, up[0], up[1]))
	                   queuedUpdates[b][i].add(up);
	                if(join(dsu, up[0]+(1<<b), up[1]+(1<<b)))
	                    queuedUpdates[b][i].add(new int[]{up[0]+(1<<b), up[1]+(1<<b)});
	            }
	        }
	    }
	    resetDSU(dsu);
	    for(int i = 0; i< N; i++)join(dsu, i, N+N-i-1);
	    int ans = N;
	    for(int i = 0; i< Q; i++){
	        for(int[] up:queuedUpdates[0][i])
	            if(join(dsu, up[0], up[1]))
	                ans--;
	        pn(ans);
	    }
	}
	void resetDSU(int[] dsu){for(int i = 0; i< dsu.length; i++)dsu[i] = i;}
	int find(int[] dsu, int u){return dsu[u] = (dsu[u] == u?u:find(dsu, dsu[u]));}
	boolean join(int[] dsu, int u, int v){
	    u = find(dsu, u);v = find(dsu, v);
	    if(u != v){
	        dsu[v] = u;
	        return true;
	    }
	    return false;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new COWSHEDS().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

8 Likes

Testers Solution is really nice

3 Likes

I agree! Such a novelty

This approach isn’t novelty. This problem can be solved in quite similar manner.

2 Likes

This problem too.

2 Likes

Considering the author’s solution, isn’t the time limit too strict ? I was barely able to pass the tests after using iterative segtree. Recursive segtree was not fast enough to pass.

1 Like

And isn’t the tester’s solution an online version of this : dacin21_codebook/palindrome_dsu.cpp at master · dacin21/dacin21_codebook · GitHub (seems same to me except that the queries are offline here)

Yes, it’s the same, except we here we need to print after each query.