MINOP343 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Stacks, prefix sums

PROBLEM:

For an array B and integer K, define \text{score}(B, K) to be the minimum number of the following operations required to make every element of B zero:

  • Choose a subsequence of B such that adjacent positions are at least K apart, and subtract 1 from every element.

Given array A and parameter P, find the sum of \text{score}(A[L\ldots R], P) across all subarrays A[L\ldots R] of A.

EXPLANATION:

Our first order of business should be to figure out what exactly \text{score}(B, K) is, for a fixed B and K.

Answer

If B has length \leq K, the answer is of course just the sum of B: each of our moves can only subtract 1 from a single element.

When |B|\gt K, we can use the above observation to obtain, at the very least, a lower bound on the answer.
Consider some subarray of B of length K, say [B_i, B_{i+1}, \ldots, B_{i+K-1}].
Any move we make can reduce the sum of this subarray by at most 1, since we can choose at most one element from it.
So, we definitely need at least B_i + B_{i+1} + \ldots + B_{i+K-1} moves.

In particular, let M denote the maximum sum of some subarray of length K.
Then, we definitely need \geq M moves to make all of B into zeros.
As it turns out, M moves is also sufficient!

Proof

First, if M = 0 the claim is trivially true, so we work with M\gt 0.

We’ll prove the following.
Claim: There exists an operation such that every window with sum M has an element chosen from it.
Note that this ensures the maximum reduces by 1, so repeating it M times will make everything zero, as required.

Proof:
Let L_1 denote the leftmost K-subarray with sum M.
Let i_1 denote the index of the rightmost non-zero element of the range [L_1, L_1+K-1].
We reduce B_{i_1} by one.
This automatically takes care of any K-subarray that includes i_1; so let’s look at ones that start after it.

Let L_2 denote the leftmost K-subarray with sum M, such that L_2 \gt i_1 (if it exists, of course).
If L_2 \geq i_1 + K, then choosing i_1 doesn’t restrict us in the subarray starting at L_2, so simply apply the same argument to the suffix starting from L_2 instead.
This leaves the case where L_2 \lt i_1 + K, meaning there’s some “overlap”.

Let d = L_2 - i_1 be the distance between them.
Note that we’re able to choose only from the last d elements of [L_2, L_2+K-1]: anything else will be too close to i_1.
Now, observe that since [L_2, L_2+K-1] is a subarray with maximum sum, its last d elements must have a sum that’s \geq the sum B_{i_1} + B_{i_1 + 1} + \ldots + B_{L_2 - 1} — if note, those last d elements could be replaced with these d elements to obtain a subarray with strictly larger sum.
However, B_{i_1}\gt 0 means that there exists a non-zero element among these last d elements (say at index i_2), so simply choose any one of them to include into the operation.
Now, replace L_1 and i_1 with L_2 and i_2, and repeat the process till every subarray with sum M is covered.


We now know what we want to find.
For each subarray A[L\ldots R] of A,

  • If R-L+1 \leq K, add to the answer the sum A_L + A_{L+1} + \ldots + A_R.
  • Otherwise, find the maximum sum of a length-K subarray in this range, and add its sum to the answer.

These can be done separately.

Short subarrays

This is the easy part.
We want to find the sum of all subarrays of length \leq K.
That’s relatively straightforward with the help of prefix sums.

Let P_i = A_1 + A_2 + \ldots + A_i denote the i-th prefix sum of A.
Then, when considering all subarrays ending at index i, we want to find the sum

(P_i - P_{i-1}) + (P_i - P_{i-2}) + \ldots + (P_i - P_{i-K}) \\ = K\cdot P_i - (P_{i-1} + P_{i-2} + \ldots + P_{i-K})

The latter is just a range sum over the prefix sum array, and so can be found quickly by building a prefix sum array over P (or, since the length K is fixed, a 2-pointer method works too).

A small detail to keep in mind is that when i \lt K, you have less than K elements in this summation, so the multiplier for P_i should be adjusted accordingly (it’ll be i\cdot P_i).

Long subarrays

Now, we want to find, across all subarrays of length \gt K, the sum of their largest K-sized subarray.

