MNMXPRPAR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yuvraj3004
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Segment trees

PROBLEM:

For an array A, define f(A) to be the sum of \min(S)\times \max(T) across all ways to partition the array into an ordered pair of non-empty sets (S, T).

You’re given an array A, along with Q point updates to it that replace a value at a given index.
Compute the value of f(A) before all updates, as well as after each update.

EXPLANATION:

As a first step, we of course try to figure out what f(A) is for a fixed array A.

Let’s sort the array A, so that A_1 \leq A_2 \ldots \leq A_N.
For a partition (S, T), note that \min(S) is now obtained by looking at the smallest index in S, while \max(T) is obtained by looking at the largest index in T.

Let’s fix i to be the smallest index in S, and compute the ‘contribution’ of A_i to the answer.
There are two possibilities now:

  1. Every element of T is smaller than i. This would just mean that S = \{i, i+1, \ldots, N\} and T = \{1, 2, \ldots, i-1\}, so \max(T) = A_{i-1}.
    There’s exactly one way for this to happen, and it has a value of A_i\cdot A_{i-1}.
  2. Some element of T is larger than i. Let’s fix j = \max(T). Then, it can be observed that:
    • All indices \lt i must go into T, and all indices \gt j must go into S.
    • Indices in [i+1, j-1] can go to either set: it doesn’t matter.
    • So, there are 2^{j-i-1} possible two-partitions with the answer being A_i\cdot A_j.

This means we quite simply have

f(A) = \sum_{i=1}^{N-1} A_i\cdot A_{i+1} + \sum_{i=1}^N \sum_{j=i+1}^N A_iA_j\cdot 2^{i-j-1}

This can be computed in linear time (given that A is sorted) by separating out the i and j terms of the nested summation - all that’s needed is to maintain the prefix sum of A_i\cdot 2^i (say in array \text{pref}) and the suffix sum of A_j\cdot 2^{-j} (say in array \text{suf}).
With these arrays defined, the second summation becomes \frac12 \sum_{i=1}^N \text{pref}_i \cdot \text{suf}_{i+1}.


Now, we need to process point updates to A.
At first glance, this seems pretty hard to do — point updates change the sorted order of the array A, which in turn changes the \text{pref} and \text{suf} arrays quite a lot too - and those arrays are what we need to maintain to compute the answer.

One way to deal with this is to process updates offline.
That is, first read all possible updates, at which point you know all N+Q values you’ll be dealing with.

Consider all these N+Q values in sorted order.
At any point of time, only N of them are “active”, and each update now deactivates one element and activates another.

Observe that as long as we can handle deactivated elements properly, this setup is much nicer for us: we don’t have to deal with changing orders anymore!

Activation/deactivation updates can now be processed by everyone’s favorite data structure: a segment tree!

In each node of the segment tree, maintain the following information:

  • The answer for the node, that being the value of f(A) when restricted to this range alone.
  • Enough information to be able to combine two nodes.
    • Looking at the expression for the answer, this means you’ll need to store the leftmost and rightmost elements present in the range/
    • You’ll also need something like the sum of A_i\cdot 2^{k_i} for the range (where k_i is the number of elements before the i-th one); as well as similar information for the suffix.

You only need to keep this information and combine it carefully.
Deactivating an element can be done by setting it to the identity.

TIME COMPLEXITY:

\mathcal{O}((N + Q)\log (N+Q)) per testcase.

CODE:

Author's code (C++)
#include "bits/stdc++.h"
using namespace std;
 
#define ll long long
#define vi vector<ll>
#define all(a) a.begin(), a.end()
#define fi first
#define se second
 
 
const ll MOD = 998244353;
struct mint {
    int v; explicit operator int() const { return v;} 
    mint():v(0) {}
    mint(ll _v):v(int(_v%MOD)) { v += (v<0)*MOD; }
};
mint& operator+=(mint& a, mint b) { 
    if ((a.v += b.v) >= MOD) a.v -= MOD; 
    return a; }
mint& operator-=(mint& a, mint b) { 
    if ((a.v -= b.v) < 0) a.v += MOD; 
    return a; }
mint operator+(mint a, mint b) { return a += b; }
mint operator-(mint a, mint b) { return a -= b; }
mint operator*(mint a, mint b) { return mint((ll)a.v*b.v); }
mint& operator*=(mint& a, mint b) { return a = a*b; }
mint pow(mint a, ll p) { assert(p >= 0);
    return p==0?1:pow(a*a,p/2)*(p&1?a:1); }
