MASTER - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3

Setter: Ma Zihang
Tester: Manan Grover
Editorialist: Kanhaiya Mohan

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Segment Tree, Set

PROBLEM:

Given a sequence A of length N.
Define f(i, j) to be the number of distinct elements in A_i, A_{i+1}, …, A_j.
You need to process Q queries of the following type:

  • \texttt{1 x y} - Set A_x = y.
  • \texttt{2 k} - Print \sum_{i=1}^k\sum_{j=i}^k f(i,j).

EXPLANATION:

Subtask 1: No type 1 queries.

Considering that we do not need to change any element of the sequence, we can formulate a simple DP solution. Let us look at a particular element A_i. There can be two cases:

  • The first occurrence of A_i is at position i: This means that for 1 \leq j < i, A_j \neq A_i. Then, this element would contribute to the answers of all suffixes ending at position i. In other words, its contribution to all the queries of type \texttt{2 k}, (i \leq k \leq N) would be i.
  • The first occurrence of A_i is not at position i: This means that there exists a position j, (1 \leq j < i), such that A_j = A_i. Here, element A_i is distinct for all the subarrays starting after position j and ending at position i. Thus, the contribution of element A_i to all the queries of type \texttt{2 k}, (i \leq k \leq N) would be i-j.

We can thus precalculate the answer for each index by storing the last position of an element and the answer till the previous index. The complexity of this approach would be O(N).

Subtask 2: Original Constraints.

The contribution of a kind of number in the array is the number of subarrays containing no such numbers subtracted from the total number of subarrays.
Formally, if there are m occurrences of a number x in the array (of size N), denoted by p_1<p_2<...<p_m, the total contribution of x is given by: \frac{N(N-1)}{2} - \sum_{i=0}^{m}\frac{(p_{i+1}-p_i-1)(p_{i+1}-p_i-2)}{2} (p_0 = 0 and p_{m+1} = N+1).

  • We can break the total contribution of x into two parts i.e. \frac{N(N-1)}{2} and - \sum_{i=0}^{m}\frac{(p_{i+1}-p_i-1)(p_{i+1}-p_i-2)}{2}.
  • Also, observe that \frac{N(N-1)}{2} = 1+2+...+(N-1).

The solution turns out to be: Maintain a prefix sum array S and do the following for every kind of number:
Step 1: For every 1 \leq i \leq N, add i to S_i. This is for the first part of the contribution.
Step 2: For every 0 \leq i \leq m, add -1 to S_{p_i+1}, -2 to S_{p_i+2}, …, p_i - p_{i+1} to S_{p_{i+1}-1} and so on. This is for the second part of the contribution.

These steps can be done in O(Nlog(N)) using set and segment tree. See setter’s solution for implementation.

Note that you don’t need to accomplish adding an arithmetic sequence to a interval, because the difference is offset by doing step 1 and 2 at the same time.

TIME COMPLEXITY:

The time complexity is O((N+Q)log(N)) per test case.

SOLUTION:

Setter's Solution
#include <iostream>
#include <set>
#include <string>

int const N = 2e5;

struct SegmentTree {
    struct Node {
        int l, r;
        long long sum, add;
    };
    
    Node tree[4 * N + 1];
    
    void pushup(int x) {
        tree[x].sum = tree[2 * x].sum + tree[2 * x + 1].sum;
    }
    
    void tag(int x, long long add) {
        tree[x].sum += 1ll * (tree[x].r - tree[x].l + 1) * add;
        tree[x].add += add;
    }
    
    void pushdown(int x) {
        if (tree[x].add) {
            tag(2 * x, tree[x].add);
            tag(2 * x + 1, tree[x].add);
            tree[x].add = 0;
        }
    }
    
    void build(int x, int l, int r) {
        tree[x].l = l;
        tree[x].r = r;
        
        if (l != r) {
            int mid = (l + r) / 2;
            
            build(2 * x, l, mid);
            build(2 * x + 1, mid + 1, r);
            pushup(x);
        }
    }
    
    void add(int x, int l, int r, int v) {
        if (l <= r) {
            if (tree[x].l >= l && tree[x].r <= r) {
                tag(x, v);
            }
            else {
                int mid = (tree[x].l + tree[x].r) / 2;
                pushdown(x);
                
                if (l <= mid) {
                    add(2 * x, l, r, v);
                }
                
                if (r > mid) {
                    add(2 * x + 1, l, r, v);
                }
                
                pushup(x);
            }
        }
    }
    
    long long get(int x, int l, int r) {
        if (tree[x].l >= l && tree[x].r <= r) {
            return tree[x].sum;
        }
        else {
            int mid = (tree[x].l + tree[x].r) / 2;
            long long ans = 0;
            pushdown(x);
            
            if (l <= mid) {
                ans += get(2 * x, l, r);				
            }
            
            if (r > mid) {
                ans += get(2 * x + 1, l, r);
            }
            
            return ans;
        }
    }
};

int a[N + 1];
std::set<int> each[N + 1];
SegmentTree seg;

void remove(int, int);
void add(int, int);

void remove(int x, int pos) {
    each[x].erase(pos);
    auto r = each[x].lower_bound(pos);
    auto l = r;
    l--;
    
    seg.add(1, pos, *r - 1, -(pos - *l));
}

