COLDIF - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: pols_agyi_pols
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Offline queries, any point-add range-sum data structure

PROBLEM:

You have an array C of length N, and M chosen subarrays (L_i, R_i).
Define f(i, j) to be the number of distinct values that occur in the subarray C[L_i\ldots R_i] but not C[L_j\ldots R_j].

For each i from 1 to M, find D_i = \sum_{j=1}^M f(i, j).

EXPLANATION:

Let’s fix an i, and look at D_i = \sum_{j=1}^M f(i, j).
A different way to obtain the same sum is as follows:

  • For a value x, let f(x) denote the number of ranges (from the M we have) that don’t contain x.
  • Then, D_i equals the sum of f(x) across all distinct x that occur in the range [L_i, R_i].

This criterion is what allows us to solve the problem quickly.

First, let’s compute all the f(x) values.
Fix a color x.
Let its occurrences in C be at indices i_1, i_2, \ldots, i_k.
For convenience, let i_0 = 0 and i_{k+1} = N+1.

Observe that f(x) equals the number of intervals that don’t contain any of the above indices - meaning they must lie strictly between i_j and i_{j+1} for some 0 \leq j \leq k.
So, we need to query for “given l and r, how many of our M intervals [L_i, R_i] satisfy
l \leq L_i \leq R_i \leq r?”

This is a rather classical task, and can be solved in \mathcal{O}((N+M)\log N) time offline using a sweepline and segment tree/fenwick tree.

How?

First, precompute all ranges that must be queried (which, as noted above, can be obtained by looking at the consecutive occurrences of each x).
Note that the manner in which query ranges are generated means there’s at most 2N of them.

For index i, let Q_i be a list of queries whose left endpoints are i.

Let a_i denote the number of “active” intervals that end at index i.
We’ll sweep over the left endpoint of our queries and maintain the array a as we go.
An interval is considered “active” if it starts at or after our current index.

Initially, all our M intervals are active.
Then, for each i = 1, 2, 3, \ldots, N in order:

  • Answer all queries in Q_i.
    The answer to a query ending at index R is simply (a_i + a_{i+1} + \ldots + a_R).
    That’s because we’ve ensured via our sweep that all intervals that are counted in a start at an index \geq i, so it suffices to count the number of them that end at an index \leq R.
  • Then, deactivate all intervals that start at index i, since they must no longer be considered active for further indices.
    For the interval [i, R], this corresponds to reducing a_R by 1.

So, we need a data structure that allows us to add to a point and query for the sum of a range quickly, which a segment tree/fenwick tree will do in \mathcal{O}(\log N).

Once all the f(x) values are known, we move on to finding the D_i values.
This part can also be done offline in similar fashion to the first part, in \mathcal{O}((N+M)\log N) time.

How?

As noted at the start, D_i equals the sum of f(x) across all distinct x present in the range.

Just as in the previous part, we answer queries offline in increasing order of their left endpoints (we now have M queries, corresponding to the intervals with us).

To account for distinctness, we use a little trick: only care about the leftmost occurrence of each color, and ignore the others.
Consider an array b, where:

  • If index i is the leftmost occurrence of C_i, set b_i = f(C_i).
  • Otherwise, set b_i = 0.

Now, sweep over the left endpoints of the queries.
As before, when at index i, answer all queries starting there - you may note that all you want is a range sum of the b array.
Once this is done, index i is no longer under consideration since we’ll be past it.
So, instead find the smallest index j\gt i such that C_j = C_i, and set b_j = f(C_j) instead (note that b_j would’ve been 0 just before this).

Once again, all that’s needed is a data structure that supports quick point updates and range sum queries.

The index j can be found quickly in a variety of ways - for example binary search on the list of positions of C_i or just precompute the jumps for every index in linear time.

TIME COMPLEXITY:

\mathcal{O}((N+M)\log N) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

struct segtree{
 
    int size;
    vector<int> operations;
    vector<int> values;
 