There’s a couple different ways to do this, I’ll present one that I think is relatively simple.
Let S_i = A_i + A_{i+1} + \ldots = A_{i+K-1} denote the sum of the K-window starting at i.
Note that S will be an array of length N-K+1.

For each i = 1, 2, \ldots, N-K+1, let’s count the number of subarrays for which S_i is the last maximal K-window.
If this is to be the case for some subarray [L, R],

  • L \leq i and R \geq i+K-1 should hold, of course.
  • Let j_1 \lt i be the rightmost index such that S_{j_1} \gt S_i.
    Then, [L, R] cannot include index j_1 (if it does, it’ll include this entire window and S_i won’t be maximal); so L \gt j_1 must hold.
  • Let j_2 \gt i be the leftmost index such that S_{j_2} \geq S_i.
    By similar reasoning, [L, R] cannot include index j_2 + K - 1; so R \lt j_2 + K - 1.
  • This gives us the following bounds:
    • j_1 \lt L \leq i
    • i+K-1 \leq R \lt j_2 + K - 1

Within these bounds, any choice of L and R will do; so the number of subarrays with S_i being the last maximum is exactly (i-j_1) \cdot (j_2+K-1 - i - K + 1).
Add this product multiplied by S_i to the answer and move on.

Observe that any subarray counted this way has length \geq K, since L \leq i and R\geq i+K-1, so we don’t need to worry about that condition.
On the other hand, you do need to ensure that subarrays of length K aren’t double counted in both this part and as “short” subarrays; so if your implementation does that, make sure to subtract them out at the end.

Note that j_1 and j_2 are essentially next/previous greater elements, and computing them for all elements of S in linear time using a stack is a well-known task - see here for instance.