mint inv(mint a) { assert(a.v != 0); return pow(a,MOD-2); }
mint operator/(mint a, mint b) { return a*inv(b); }
bool operator==(mint a, mint b) { return a.v == b.v; }
using vm = vector<mint>;
 
 
struct node{
    mint s1, s2, m1, m2, lazy1, lazy2;
    node(){
        s1 = 0, s2 = 0;
        m1 = 1, m2 = 1;
        lazy1 = 1, lazy2 = 1;
    }
};
 
struct segtree{
    
    ll n;
    vector<node> nodes;
 
    node comb(node a, node b) {
        node c;
        c.s1 = a.s1 + b.s1;
        c.s2 = a.s2 + b.s2;
        return c;
    } 
 
    segtree(ll n){
        this->n =  n;
        nodes.resize(4*n);
    }
 
    void push(ll k, ll l, ll r){
        ll tm = (l+r)/2;
        for(int i=0; i<2; i++){
            nodes[2*k+i].s1 *= nodes[k].lazy1;
            nodes[2*k+i].m1 *= nodes[k].lazy1;
            nodes[2*k+i].lazy1 *= nodes[k].lazy1;
            nodes[2*k+i].s2 *= nodes[k].lazy2;
            nodes[2*k+i].m2 *= nodes[k].lazy2;
            nodes[2*k+i].lazy2 *= nodes[k].lazy2;
        }
        nodes[k].lazy1 = nodes[k].lazy2 = 1;
    }
 
    node query(ll k, ll tl, ll tr, ll l, ll r){
        if (l > r)
            return node();
        if (l == tl && tr == r)
            return nodes[k];
        push(k, tl, tr);
        ll tm = (tl + tr) / 2;
        return comb(query(k*2, tl, tm, l, min(r, tm)), query(k*2+1, tm+1, tr, max(l, tm+1), r));
    }
 
    node q(ll l, ll r){
        return query(1, 0, n-1, l, r);
    }
 
    void update(ll k, ll tl, ll tr, ll l, ll r, ll val1, ll val2){
        if (l > r) 
            return;
        if (l == tl && tr == r) {
            nodes[k].s1 *= val1;
            nodes[k].m1 *= val1;
            nodes[k].lazy1 *= val1;
            nodes[k].s2 *= val2;
            nodes[k].m2 *= val2;
            nodes[k].lazy2 *= val2;
        }
        else{
            push(k, tl, tr);
            ll tm = (tl + tr) / 2;
            update(k*2, tl, tm, l, min(r, tm), val1, val2);
            update(k*2+1, tm+1, tr, max(l, tm+1), r, val1, val2);
            nodes[k].s1 = nodes[k*2].s1 + nodes[k*2+1].s1;
            nodes[k].s2 = nodes[k*2].s2 + nodes[k*2+1].s2;
        }
    }
 
    void u(ll l, ll r, ll val1, ll val2){
        update(1, 0, n-1, l, r, val1, val2);
    }
 
    void update2(ll k, ll tl, ll tr, ll x, ll val){
        if(tl==tr){
            nodes[k].s1 += val * nodes[k].m1;
            nodes[k].s2 += val * nodes[k].m2;
        }
        else{
            push(k, tl, tr);
            ll tm = (tl + tr) / 2;
            if(x<=tm)   update2(k*2, tl, tm, x, val);
            else    update2(k*2+1, tm+1, tr, x, val);
            nodes[k].s1 = nodes[k*2].s1 + nodes[k*2+1].s1;
            nodes[k].s2 = nodes[k*2].s2 + nodes[k*2+1].s2;
        }
    }
 