void add(int x, int pos) {
    auto r = each[x].lower_bound(pos);
    auto l = r;
    l--;
    
    seg.add(1, pos, *r - 1, pos - *l);
    each[x].insert(pos);
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    std::cout.tie(0);
    
    int n, q;
    std::cin >> n >> q;
    
    for (int i = 1; i <= n; i++) {
        std::cin >> a[i];
        
        each[a[i]].insert(i);
    }
    
    seg.build(1, 1, n);
    for (int i = 1; i <= n; i++) {
        each[i].insert(0);
        each[i].insert(n + 1);
        
        auto forward = ++each[i].begin();
        auto back = each[i].begin();
        
        while (forward != each[i].end()) {
            seg.add(1, *back + 1, *forward - 1, *back);
            
            if (*forward <= n) {
                seg.add(1, *forward, *forward, *forward);
            }
            
            forward++;
            back++;
        }
    }
    
    for (int i = 1; i <= q; i++) {
        int op;
        std::cin >> op;
        
        if (op == 1) {
            int x, v;
            std::cin >> x >> v;
            
            remove(a[x], x);
            add(v, x);
            
            a[x] = v;
        }
        else {
            int x;
            std::cin >> x;
            
            std::cout << seg.get(1, 1, x) << '\n';
        }
    }
    
    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I long long int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define N 200001
#define rnd randGen(rng)
class seg{
public:
struct node{
    //........start
    I sum;
    I lazy;
    //........end
    I lft,rgt;
};
I m;
V(node) segarr;
node merge(node a,node b){
    node ans;
    ans.lft=min(a.lft,b.lft);
    ans.rgt=max(a.rgt,b.rgt);
    //..............start
    ans.sum=a.sum+b.sum;
    ans.lazy=0;
    //..............end
    return ans;
}
void make(node &temp,I a,B f){
    if(f){
    //.......start
    temp.sum=a;
    //.......end
    }else{
    //.......start
    temp.sum=0;
    //.......end
    }
    temp.lazy=0;
}
seg(I n,I arr[]){
    m=pow(2,ceil(log2(n)));
    node temp;
    asc(i,0,2*m-1){
    segarr.pb(temp);
    }
    asc(i,0,m){
    if(i<n){
        make(segarr[i+m-1],arr[i],true);
    }else{
        make(segarr[i+m-1],arr[i],false);
    }
    segarr[i+m-1].lft=i;
    segarr[i+m-1].rgt=i;
    }
    dsc(i,0,m-1){
    segarr[i]=merge(segarr[2*i+1],segarr[2*i+2]);
    }
}
void push(I x){
    segarr[x].sum+=segarr[x].lazy*(segarr[x].rgt-segarr[x].lft+1);
    if(2*x+1<segarr.size()){
    segarr[2*x+1].lazy+=segarr[x].lazy;
    segarr[2*x+2].lazy+=segarr[x].lazy;
    }
    segarr[x].lazy=0;
}
node query_help(I l,I r,I x){
    push(x);
    if(segarr[x].lft>=l && segarr[x].rgt<=r){
    return segarr[x];
    }
    if(l>segarr[2*x+1].rgt){
    return query_help(l,r,2*x+2);
    }
    if(r<segarr[2*x+2].lft){
    return query_help(l,r,2*x+1);
    }
    return merge(query_help(l,r,2*x+1),query_help(l,r,2*x+2));
}
node query(I l,I r){
    return query_help(l,r,0);
}
void update_help(I x){
    segarr[x]=merge(segarr[2*x+1],segarr[2*x+2]);
    if(x!=0){
    update_help((x-1)/2);
    }
}
void update(I x,I temp){
    I y=x+m-1;
    make(segarr[y],temp,true);
    if(y!=0){
    update_help((y-1)/2);
    }
}
void update_range_help(I l,I r,I x,I temp){
    if(segarr[x].lft>=l && segarr[x].rgt<=r){
    segarr[x].lazy+=temp;
    return;
    }
    segarr[x].sum+=temp*(min(r,segarr[x].rgt)-max(l,segarr[x].lft)+1);
    if(l>segarr[2*x+1].rgt){
    update_range_help(l,r,2*x+2,temp);
    return;
    }
    if(r<segarr[2*x+2].lft){
    update_range_help(l,r,2*x+1,temp);
    return;
    }
    update_range_help(l,r,2*x+1,temp);
    update_range_help(l,r,2*x+2,temp);
}
void update_range(I l,I r,I temp){
    update_range_help(l,r,0,temp);
}
};
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);

            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
int main(){
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
uniform_int_distribution<I> randGen;
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
I n,q;
n=readInt(1,N-1,' ');
q=readInt(1,N-1,'\n');
I a[n];
OS(I) mpp[N];
asc(i,0,N){
    mpp[i].insert(-1);
}
asc(i,0,n){
    if(i==n-1){
    a[i]=readInt(1,n,'\n');
    }else{
    a[i]=readInt(1,n,' ');
    }
}
I b[n];
b[0]=1;
mpp[a[0]].insert(0);
asc(i,1,n){
    b[i]=b[i-1];
    A it=mpp[a[i]].end();
    it--;
    b[i]+=i-(*it);
    mpp[a[i]].insert(i);
}
asc(i,0,N){
    mpp[i].insert(n);
}
seg s(n,b);
asc(i,0,q){
    I x,y;
    x=readInt(1,2,' ');
    if(x==2){
    y=readInt(1,n,'\n');
    }else{
    y=readInt(1,n,' ');
    }
    y--;
    if(x==2){
    cout<<s.query(0,y).sum<<"\n";
    }else{
    I z;
    z=readInt(1,n,'\n');
    if(a[y]==z){
        continue;
    }
    A it=mpp[a[y]].find(y);
    it--;
    I temp=(*it)-y;
    it++;
    it++;
    s.update_range(y,(*it)-1,temp);
    it--;
    mpp[a[y]].erase(it);
    it=mpp[z].lb(y);
    it--;
    temp=y-(*it);
    it++;
    s.update_range(y,(*it)-1,temp);
    mpp[z].insert(y);
    a[y]=z;
    }
}
return 0;
}
6 Likes