Both parts are done in linear time, so we’re done!

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500;
ll sum_val(ll l){
    l=(l*(l+1))/2;
    l%=MOD; 
    return l;
}  
void solve(){        
    ll n,k; cin>>n>>k;    
    vector<ll> a(n+5),pref(n+5,0); 
    ll ans=0; 
    vector<ll> pref_pref(n+5,0);    
    vector<pair<ll,ll>> sum_order; 
    for(ll i=1;i<=n;i++){    
        cin>>a[i];
        pref[i]=pref[i-1]+a[i];
        if(i>=k){
            sum_order.push_back({pref[i]-pref[i-k],i-k+1});
        }
        pref_pref[i]=pref_pref[i-1]+pref[i];
        pref_pref[i]%=MOD;
        ans+=(pref[i]%MOD)*min(k-1,i);
        ans%=MOD;
        ans-=pref_pref[i-1]-pref_pref[max(0ll,i-(k-1)-1)];
        ans%=MOD;
    }
    sort(all(sum_order)); 
    reverse(all(sum_order));  
    set<ll> track;
    auto getv=[&](ll p,ll status){
        ll pos=-1;
        if((*track.begin()==p) and (status==-1)){
            return pos;
        }
        if((*track.rbegin()==p) and (status==1)){
            return pos;
        }
        auto it=track.find(p);
        if(status==-1){
            it--;
        }
        else{
            it++;  
        }
        return *it;
    };
    auto sub_count=[&](ll l,ll r){
        if(min(l,r)==-1){
            return 0ll; 
        }
        r+=k-1;
        r=n+1-r;
        ll now=(l*r)%MOD;
        return now; 
    };
    for(auto it:sum_order){
        ll pos=it.s;
        track.insert(pos);
        ll val=it.f;
        val%=MOD;
        ll l=getv(pos,-1),r=getv(pos,1);
        ll cur=sub_count(pos,pos)-sub_count(l,pos)-sub_count(pos,r)+sub_count(l,r);
        cur%=MOD;
        ans=(ans+val*cur)%MOD;
    }
    ans=(ans+MOD)%MOD;
    cout<<ans<<nline;
    return;    
}                                           
int main()                                                                               
{     
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                 
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

struct sparse {
    using T = pair<long long, int>;
    int n;
    int h;
    vector<vector<T>> table;

    T op(T x, T y) {
        return max(x, y);
    }

    sparse(const vector<T>& v) {
        n = (int) v.size();
        h = 32 - __builtin_clz(n);
        table.resize(h);
        table[0] = v;
        for (int j = 1; j < h; j++) {
            table[j].resize(n - (1 << j) + 1);
            for (int i = 0; i <= n - (1 << j); i++) {
                table[j][i] = op(table[j - 1][i], table[j - 1][i + (1 << (j - 1))]);
            }
        }
    }

    T get(int l, int r) {
        assert(0 <= l && l < r && r <= n);
        int k = 31 - __builtin_clz(r - l);
        return op(table[k][l], table[k][r - (1 << k)]);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e6);
        in.readSpace();
        int k = in.readInt(1, n - 1);
        in.readEoln();
        auto a = in.readInts(n, 0, 1e9);
        in.readEoln();
        mint ans = 0;
        {
            vector<long long> pref(n + 1);
            for (int i = 0; i < n; i++) {
                pref[i + 1] = pref[i] + a[i];
            }
            vector<pair<long long, int>> t;
            for (int i = 0; i + k <= n; i++) {
                t.emplace_back(pref[i + k] - pref[i], i);
            }
            int sz = (int) t.size();
            sparse sp(t);
            function<void(int, int)> Dfs = [&](int l, int r) {
                if (l >= r) {
                    return;
                }
                auto x = sp.get(l, r);
                ans += x.first * mint(r - x.second) * (x.second - l + 1);
                Dfs(l, x.second);
                Dfs(x.second + 1, r);
            };
            Dfs(0, sz);
        }
        vector<mint> f(n + 1);
        for (int i = 1; i < k; i++) {
            f[0]++;
            f[i]--;
            f[n - i + 1]--;
        }
        for (int i = 0; i < n; i++) {
            f[i + 1] += f[i];
        }
        for (int i = 0; i < n; i++) {
            f[i + 1] += f[i];
        }
        for (int i = 0; i < n; i++) {
            ans += f[i] * a[i];
        }
        cout << ans << '\n';
    }
    in.readEof();
    assert(sn <= 1e6);
    return 0;
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    pref = [0]
    for x in a: pref.append(pref[-1] + x)
    windows = []
    nxt, prv = [n]*(n+1), [-1]*(n+1)
    for i in range(k, n+1): windows.append(pref[i] - pref[i-k])
    stk = []
    for i in range(len(windows)):
        while len(stk) > 0:
            u = stk[-1]
            if windows[u] <= windows[i]: stk.pop()
            else: break
        if len(stk) > 0: prv[i] = stk[-1]
        stk.append(i)
    stk = []
    for i in reversed(range(len(windows))):
        while len(stk) > 0:
            u = stk[-1]
            if windows[u] < windows[i]: stk.pop()
            else: break
        if len(stk) > 0: nxt[i] = stk[-1]
        stk.append(i)
    
    ans = 0
    for i in range(len(windows)):
        L, R = i, i+k-1
        if prv[i] == -1: L = 0
        else: L = prv[i]+1
        if nxt[i] == n: R = n-1
        else: R = nxt[i]+k-2
        ans += (i-L+1) * (R-(i+k-1)+1) * windows[i] % mod
    
    # Lengths < k
    sm, ct = 0, 1
    for i in range(1, n+1):
        if k == 1: break
        # pref[i] - pref[i-1] + pref[i] + pref[i-2] + ... + pref[i] - pref[i-k+1]
        ans += (pref[i] * ct - sm) % mod
        sm = (sm + pref[i]) % mod
        if ct < k-1: ct += 1
        else: sm = (sm - pref[i-k+1]) % mod
    print(ans % mod)
2 Likes

My solution of this problem is different from editorial.
Iterate on increasing order of r=1 to n, and we can infact maintain right endpoints of the operations required.

Submissions -
O(n^2) - CodeChef: Practical coding for everyone
Optimised O(nlogn) - CodeChef: Practical coding for everyone

Right, that works too.

Once you observe that the answer for a subarray is its largest K-window sum, the problem reduces to something quite similar to “given an array A, find the sum of \max(A[L\ldots R]) across all subarrays of A”, and most solutions of that can be adapted to this problem too.

One way is mentioned in the editorial, one is yours, the tester’s solution is to use recursion (find the maximum of the range (using some RMQ structure) → compute its contribution → recursively solve for the left and right parts).