PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: weaponzdautist
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Square root decomposition
PROBLEM:
You’re given two arrays L and C, both of length N.
Process the following, online:
1 l r
: find the sum of L_i across all i between l and r, such that C_i hasn’t occurred to its left within this range.2 i x
: set L_i := x3 i x
: set C_i := x
EXPLANATION:
There are multiple approaches to solve this task. One of them, relatively easy to implement, is with square root decomposition.
First, let’s compute \text{prev}_i to be the largest index j \lt i such that C_j = C_i.
If no such j exists, we say \text{prev}_i = -1.
Now, note that:
- To answer a query
1 l r
, we essentially want the sum of A_i across all l \leq i \leq r such that \text{prev}_i \lt l.
If \text{prev}_i \geq l, C_i has occurred before in this range. - Updating L_i doesn’t change the array \text{prev}.
- Updating C_i to x changes at most three indices of \text{prev}.
Specifically, let j\gt i be the index such that \text{prev}_j = i, and k\gt i be the smallest index such that C_k = x.
Then, the \text{prev} values of only indices i, j, k will change (note that j and/or k may also not exist, which is fine).
Finding j and k can be done in \mathcal{O}(\log N) time by storing a sorted list of indices corresponding to each value of C_i, and then binary searching on this list.
For example, you can usestd::set
for this, since you also need quick insertion/deletion.
Now, notice that we have a situation where updates are “fast” while queries are “slow”.
So, we can afford to make updates a bit slower if it allows for faster queries, which is where square-root decomposition is often helpful.
Let’s choose a constant B, and break the range [1, N] into blocks of size B.
For each block, we’ll store the list of indices corresponding to it, sorted in increasing order of their \text{prev}_i values.
Now,
- Suppose we get the query [l, r].
This range will fully enclose some of the blocks, and partially intersect at most two of them (at the ends).
The partially intersecting part can be brute-forced, we check at most 2B indices in total.
As for a block that’s fully enclosed, recall that we’ve kept the indices in it sorted by their \text{prev}_i values.
So, we’re looking for the sum of L_i of some prefix of this sorted list (since we only care about those indices with \text{prev}_i \lt l).
Finding the appropriate prefix can be done with binary search in \mathcal{O}(\log B), and we do this once for at most \frac{N}{B} blocks. - Updates are simple to handle too: as we noted above, at most three indices will change values after an update, so at most three blocks need to be recomputed.
Each block can be recomputed in \mathcal{O}(B\log B) time, since we perform a single sort and then compute prefix sums.
So, we have a complexity of \mathcal{O}(B\log B) for updates, and \mathcal{O}(\frac{N}{B}\log B + B + \log N) for queries.
Choosing B = \sqrt N makes both parts have a complexity of \mathcal{O}(\sqrt N \log N), which is fast enough for us.
It’s possible to perform updates in \mathcal{O}(B) time by utilizing the fact that at most three indices change, so resorting the entire block isn’t necessary: instead, we can do something like insertion sort to fix those three indices alone.
This allows us to choose B = \sqrt {N\log N} to marginally improve the time complexity from \mathcal{O}(\sqrt N\log N) to \mathcal{O}(\sqrt{N\log N)} per query, but isn’t necessary to get AC.
In practice, you can just hardcode a reasonable enough value of B (say, something around 500) and be fine.
There also exist solutions that answer each query in \mathcal{O}(\log^2 N) or \mathcal{O}(\log ^3 N), for instance using 2D or persistent structures.
These will likely take a bit more effort to code (if you don’t already have a template), and will probably also have high constant factor, however.
TIME COMPLEXITY:
\mathcal{O}(N\log N + Q\sqrt N\log N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
#pragma GCC optimize("03")
using namespace std;
#define ll long long
#define fastio() ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
#define all(a) a.begin(),a.end()
#define endl "\n"
#define sp " "
#define pb push_back
#define mp make_pair
#define vecvec(type, name, n, m, value) vector<vector<type>> name(n + 1, vector<type> (m + 1, value))
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(long long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define debug(x...) {cerr << "[" << #x << "] = ["; _print(x);}
#define reach cerr << "reached" << endl
#else
#define debug(x...)
#define reach
#endif
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int MOD = 1e9+7;
const int64_t inf = 0x3f3f3f3f, INF = 1e18, BIG_MOD = 489133282872437279;
/*--------------------------------------------------------------------------------------------------------------------------------------------------------------------------*/
// #define int int64_t
int ceil_div(int x, int y)
{
return (x + y - 1)/y;
}
const int N = 2e5+5, B = 1500;
struct Fenwick //one indexed
{
vector<int64_t> bit;
int n;
Fenwick() {};
void init(int n) {this->n = n + 1;bit.assign(n + 1, 0);}
// mode 1
int64_t sum(int idx)
{
int64_t ret = 0;
for (++idx; idx > 0; idx -= idx & -idx) ret += bit[idx];
return ret;
}
int64_t sum(int l, int r) {return sum(r) - sum(l - 1);}
void add(int idx, int delta) {for (++idx; idx < n; idx += idx & -idx) bit[idx] += delta;}
// mode 2
void range_add(int l, int r, int val)
{
add(l, val), add(r + 1, -val);
}
int64_t point_query(int idx)
{
int64_t ret = 0;
for (++idx; idx > 0; idx -= idx & -idx) ret += bit[idx];
return ret;
}
};
int n, q;
int a[N], c[N];
int val[N];
set<int> ind[N];
int prv[N], nxt[N];
Fenwick fen[(N + B - 1)/B + 10];
int32_t main()
{
fastio();
cin >> n >> q;
for(int i = 1; i <= n; i ++)
{
cin >> a[i] >> c[i];
ind[c[i]].insert(i);
}
fill(prv, prv + N, 0);
fill(nxt, nxt + N, n + 1);
for(int c = 1; c < N; c ++) if(!ind[c].empty())
{
vector<int> v(all(ind[c]));
for(int i = 0; i < v.size(); i ++)
{
if(i != 0)
prv[v[i]] = v[i - 1];
if(v[i] != v.back())
nxt[v[i]] = v[i + 1];
}
}
for(int b = 1; b <= ceil_div(n, B); b ++)
{
fen[b].init(n);
int l = (b - 1) * B + 1, r = min(n, b * B);
fill(val, val + N, 0);
int64_t score = 0;
for(int i = r; i >= 1; i --)
{
score -= val[c[i]];
val[c[i]] = (l <= i ? a[i] : 0);
score += val[c[i]];
fen[b].range_add(i, i, score);
}
}
int64_t last = 0;
for(int i = 1; i <= q; i ++)
{
int t;
cin >> t;
if(t == 1) //answer query from l to r
{
int64_t l, r;
cin >> l >> r;
l ^= last, r ^= last;
int bl = ceil_div(l, B), br = ceil_div(r, B);
int64_t ans = 0;
if(bl == br)
{
for(int i = l; i <= r; i ++)
if(prv[i] < l)
ans += a[i];
}
else
{
for(int i = l; i <= bl * B; i ++)
if(prv[i] < l)
ans += a[i];
for(int i = (br - 1) * B + 1; i <= r; i ++)
if(prv[i] < l)
ans += a[i];
for(int b = bl + 1; b <= br - 1; b ++)
ans += fen[b].point_query(l);
}
last = ans;
cout << ans << endl;
}
else if(t == 2) //change value of a[j] to y
{
int64_t j, y;
cin >> j >> y;
j ^= last, y ^= last;
int b = ceil_div(j, B);
fen[b].range_add(prv[j] + 1, j, y - a[j]);
a[j] = y;
}
else if(t == 3) //change color of a[j] to d
{
int64_t j, d;
cin >> j >> d;
j ^= last, d ^= last;
ind[c[j]].erase(j);
int b = ceil_div(j, B);
fen[b].range_add(prv[j] + 1, n, -a[j]);
if(nxt[j] != n + 1)
{
int b2 = ceil_div(nxt[j], B);
fen[b2].range_add(prv[j] + 1, j, +a[nxt[j]]);
}
int _prv = prv[j], _nxt = nxt[j];
nxt[_prv] = _nxt, prv[_nxt] = _prv;
c[j] = d;
ind[c[j]].insert(j);
auto it = ind[c[j]].lower_bound(j);
if(it != ind[c[j]].begin())
{
-- it;
prv[j] = *it;
nxt[prv[j]] = j;
++ it;
}
else
prv[j] = 0;
fen[b].range_add(prv[j] + 1, j, +a[j]);
++ it;
if(it != ind[c[j]].end())
{
nxt[j] = *it;
prv[nxt[j]] = j;
int b2 = ceil_div(nxt[j], B);
fen[b2].range_add(prv[j] + 1, j, -a[nxt[j]]);
}
else
nxt[j] = n + 1;
}
}
}
Tester's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>
using namespace std;
// #define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
template <typename T>
struct fenwick {
int n;
vector<T> node;
fenwick(int _n) : n(_n) {
node.resize(n);
}
void add(int x, T v) {
while (x < n) {
node[x] += v;
x |= (x + 1);
}
}
T get(int x) { // [0, x]
T v = 0;
while (x >= 0) {
v += node[x];
x = (x & (x + 1)) - 1;
}
return v;
}
T get(int x, int y) { // [x, y]
return (get(y) - (x ? get(x - 1) : 0));
}
};
int main() {
input_checker in;
int n = in.readInt(1, 1e5);
in.readSpace();
int q = in.readInt(1, 1e5);
in.readEoln();
vector<int> len(n), col(n);
for (int i = 0; i < n; i++) {
len[i] = in.readInt(1, 1e4);
in.readSpace();
col[i] = in.readInt(1, n);
in.readEoln();
col[i]--;
}
vector<set<int>> at(n);
for (int i = 0; i < n; i++) {
at[col[i]].emplace(i);
}
for (int i = 0; i < n; i++) {
at[i].emplace(-1);
at[i].emplace(n);
}
const int B = 350;
int C = (n + B - 1) / B;
vector<fenwick<int>> f(C, fenwick<int>(n));
map<int, int> mp;
for (int i = n - 1; i >= 0; i--) {
mp[col[i]] = i;
if (i % B == 0) {
int j = i / B;
for (auto [x, y] : mp) {
f[j].add(y, len[y]);
}
}
}
int last = 0;
while (q--) {
int op = in.readInt(1, 3);
in.readSpace();
int x = in.readInt(0, 2e9);
in.readSpace();
int y = in.readInt(0, 2e9);
in.readEoln();
x ^= last;
y ^= last;
if (op == 1) {
assert(1 <= x && x <= y && y <= n);
x--;
y--;
set<int> st;
int ans = 0;
for (int i = x; i <= y; i++) {
if (i % B == 0) {
int j = i / B;
ans += f[j].get(i, y);
for (int k : st) {
int l = *at[k].lower_bound(i);
if (l <= y) {
ans -= len[l];
}
}
break;
}
if (st.count(col[i])) {
continue;
}
st.emplace(col[i]);
ans += len[i];
}
cout << ans << '\n';
last = ans;
} else if (op == 2) {
assert(1 <= x && x <= n);
x--;
assert(1 <= y && y <= 1e4);
int l0 = *prev(at[col[x]].lower_bound(x));
for (int i = l0 / B; i < C && i * B <= x; i++) {
if (l0 < i * B) {
f[i].add(x, y - len[x]);
}
}
len[x] = y;
} else {
assert(1 <= x && x <= n);
x--;
assert(1 <= y && y <= n);
y--;
int l0 = *prev(at[col[x]].lower_bound(x));
int r0 = *at[col[x]].upper_bound(x);
int l1 = *prev(at[y].lower_bound(x));
int r1 = *at[y].upper_bound(x);
for (int i = 0; i < C && i * B <= x; i++) {
if (l0 < i * B) {
f[i].add(x, -len[x]);
if (r0 < n) {
f[i].add(r0, len[r0]);
}
}
if (l1 < i * B) {
f[i].add(x, len[x]);
if (r1 < n) {
f[i].add(r1, -len[r1]);
}
}
}
at[col[x]].erase(x);
col[x] = y;
at[col[x]].emplace(x);
}
}
in.readEof();
return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int n, q; cin >> n >> q;
vector<int> a(n), b(n);
for (int i = 0; i < n; ++i) cin >> a[i] >> b[i];
const int B = 400;
/**
* Divide into blocks of size B
* let prev[i] = index of largest j < i such that b[i] = b[j]
* in each block, keep elements sorted by prev[i] and build prefix sums
* updates:
* - changing value just changes the prefix sum within a block
* - changing color changes the ordering and prefix sum of at most two blocks
* query:
* - for a full block, some prefix sum (find using binary search)
* - for a non-full block, brute
*/
vector<int> jump(n, -1);
vector<set<int>> who(n+1);
for (int i = 0; i < n; ++i) {
if (!who[b[i]].empty()) jump[i] = *who[b[i]].rbegin();
who[b[i]].insert(i);
}
vector<int> block_order(n);
vector<ll> pref(n);
auto recalc = [&] (int block) {
int lo = block*B, hi = min(n, block*B + B);
for (int i = lo; i < hi; ++i) block_order[i] = i;
sort(begin(block_order)+lo, begin(block_order)+hi, [&] (int i, int j) {
return jump[i] < jump[j];
});
for (int i = lo; i < hi; ++i) {
int u = block_order[i];
pref[i] = a[u];
if (i > lo) pref[i] += pref[i-1];
}
};
for (int i = 0; i <= n/B; ++i) recalc(i);
ll last = 0;
while (q--) {
int type; cin >> type;
if (type == 1) {
ll L, R; cin >> L >> R;
L ^= last, R ^= last;
--L, --R;
last = 0;
int low = L;
while (L <= R) {
if (L%B == 0) break;
if (jump[L] < low) last += a[L];
++L;
}
while (L <= R) {
if (R%B == B-1) break;
if (jump[R] < low) last += a[R];
--R;
}
if (L <= R) {
for (int block = L/B; block <= R/B; ++block) {
int lo = block*B, hi = block*B + B;
auto till = lower_bound(begin(block_order)+lo, begin(block_order)+hi, low, [&] (int i, int x) {
return jump[i] < x;
}) - begin(block_order);
if (till > lo) last += pref[till-1];
}
}
cout << last << '\n';
}
else {
ll pos, val; cin >> pos >> val;
pos ^= last, val ^= last;
--pos;
if (type == 2) {
a[pos] = val;
recalc(pos/B);
}
else {
auto it = who[b[pos]].find(pos);
int x = n, y = n;
if (next(it) != end(who[b[pos]])) {
x = *next(it);
jump[x] = jump[pos];
}
who[b[pos]].erase(pos);
who[val].insert(pos);
b[pos] = val;
auto it2 = who[val].find(pos);
if (next(it2) != end(who[val])) {
y = *next(it2);
jump[y] = pos;
}
if (it2 == begin(who[val])) jump[pos] = -1;
else jump[pos] = *prev(it2);
if (x > y) swap(x, y);
recalc(pos/B);
if (x != n and pos/B < x/B) recalc(x/B);
if (y != n and x/B < y/B) recalc(y/B);
}
}
}
}