VIDEOTAPES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: weaponzdautist
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Square root decomposition

PROBLEM:

You’re given two arrays L and C, both of length N.
Process the following, online:

  • 1 l r: find the sum of L_i across all i between l and r, such that C_i hasn’t occurred to its left within this range.
  • 2 i x: set L_i := x
  • 3 i x: set C_i := x

EXPLANATION:

There are multiple approaches to solve this task. One of them, relatively easy to implement, is with square root decomposition.

First, let’s compute \text{prev}_i to be the largest index j \lt i such that C_j = C_i.
If no such j exists, we say \text{prev}_i = -1.

Now, note that:

  • To answer a query 1 l r, we essentially want the sum of A_i across all l \leq i \leq r such that \text{prev}_i \lt l.
    If \text{prev}_i \geq l, C_i has occurred before in this range.
  • Updating L_i doesn’t change the array \text{prev}.
  • Updating C_i to x changes at most three indices of \text{prev}.
    Specifically, let j\gt i be the index such that \text{prev}_j = i, and k\gt i be the smallest index such that C_k = x.
    Then, the \text{prev} values of only indices i, j, k will change (note that j and/or k may also not exist, which is fine).
    Finding j and k can be done in \mathcal{O}(\log N) time by storing a sorted list of indices corresponding to each value of C_i, and then binary searching on this list.
    For example, you can use std::set for this, since you also need quick insertion/deletion.

Now, notice that we have a situation where updates are “fast” while queries are “slow”.
So, we can afford to make updates a bit slower if it allows for faster queries, which is where square-root decomposition is often helpful.

Let’s choose a constant B, and break the range [1, N] into blocks of size B.
For each block, we’ll store the list of indices corresponding to it, sorted in increasing order of their \text{prev}_i values.

Now,

  • Suppose we get the query [l, r].
    This range will fully enclose some of the blocks, and partially intersect at most two of them (at the ends).
    The partially intersecting part can be brute-forced, we check at most 2B indices in total.
    As for a block that’s fully enclosed, recall that we’ve kept the indices in it sorted by their \text{prev}_i values.
    So, we’re looking for the sum of L_i of some prefix of this sorted list (since we only care about those indices with \text{prev}_i \lt l).
    Finding the appropriate prefix can be done with binary search in \mathcal{O}(\log B), and we do this once for at most \frac{N}{B} blocks.
  • Updates are simple to handle too: as we noted above, at most three indices will change values after an update, so at most three blocks need to be recomputed.
    Each block can be recomputed in \mathcal{O}(B\log B) time, since we perform a single sort and then compute prefix sums.

So, we have a complexity of \mathcal{O}(B\log B) for updates, and \mathcal{O}(\frac{N}{B}\log B + B + \log N) for queries.

Choosing B = \sqrt N makes both parts have a complexity of \mathcal{O}(\sqrt N \log N), which is fast enough for us.

It’s possible to perform updates in \mathcal{O}(B) time by utilizing the fact that at most three indices change, so resorting the entire block isn’t necessary: instead, we can do something like insertion sort to fix those three indices alone.
This allows us to choose B = \sqrt {N\log N} to marginally improve the time complexity from \mathcal{O}(\sqrt N\log N) to \mathcal{O}(\sqrt{N\log N)} per query, but isn’t necessary to get AC.

In practice, you can just hardcode a reasonable enough value of B (say, something around 500) and be fine.

There also exist solutions that answer each query in \mathcal{O}(\log^2 N) or \mathcal{O}(\log ^3 N), for instance using 2D or persistent structures.
These will likely take a bit more effort to code (if you don’t already have a template), and will probably also have high constant factor, however.

TIME COMPLEXITY:

\mathcal{O}(N\log N + Q\sqrt N\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>

#pragma GCC optimize("03")

using namespace std;

#define ll long long
#define fastio() ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)

#define all(a) a.begin(),a.end()
#define endl "\n"
#define sp " " 
#define pb push_back
#define mp make_pair
#define vecvec(type, name, n, m, value) vector<vector<type>> name(n + 1, vector<type> (m + 1, value))

void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}