    int NEUTRAL_ELEMENT = 0;
 
    int modify_op(int a,int b){
        return a+b;
    }
 
    int calc_op(int a,int b){
        return a+b;
    }
 
    void apply_mod_op(int &a,int b){
        a=modify_op(a,b);
    }
 
    void init(int n){
        size=1;
        while(size<n)size*=2;
        operations.assign(2*size,0LL);
        values.assign(2*size,0LL);
    }
 
    // void propogate(int x, int lx, int rx){
    //     if(rx-lx==1)return;
    //     values[2*x+1]=merge(values[2*x+1],values[x]);
    //     values[2*x+2]=merge(values[2*x+2],values[x]);
    //     values[x]=NEUTRAL_ELEMENT;
    // }
    void build(vi &a,int x,int lx,int rx){
        if(rx-lx==1){
            if(lx<a.size()){
                values[x]=a[lx];
            }
            return;
        }
        int m = (lx+rx)/2;
        build(a,2*x+1,lx,m);
        build(a,2*x+2,m,rx);
        values[x]=(values[2*x+1]+values[2*x+2]);
    }
 
    void build(vi &a){
        build(a,0,0,size);
    }
 
    void modify(int l,int r,int v,int x,int lx,int rx){
        // propogate(x,lx,rx);
        if(lx>=r || l>=rx)return;
        if(lx>=l && rx<=r){
        apply_mod_op(operations[x],v);
        apply_mod_op(values[x],v*(rx-lx));
        return;}
        int m = (lx+rx)/2;
        modify(l,r,v,2*x+1,lx,m);
        modify(l,r,v,2*x+2,m,rx);
        values[x]=calc_op(values[2*x+1],values[2*x+2]);
        apply_mod_op(values[x],operations[x]*(rx-lx));
    }
 
    void modify(int l,int r,int v){
        return modify(l,r,v,0,0,size);
    }
 
    int calc(int l,int r,int x,int lx,int rx){
        if(lx>=r || l>=rx)return NEUTRAL_ELEMENT;
        if(lx>=l && rx<=r){return values[x];}
        int m = (lx+rx)/2;
        int m1=calc(l,r,2*x+1,lx,m);
        int m2=calc(l,r,2*x+2,m,rx);
        auto res = calc_op(m1,m2);
        apply_mod_op(res,operations[x]*(min(r,rx)-max(l,lx)));
        return res;
    }
 