    void u2(ll x, ll val){
        update2(1, 0, n-1, x, val);
    }
 
};
 
 
signed main(){
 
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
 
    ll t;
    cin>>t;
 
    ll i2 = inv(2).v;
 
    while(t--){
 
        ll n, q;
        cin>>n>>q;
 
        vector<pair<ll, ll>> z(n+q+1);
        multiset<ll> s;
 
        vi v(n+1);
        for(int i=1; i<=n; i++){
            cin>>v[i];
            z[i] = {v[i], i};
        }
 
        vector<array<ll, 2>> queries(q+1);
        for(int i=1; i<=q; i++){
            for(int j=0; j<2; j++){
                cin>>queries[i][j];
            }
            z[n+i] = {queries[i][1], i+n};
        }
 
        sort(all(z));
 
        vi idx(n+q+1);
        for(int i=1; i<=n+q; i++){
            idx[z[i].se] = i;
        }
 
        mint ans = 0, pr = 0;
        segtree seg(n+q+1);
 
        auto contribution = [&](ll pos){
            mint res = 0;
            node n1 = seg.q(1, pos-1);
            node n2 = seg.q(pos, pos);
            node n3 = seg.q(pos+1, n+q);
            res += n2.s1 * n1.s2 + n2.s2 * n3.s1 + (n1.s2 * n3.s1)/2;
            return res/2;
        };
 
        auto add = [&](ll pos, ll val){
            seg.u(pos+1, n+q, 2, i2);
            seg.u2(pos, val);
            ans += contribution(pos);
            auto it = s.upper_bound(val);
            if(it != s.end()){
                pr += val * (*it);
            }
            if(it != s.begin()){
                auto it2 = it;
                it2--;
                pr += val * (*it2);
                if(it != s.end()){
                    pr -= (*it) * (*it2);
                }
            }
            s.insert(val);
        };
 
        auto remove = [&](ll pos, ll val){
            ans -= contribution(pos);
            seg.u(pos+1, n+q, i2, 2);
            seg.u2(pos, val);
            val = -val;
            auto it = s.upper_bound(val);
            auto it2 = it;
            it2--;
            if(it != s.end()){
                ans -= (*it) * val;
            }
            if(it2 != s.begin()){
                it2--;
                ans -= val * (*it2);
                if(it != s.end()){
                    ans += (*it) * (*it2);
                }
            }
            s.erase(s.find(val));
        };
 
        vector<pair<ll, ll>> idk(n+1);
        for(int i=1; i<=n; i++){
            add(idx[i], v[i]);
            idk[i] = {v[i], idx[i]};
        }
 
        cout<<(ans+pr).v<<"\n";
        for(int i=1; i<=q; i++){
            ll pos = queries[i][0];
            remove(idk[pos].se, -idk[pos].fi);
            add(idx[n+i], queries[i][1]);
            idk[pos] = {queries[i][1], idx[n+i]};
            cout<<(ans+pr).v<<"\n";
        }
 
    }
 
    return 0;
}

Tester's code (C++)
#include<bits/stdc++.h>

#include <algorithm>
#include <cassert>
#include <functional>
#include <vector>


#ifdef _MSC_VER
#include <intrin.h>
#endif

#if __cplusplus >= 202002L
#include <bit>
#endif

namespace atcoder {

namespace internal {

#if __cplusplus >= 202002L

using std::bit_ceil;

#else

unsigned int bit_ceil(unsigned int n) {
    unsigned int x = 1;
    while (x < (unsigned int)(n)) x *= 2;
    return x;
}

#endif

int countr_zero(unsigned int n) {
#ifdef _MSC_VER
    unsigned long index;
    _BitScanForward(&index, n);
    return index;
#else
    return __builtin_ctz(n);
#endif
}

constexpr int countr_zero_constexpr(unsigned int n) {
    int x = 0;
    while (!(n & (1 << x))) x++;
    return x;
}

}  // namespace internal

}  // namespace atcoder