template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define debug(x...) {cerr << "[" << #x << "] = ["; _print(x);}
#define reach cerr << "reached" << endl
#else
#define debug(x...)
#define reach 
#endif

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int MOD = 1e9+7;
const int64_t inf = 0x3f3f3f3f, INF = 1e18, BIG_MOD = 489133282872437279;
/*--------------------------------------------------------------------------------------------------------------------------------------------------------------------------*/

// #define int int64_t

int ceil_div(int x, int y)
{
    return (x + y - 1)/y;
}

const int N = 2e5+5, B = 1500;

struct Fenwick          //one indexed
{
    vector<int64_t> bit;
    int n;

    Fenwick() {};
    void init(int n) {this->n = n + 1;bit.assign(n + 1, 0);}

    // mode 1
    int64_t sum(int idx) 
    {
        int64_t ret = 0;
        for (++idx; idx > 0; idx -= idx & -idx) ret += bit[idx];
        return ret;
    }

    int64_t sum(int l, int r)   {return sum(r) - sum(l - 1);}
    void add(int idx, int delta)    {for (++idx; idx < n; idx += idx & -idx) bit[idx] += delta;}
    
    // mode 2
    void range_add(int l, int r, int val) 
    {
        add(l, val), add(r + 1, -val);
    }

    int64_t point_query(int idx) 
    {
        int64_t ret = 0;
        for (++idx; idx > 0; idx -= idx & -idx) ret += bit[idx];
        return ret;
    }
};

int n, q;
int a[N], c[N];

int val[N];
set<int> ind[N];
int prv[N], nxt[N];
Fenwick fen[(N + B - 1)/B  + 10];

int32_t main()
{
    fastio();

    cin >> n >> q;

    for(int i = 1; i <= n; i ++)
    {
        cin >> a[i] >> c[i];
        ind[c[i]].insert(i);
    }

    fill(prv, prv + N, 0);
    fill(nxt, nxt + N, n + 1);

    for(int c = 1; c < N; c ++)  if(!ind[c].empty())
    {
        vector<int> v(all(ind[c]));

        for(int i = 0; i < v.size(); i ++)
        {
            if(i != 0)
                prv[v[i]] = v[i - 1];
            
            if(v[i] != v.back())
                nxt[v[i]] = v[i + 1];
        }
    }

    for(int b = 1; b <= ceil_div(n, B); b ++)
    {
        fen[b].init(n);
        
        int l = (b - 1) * B + 1, r = min(n, b * B);

        fill(val, val + N, 0);

        int64_t score = 0;
        for(int i = r; i >= 1; i --)
        {
            score -= val[c[i]];

            val[c[i]] = (l <= i ? a[i] : 0);
            score += val[c[i]];

            fen[b].range_add(i, i, score);
        }
    }

    int64_t last = 0;

    for(int i = 1; i <= q; i ++)
    {
        int t;
        cin >> t;

        if(t == 1)                                      //answer query from l to r
        {
            int64_t l, r;
            cin >> l >> r;
            l ^= last, r ^= last;

            int bl = ceil_div(l, B), br = ceil_div(r, B);

            int64_t ans = 0;

            if(bl == br)
            {
                for(int i = l; i <= r; i ++)
                    if(prv[i] < l)
                        ans += a[i];
            }
            else
            {
                for(int i = l; i <= bl * B; i ++)
                    if(prv[i] < l)
                        ans += a[i];
                    
                for(int i = (br - 1) * B + 1; i <= r; i ++)
                    if(prv[i] < l)
                        ans += a[i];

                for(int b = bl + 1; b <= br - 1; b ++)
                    ans += fen[b].point_query(l);
            }

            last = ans;
            cout << ans << endl;
        }
        else if(t == 2)                                 //change value of a[j] to y
        {
            int64_t j, y;
            cin >> j >> y;
            j ^= last, y ^= last;

            int b = ceil_div(j, B);
            fen[b].range_add(prv[j] + 1, j, y - a[j]);

            a[j] = y;
        }
        else if(t == 3)                                 //change color of a[j] to d
        {
            int64_t j, d;
            cin >> j >> d;
            j ^= last, d ^= last;

            ind[c[j]].erase(j);

            int b = ceil_div(j, B);
            fen[b].range_add(prv[j] + 1, n, -a[j]);

            if(nxt[j] != n + 1)
            {
                int b2 = ceil_div(nxt[j], B);
                fen[b2].range_add(prv[j] + 1, j, +a[nxt[j]]);
            }
            int _prv = prv[j], _nxt = nxt[j];
            nxt[_prv] = _nxt, prv[_nxt] = _prv;

            c[j] = d;
            ind[c[j]].insert(j);

            auto it = ind[c[j]].lower_bound(j);

            if(it != ind[c[j]].begin())
            {
                -- it;

                prv[j] = *it;
                nxt[prv[j]] = j;

                ++ it;
            }
            else
                prv[j] = 0;
            fen[b].range_add(prv[j] + 1, j, +a[j]);

            ++ it;
            if(it != ind[c[j]].end())
            {
                nxt[j] = *it;
                prv[nxt[j]] = j;

                int b2 = ceil_div(nxt[j], B);
                fen[b2].range_add(prv[j] + 1, j, -a[nxt[j]]);
            }
            else
                nxt[j] = n + 1;
        }
    }
}
Tester's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>
using namespace std;

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <typename T>
struct fenwick {
    int n;
    vector<T> node;

    fenwick(int _n) : n(_n) {
        node.resize(n);
    }

    void add(int x, T v) {
        while (x < n) {
            node[x] += v;
            x |= (x + 1);
        }
    }

    T get(int x) {  // [0, x]
        T v = 0;
        while (x >= 0) {
            v += node[x];
            x = (x & (x + 1)) - 1;
        }
        return v;
    }

    T get(int x, int y) {  // [x, y]
        return (get(y) - (x ? get(x - 1) : 0));
    }
};

int main() {
    input_checker in;
    int n = in.readInt(1, 1e5);
    in.readSpace();
    int q = in.readInt(1, 1e5);
    in.readEoln();
    vector<int> len(n), col(n);
    for (int i = 0; i < n; i++) {
        len[i] = in.readInt(1, 1e4);
        in.readSpace();
        col[i] = in.readInt(1, n);
        in.readEoln();
        col[i]--;
    }
    vector<set<int>> at(n);
    for (int i = 0; i < n; i++) {
        at[col[i]].emplace(i);
    }
    for (int i = 0; i < n; i++) {
        at[i].emplace(-1);
        at[i].emplace(n);
    }
    const int B = 350;
    int C = (n + B - 1) / B;
    vector<fenwick<int>> f(C, fenwick<int>(n));
    map<int, int> mp;
    for (int i = n - 1; i >= 0; i--) {
        mp[col[i]] = i;
        if (i % B == 0) {
            int j = i / B;
            for (auto [x, y] : mp) {
                f[j].add(y, len[y]);
            }
        }
    }
    int last = 0;
    while (q--) {
        int op = in.readInt(1, 3);
        in.readSpace();
        int x = in.readInt(0, 2e9);
        in.readSpace();
        int y = in.readInt(0, 2e9);
        in.readEoln();
        x ^= last;
        y ^= last;
        if (op == 1) {
            assert(1 <= x && x <= y && y <= n);
            x--;
            y--;
            set<int> st;
            int ans = 0;
            for (int i = x; i <= y; i++) {
                if (i % B == 0) {
                    int j = i / B;
                    ans += f[j].get(i, y);
                    for (int k : st) {
                        int l = *at[k].lower_bound(i);
                        if (l <= y) {
                            ans -= len[l];
                        }
                    }
                    break;
                }
                if (st.count(col[i])) {
                    continue;
                }
                st.emplace(col[i]);
                ans += len[i];
            }
            cout << ans << '\n';
            last = ans;
        } else if (op == 2) {
            assert(1 <= x && x <= n);
            x--;
            assert(1 <= y && y <= 1e4);
            int l0 = *prev(at[col[x]].lower_bound(x));
            for (int i = l0 / B; i < C && i * B <= x; i++) {
                if (l0 < i * B) {
                    f[i].add(x, y - len[x]);
                }
            }
            len[x] = y;
        } else {
            assert(1 <= x && x <= n);
            x--;
            assert(1 <= y && y <= n);
            y--;
            int l0 = *prev(at[col[x]].lower_bound(x));
            int r0 = *at[col[x]].upper_bound(x);
            int l1 = *prev(at[y].lower_bound(x));
            int r1 = *at[y].upper_bound(x);
            for (int i = 0; i < C && i * B <= x; i++) {
                if (l0 < i * B) {
                    f[i].add(x, -len[x]);
                    if (r0 < n) {
                        f[i].add(r0, len[r0]);
                    }
                }
                if (l1 < i * B) {
                    f[i].add(x, len[x]);
                    if (r1 < n) {
                        f[i].add(r1, -len[r1]);
                    }
                }
            }
            at[col[x]].erase(x);
            col[x] = y;
            at[col[x]].emplace(x);
        }
    }
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    
    int n, q; cin >> n >> q;
    vector<int> a(n), b(n);
    for (int i = 0; i < n; ++i) cin >> a[i] >> b[i];

    const int B = 400;
    /**
     * Divide into blocks of size B
     * let prev[i] = index of largest j < i such that b[i] = b[j]
     * in each block, keep elements sorted by prev[i] and build prefix sums
     * updates:
     * - changing value just changes the prefix sum within a block
     * - changing color changes the ordering and prefix sum of at most two blocks
     * query:
     * - for a full block, some prefix sum (find using binary search)
     * - for a non-full block, brute
     */
    
    vector<int> jump(n, -1);
    vector<set<int>> who(n+1);
    for (int i = 0; i < n; ++i) {
        if (!who[b[i]].empty()) jump[i] = *who[b[i]].rbegin();
        who[b[i]].insert(i);
    }
    
    vector<int> block_order(n);
    vector<ll> pref(n);
    auto recalc = [&] (int block) {
        int lo = block*B, hi = min(n, block*B + B);
        for (int i = lo; i < hi; ++i) block_order[i] = i;
        sort(begin(block_order)+lo, begin(block_order)+hi, [&] (int i, int j) {
            return jump[i] < jump[j];
        });
        for (int i = lo; i < hi; ++i) {
            int u = block_order[i];
            pref[i] = a[u];
            if (i > lo) pref[i] += pref[i-1];
        }
    };
    for (int i = 0; i <= n/B; ++i) recalc(i);
    ll last = 0;
    while (q--) {
        int type; cin >> type;
        if (type == 1) {
            ll L, R; cin >> L >> R;
            L ^= last, R ^= last;
            --L, --R;
            
            last = 0;
            int low = L;
            while (L <= R) {
                if (L%B == 0) break;
                if (jump[L] < low) last += a[L];
                ++L;
            }
            while (L <= R) {
                if (R%B == B-1) break;
                if (jump[R] < low) last += a[R];
                --R;
            }
            if (L <= R) {
                for (int block = L/B; block <= R/B; ++block) {
                    int lo = block*B, hi = block*B + B;
                    auto till = lower_bound(begin(block_order)+lo, begin(block_order)+hi, low, [&] (int i, int x) {
                        return jump[i] < x;
                    }) - begin(block_order);
                    if (till > lo) last += pref[till-1];
                }
            }
            cout << last << '\n';
        }
        else {
            ll pos, val; cin >> pos >> val;
            pos ^= last, val ^= last;
            --pos;

            if (type == 2) {
                a[pos] = val;
                recalc(pos/B);
            }
            else {
                auto it = who[b[pos]].find(pos);
                int x = n, y = n;
                if (next(it) != end(who[b[pos]])) {
                    x = *next(it);
                    jump[x] = jump[pos];
                }
                who[b[pos]].erase(pos);
                
                who[val].insert(pos);
                b[pos] = val;
                auto it2 = who[val].find(pos);
                if (next(it2) != end(who[val])) {
                    y = *next(it2);
                    jump[y] = pos;
                }
                if (it2 == begin(who[val])) jump[pos] = -1;
                else jump[pos] = *prev(it2);
                
                if (x > y) swap(x, y);
                recalc(pos/B);
                if (x != n and pos/B < x/B) recalc(x/B);
                if (y != n and x/B < y/B) recalc(y/B);
            }
        }
    }
}


1 Like