    int calc(int l,int r){
        return calc(l,r,0,0,size);
    }
 
};

 
void solve()
{
    int n,m;
    cin >> n >> m;
    vi a(n);
    take(a,n);
    for(auto &x : a)x--;
    vi l(m),r(m);
    vi v[n];
    vi c(n);
    vi p[n];
    vi pp[n];
    repin{
        p[a[i]].pb(i);
    }
    repin{
        if(p[i].size() == 0)continue;
        reverse(be(p[i]));
        pp[i] = p[i];
    }
    // vector<array<int,3>> q;
    vector<array<int,2>> v1[n];
    rep(i,0,m){
        cin >> l[i] >> r[i];
        l[i]--;
        r[i]--;
        v[r[i]].pb(l[i]);
        c[l[i]]++;
        v1[l[i]].pb({r[i],i});
    }
    // cout << "hi\n";return;
    // give(c,n);cout << "\n";
    vi b(n);
    segtree st;
    st.init(n);
    segtree st1;
    st1.init(n);
    segtree st2;
    st2.init(n);
    vi cur(n);
    vi vis(n);
    repin{
        st.modify(i,i+1,c[i]);
        int res = 0;
        if(vis[a[i]])res = p[a[i]].back()+1,p[a[i]].pob(); 
        vis[a[i]] = true;
        b[i] = st.calc(res,i+1);
        for(auto x : v[i]){
            st.modify(x,x+1,-1);
        }
    }
    vi col(n);
    repin{
        col[a[i]] += b[i];
    }
    
    // cout << "hi\n";return;
    
    vi ans(m);
    repin{
        if(pp[i].size()){
            st1.modify(pp[i].back(),pp[i].back()+1,col[i]);
            st2.modify(pp[i].back(),pp[i].back()+1,1);
        }
    }
    // give(c,n);cout << "\n";
    // give(col,n);cout << "\n";
    // give(b,n);cout << "\n";
    // repin cout << st1.calc(i,i+1) << " ";cout << "\n";
    // repin cout << st2.calc(i,i+1) << " ";cout << "\n";
    repin{
        for(auto [e,z] : v1[i]){
            ans[z] = st2.calc(i,e+1)*m - st1.calc(i,e+1);
        }
        pp[a[i]].pob();
        if(pp[a[i]].size()){
            st1.modify(pp[a[i]].back(),pp[a[i]].back()+1,col[a[i]]);
            st2.modify(pp[a[i]].back(),pp[a[i]].back()+1,1);
        }
    }
    give(ans,m);cout << "\n";
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t;
    cin >> t;
    while(t--)
        solve();
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

template<class T, T unit = T()>
struct SegTree {
	T f(T a, T b) { return a+b; }
	vector<T> s; int n;
	SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
	void update(int pos, T val) {
        pos += n;
        s[pos] += val;
        while (pos /= 2)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        vector<int> c(n);
        for (int &x : c) cin >> x;
        vector<vector<int>> positions(n+1);
        for (int i = 0; i < n; ++i)
            positions[c[i]].push_back(i);
        vector<int> next(n+1, n), jump(n);
        for (int i = n-1; i >= 0; --i) {
            jump[i] = next[c[i]];
            next[c[i]] = i;
        }

        vector<array<int, 2>> intervals(m);
        vector<vector<int>> from_here(n);
        for (auto &[x, y] : intervals) {
            cin >> x >> y;
            --x, --y;
            from_here[x].push_back(y);
        }
        
        vector<int> dont_have(n+1);
        vector<vector<array<int, 2>>> queries(n+1);
        for (int i = 1; i <= n; ++i) {
            if (positions[i].empty()) continue;

            int prv = -1;
            for (auto x : positions[i]) {
                queries[prv+1].push_back({x, i});
                prv = x;
            }
            queries[prv+1].push_back({n, i});
        }

        SegTree<ll> seg(n);
        for (auto [x, y] : intervals) seg.update(y, 1);

        for (int L = 0; L < n; ++L) {
            for (auto [R, id] : queries[L])
                dont_have[id] += seg.query(L, R);
            for (auto R : from_here[L]) seg.update(R, -1);
        }

        for (auto &tmp : queries) tmp.clear();
        for (int i = 0; i < m; ++i)
            queries[intervals[i][0]].push_back({intervals[i][1], i});
        
        for (int i = 1; i <= n; ++i) {
            if (positions[i].size()) seg.update(positions[i][0], dont_have[i]);
        }

        vector<ll> ans(m);
        for (int L = 0; L < n; ++L) {
            for (auto [R, id] : queries[L]) {
                ans[id] = seg.query(L, R+1);
            }
            if (jump[L] < n) seg.update(jump[L], dont_have[c[L]]);
        }

        for (auto x : ans) cout << x << ' ';
        cout << '\n';
    }
}
1 Like

As an alternative, the second part is also a classical exercise on Mo’s algorithm. The benefit of bashing it with Mo is that you have less thinking to do during the contest.

Here’s a link to my implementation. It does not contain any constant factor optimizations so it is somewhat slow at 1.3 seconds.

Yes, several sqrt solutions are possible here, and if implemented reasonably they will likely pass.
I wasn’t able to cleanly separate them, so I elected to make the time limit kinda loose rather than make constraints higher and risk bad implementations with intended complexity failing.
(Though I do think it’s a bit unfortunate, since the problem was conceptualized as more of an educational task on sweepline + segtree but allowing sqrt sidesteps that.)