namespace atcoder {

#if __cplusplus >= 201703L

template <class S, auto op, auto e> struct segtree {
    static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
                  "op must work as S(S, S)");
    static_assert(std::is_convertible_v<decltype(e), std::function<S()>>,
                  "e must work as S()");

#else

template <class S, S (*op)(S, S), S (*e)()> struct segtree {

#endif

  public:
    segtree() : segtree(0) {}
    explicit segtree(int n) : segtree(std::vector<S>(n, e())) {}
    explicit segtree(const std::vector<S>& v) : _n(int(v.size())) {
        size = (int)internal::bit_ceil((unsigned int)(_n));
        log = internal::countr_zero((unsigned int)size);
        d = std::vector<S>(2 * size, e());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) const {
        assert(0 <= p && p < _n);
        return d[p + size];
    }

    S prod(int l, int r) const {
        assert(0 <= l && l <= r && r <= _n);
        S sml = e(), smr = e();
        l += size;
        r += size;

        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return op(sml, smr);
    }

    S all_prod() const { return d[1]; }

    template <bool (*f)(S)> int max_right(int l) const {
        return max_right(l, [](S x) { return f(x); });
    }
    template <class F> int max_right(int l, F f) const {
        assert(0 <= l && l <= _n);
        assert(f(e()));
        if (l == _n) return _n;
        l += size;
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(op(sm, d[l]))) {
                while (l < size) {
                    l = (2 * l);
                    if (f(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*f)(S)> int min_left(int r) const {
        return min_left(r, [](S x) { return f(x); });
    }
    template <class F> int min_left(int r, F f) const {
        assert(0 <= r && r <= _n);
        assert(f(e()));
        if (r == 0) return 0;
        r += size;
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(op(d[r], sm))) {
                while (r < size) {
                    r = (2 * r + 1);
                    if (f(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

  private:
    int _n, size, log;
    std::vector<S> d;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

}  // namespace atcoder


using namespace std;
using namespace atcoder;

#define mod 998244353

using pl=pair<long long,long long>;

#define SZ 1048576
long long p2[SZ];

typedef struct{
  long long hd;
  long long tl;
  long long hc;
  long long tc;
  long long len;
  long long res;
}S;

S op(S l,S r){
  S res;
  res.hd=(l.hd+p2[l.len]*r.hd)%mod;
  res.tl=(l.tl*p2[r.len]+r.tl)%mod;

  if(l.hc==-1){res.hc=r.hc;}
  else{res.hc=l.hc;}
  if(r.tc==-1){res.tc=l.tc;}
  else{res.tc=r.tc;}

  res.len=(l.len+r.len);
  res.res=(l.res+r.res)%mod;
  res.res+=(l.tl*r.hd); res.res%=mod;
  if(l.tc!=-1 && r.hc!=-1){
    res.res+=(l.tc*r.hc); res.res%=mod;
  }
  return res;
}
S e(){
  return {0,0,-1,-1,0,0};
}
S cvrt(long long x){
  return {x,x,x,x,1,0};
}

int main(){
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  p2[0]=1;
  for(long long i=1;i<SZ;i++){ p2[i]=(p2[i-1]*2)%mod; }

  int t;
  cin >> t;
  while(t>0){
    t--;
    long long n,q;
    cin >> n >> q;

    vector<long long> a(n);
    for(auto &nx : a){cin >> nx;}
    vector<long long> sv=a;

    vector<pl> ql(q);
    for(auto &nx : ql){
      cin >> nx.first >> nx.second;
      nx.first--;
      sv.push_back(nx.second);
    }
    sort(sv.begin(),sv.end());

    long long l=sv.size();
    set<pl> s0,s1;
    for(long long i=0;i<l;i++){
      s0.insert({sv[i],i});
    }
    segtree<S,op,e> seg(l);
    for(auto &nx : a){
      auto it=s0.lower_bound({nx,-1});
      s1.insert(*it);
      seg.set((*it).second,cvrt(nx));
      s0.erase(it);
    }
    cout << seg.all_prod().res << "\n";
    for(auto &nx : ql){
      long long i=nx.first;
      long long v=nx.second;

      // erase a[i]
      {
        auto it=s1.lower_bound({a[i],-1});
        s0.insert(*it);
        seg.set((*it).second,e());
        s1.erase(it);
      }

      a[i]=v;

      // add v
      {
        auto it=s0.lower_bound({v,-1});
        s1.insert(*it);
        seg.set((*it).second,cvrt(v));
        s0.erase(it);
      }
      
      cout << seg.all_prod().res << "\n";
    }
  }
  return 0;
}

// #include<bits/stdc++.h>
//
// using namespace std;
//
// int main(){
//   for(int n=1;n<=8;n++){
//     vector<vector<int>> bk(n,vector<int>(n,0));
//     for(int i=1;i<(1<<n)-1;i++){
//       int x=1e9,y=-1e9;
//       for(int j=0;j<n;j++){
//         if(i&(1<<j)){x=min(x,j);}
//         else{y=max(y,j);}
//       }
//       if(x>y){swap(x,y);}
//       bk[x][y]++;
//     }
//
//     for(auto &nx : bk){
//       for(auto &ny : nx){cout << ny << " ";}
//       cout << "\n";
//     }cout << "\n";
//   }
//   return 0;
// }

1 Like

Can someone explain me the offline part of handling the queries , where we are activating and deactivating the elements. I am unable to catch the approach .

If the change persists (meaning after each update the array does not go back to its original state) , how can we do it using offline queries??
Because sorting the queries might reorder them and then before activating/deactivating any element we are not aware of the previous contents of the array