CONT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: ro27
Preparation: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2760

PREREQUISITES:

Stacks, or DSU and multiset/segment trees

PROBLEM:

Given an array A, count the number of its subarrays such that the absolute difference between any adjacent pair of elements inside doesn’t exceed the length of the subarray.

EXPLANATION:

There are various solutions to this task, ranging from \mathcal{O}(N\log N) to even \mathcal{O}(N).
They all use the same main observation, however: only the maximum adjacent difference within a subarray really matters.

O(N logN)

Let D_i = |A_i - A_{i+1}| denote the i-th adjacent difference. D is an array of length N-1.

Let’s fix L, the length of the subarray; and we’ll try to count the number of good subarrays of length exactly L.
Note that if D_i \gt L, then we aren’t allowed to include both index i and index i+1 in the same length-L subarray.

This breaks the array into several contiguous blocks of elements, such that the adjacent differences within each block doesn’t exceed L. We’re only allowed to choose subarrays of length L that lie fully within some such block.
For instance, if A = [1, 4, 6, 4, 1, 2, 2, 5, 7] and L = 2, the array is broken up as [1], [4, 6, 4], [1, 2, 2], [5, 7].

Suppose there are k blocks, with lengths x_1, x_2, \ldots, x_k.
The number of subarrays of length L in the i-th block is then \max(0, x_i-L+1).
So, the total number of good subarrays of length L is \sum_{i=1}^k \max(0, x_i-L+1).

This gives us a straightforward \mathcal{O}(N^2) solution already.

To optimize it, let’s see how things change when moving from L to L+1.
Say we’ve processed length L, and have a bunch of disjoint segments.
When moving to L+1, these segments don’t change all that much: in fact, the only real change is that some of them merge together!

Specifically, the only change is that for each i such that D_i = L+1, the segments ending at i and starting at i+1 merge together.
By maintaining some data structure that allows for fast merges such as a DSU (though since we’re dealing with intervals, a set + binary search works too), these merges can be simulated quickly.
Note that there are only N-1 merges at most, so simulating them all is perfectly fine.

We now need to figure out how to update the contribution of the segment lengths to the answer.
Keep a multiset S of lengths of the currently active segments.
This is easy to maintain while merging: two lengths are deleted and their sum is inserted.
Now, we want the sum of \max(0, x-L+1) for all x\in S; or rather, the sum of x-L+1 for all x\in S such that x \geq L.

Notice that \sum (x-L+1) breaks up into \sum(x) and \sum (L-1).
The first summation is the sum of everything in S that’s \geq L.
The second one is L-1 times the count of the number of things in S that are \geq L.
There’s a couple of ways to find this information quickly:

  • The ‘standard’ way is to use a segment tree built on the values 1 to N — keep both the sum of elements and the count of elements in range, and query the appropriate suffix.
    Updates change only three indices in the segment tree, so direct point updates are fine.
  • There’s also a ‘bruteforce’ approach.
    Maintain the multiset S of lengths, as mentioned above.
    Iterate over its elements in descending order and add x-L+1 to the answer for each one.
    However, break out the instant you reach an element \lt L.

The second approach might seem quadratic overall, but it isn’t!
Since the segments are disjoint and their total length is N, there can be at most \left\lfloor \frac{N}{L} \right\rfloor segments of length \geq L.
So, as long as we break out early, the total number of segments we iterate across is bounded by \sum_L \left\lfloor \frac{N}{L} \right\rfloor, which is well-known to be \mathcal{O}(N\log N).

O(N)

We use a similar observation to the \mathcal{O}(N\log N) version: for a fixed L, the array is broken up into several segments whose adjacent differences are no more than L.

As noted in the \mathcal{O}(N\log N) solution, the only way these segments can change in the future is for adjacent segments to merge.

So, let’s fix a segment [L, R] with length K = R-L+1, and figure out the ‘time’ during when it’s alive (i.e, before it merges).
Let this time interval be [T_1, T_2].
This segment then contributes \max(0, K-x+1) to \text{ans}[x] for each T_1 \leq x \leq T_2.
Let’s only deal with x \leq K, so we have to add K-x+1 to \text{ans}[x] for some range of x.
Here, \text{ans}[x] denotes the number of good subarrays of length exactly x.

This can be done by splitting up the sum into two parts: adding K+1 to some range, and performing \text{ans}[x] \to \text{ans}[x] + x for some range.
The first one is a range-add update, and it’s well-known that Q range-add updates can be performed offline in linear time in \mathcal{O}(N+Q) with the help of prefix sums.
\text{ans}[x] \to \text{ans}[x] + x for a range of x can also be performed offline similarly: just keep track of the number of times this update needs to be done for a particular x (at which point it turns into adding 1 to a range).

Hence, if we’ve found a segment and the times associated with it, we know how to solve the task.
To actually find these, observe what a segment looks like.
If [L, R] is the segment, then:

  • Let M be the maximum adjacent difference within the segment.
  • Then, the left border should satisfy L = 1 or |A_L - A_{L-1}| \gt M (otherwise the segment could be extended to the left).
    Similarly, the right should satisfy R = N or |A_R - A_{R+1}| \gt M.
  • The times are also easy: this segment first exists at time T_1 = M (before M, the segments would’ve been even smaller), and exists till \min(|A_L - A_{L-1}|, |A_R - A_{R+1}|)-1 (since at these times, the segment can be extended to one of the sides)/

Looking at this from a different angle, if we fix the maximum difference M = |A_i - A_{i+1}| of the segment, the endpoints are in fact uniquely defined: L is the closest index to the left of i such that its difference exceeds M, similarly R is the closest such index to the right of i+1 satisfying this property.

In other words, all we need to do is find, for each adjacent difference, the closest greater differences to its left and right.
This is a classical task (the “next greater element problem”) and can be solved in linear time using a stack.

With this, we’ve found all relevant segments in \mathcal{O}(N) time; after which the contributions of each one can be processed in via the range-add updates as mentioned earlier, also in linear time.
The entire algorithm is thus linear.

TIME COMPLEXITY:

\mathcal{O}(N) or \mathcal{O}(N\log N) per testcase.

CODE:

Preparer's code (C++, N logN)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

int find(int x, vector <int> &root){
    if (x == root[x]) return x;
    return root[x] = find(root[x], root);
}

void Solve() 
{
    int n; cin >> n;
    vector <int> a(n);
    for (auto &x : a) cin >> x;

    vector <int> root(n), sz(n, 1);
    iota(root.begin(), root.end(), 0);

    multiset <int> ms;
    int sum = 0;
    for (int i = 0; i < n; i++) ms.insert(1), sum += 1;

    int L = 0;

    // cout << "YES\n";
    // return;
    
    auto unite = [&](int a, int b){
        a = find(a, root);
        b = find(b, root);

        assert(a != b);
        root[b] = a;

        //remove sz[a] and sz[b]
        if (ms.find(sz[a]) != ms.end()){
            sum -= sz[a];
            ms.erase(ms.find(sz[a]));
        }

        if (ms.find(sz[b]) != ms.end()){
            sum -= sz[b];
            ms.erase(ms.find(sz[b]));
        }

        sz[a] += sz[b];

        if (sz[a] >= L){
            ms.insert(sz[a]);
            sum += sz[a];
        }
    };

    vector<vector<int>> adj(n + 1);

    for (int i = 1; i < n; i++){
        if (a[i] == a[i - 1]){
            unite(i, i - 1);
        } else {
            int x = abs(a[i] - a[i - 1]);
            if (x <= n) adj[x].push_back(i);
        }
    }

    int ans = 0;

    for (L = 1; L <= n; L++){
        for (auto i : adj[L]){
            unite(i, i - 1);
        }

        while (ms.size() && *(ms.begin()) < L){
            sum -= *ms.begin();
            ms.erase(ms.begin());
        }

        int cnt = ms.size();
        int add = sum - cnt * (L - 1);

        ans += add;
    }

   // cout << "YES\n";

    cout << ans << "\n";
}   

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
  //  cout << "YES\n";
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++, linear)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

int mod;
struct mi {
    int64_t v; explicit operator int64_t() const { return v % mod; }
    mi() { v = 0; }
    mi(int64_t _v) {
        v = (-mod < _v && _v < mod) ? _v : _v % mod;
        if (v < 0) v += mod;
    }
    friend bool operator==(const mi& a, const mi& b) {
        return a.v == b.v; }
    friend bool operator!=(const mi& a, const mi& b) {
        return !(a == b); }
    friend bool operator<(const mi& a, const mi& b) {
        return a.v < b.v; }

    mi& operator+=(const mi& m) {
        if ((v += m.v) >= mod) v -= mod;
        return *this; }
    mi& operator-=(const mi& m) {
        if ((v -= m.v) < 0) v += mod;
        return *this; }
    mi& operator*=(const mi& m) {
        v = v*m.v%mod; return *this; }
    mi& operator/=(const mi& m) { return (*this) *= inv(m); }
    friend mi pow(mi a, int64_t p) {
        mi ans = 1; assert(p >= 0);
        for (; p; p /= 2, a *= a) if (p&1) ans *= a;
        return ans;
    }
    friend mi inv(const mi& a) { assert(a.v != 0);
        return pow(a,mod-2); }

    mi operator-() const { return mi(-v); }
    mi& operator++() { return *this += 1; }
    mi& operator--() { return *this -= 1; }
    mi operator++(int32_t) { mi temp; temp.v = v++; return temp; }
    mi operator--(int32_t) { mi temp; temp.v = v--; return temp; }
    friend mi operator+(mi a, const mi& b) { return a += b; }
    friend mi operator-(mi a, const mi& b) { return a -= b; }
    friend mi operator*(mi a, const mi& b) { return a *= b; }
    friend mi operator/(mi a, const mi& b) { return a /= b; }
    friend ostream& operator<<(ostream& os, const mi& m) {
        os << m.v; return os;
    }
    friend istream& operator>>(istream& is, mi& m) {
        int64_t x; is >> x;
        m.v = x;
        return is;
    }
    friend void __print(const mi &x) {
        cerr << x.v;
    }
};


bool prime(int s) {
    for(int i = 2 ; i * i <= s ; i++) {
        if(s % i == 0)  return false;
    }
    return true;
}


int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

//    input_checker input;

    int sum_n = 0;
    auto __solve_testcase = [&](int testcase) -> void {
        // int n = input.readInt(1, (int)2e5); input.readEoln();
        // sum_n += n;
        // auto a = input.readInts(n, 1, (int)1e9);    input.readEoln();

        int n;  cin >> n;
        vector<int> a(n);
        for(auto &i : a)    cin >> i;

        --n;
        vector<int> b(n);
        for(int i = 0 ; i < n ; i++)    b[i] = abs(a[i] - a[i + 1]);

        vector<int> nx(n, n), pv(n, -1), Stk;
        for(int i = 0 ;  i < n ; i++) {
            while(!Stk.empty() && b[Stk.back()] < b[i])
                Stk.pop_back();

            if(!Stk.empty())
                pv[i] = Stk.back();

            Stk.push_back(i);
        }

        Stk.clear();
        for(int i = n - 1 ; i >= 0 ; i--) {
            while(!Stk.empty() && b[Stk.back()] <= b[i]) {
                Stk.pop_back();
            }
            if(!Stk.empty())
                nx[i] = Stk.back();
            Stk.push_back(i);
        }

        // [LF, RG, SM]

        auto sum_natural = [&](int n) -> long long {
            return n * (n + 1ll) / 2;
        };
        auto calc = [&](int sm, int L, int R) -> long long {
            if(sm > L + R)
                return 0ll;

            long long here = 0;
            if(L >= sm - 1) {
                here += (L - sm + 2ll) * R;
                L = sm - 2;
            }

            here += sum_natural(R - sm + L + 1);
            if(R - sm + 1 > 0)
                here -= sum_natural(R - sm + 1);

            return here;
        };

        long long res = n + 1;
        for(int i = 0 ; i < n ; i++) {
            int dfr = nx[i] - i;
            int dfl = i - pv[i];
            int sm = max(b[i], 2);

            res += calc(sm, dfl, dfr);
        }
        cout << res << '\n';
    };
    
//    int no_of_tests = input.readInt(1, (int)2e4);   input.readEoln();
    int no_of_tests;    cin >> no_of_tests;
    for(int test_no = 1 ; test_no <= no_of_tests ; test_no++)
        __solve_testcase(test_no);
    

    // input.readEof();

    return 0;
}
Editorialist's code (Python, linear)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    difs = [10**9 + 7]
    for i in range(1, n): difs.append(abs(a[i] - a[i-1]))
    difs += [10**9 + 7]

    left, right = [0]*(n+1), [n]*(n+1)
    stk = [0]
    for i in range(1, n):
        while difs[stk[-1]] < difs[i]: stk.pop()
        left[i] = stk[-1]
        stk.append(i)
    stk = [n]
    for i in reversed(range(1, n)):
        while difs[stk[-1]] <= difs[i]: stk.pop()
        right[i] = stk[-1]
        stk.append(i)
    
    p1, p2 = [0]*(n+5), [0]*(n+5)
    for i in range(1, n):
        if difs[left[i]] == difs[i]: continue
        # segment [L, R] exists from time difs[i] to time min(difs[L-1], difs[R+1])
        # length x, k-x+1 subarrays of length k?
        # consider only k >= x: add k+1 to range, b[x] -= x for some range of x
        curlen = right[i] - left[i]
        lo, hi = difs[i], min(difs[left[i]], difs[right[i]])
        hi = min(hi, curlen + 1)
        if lo >= hi: continue
        
        p1[lo] -= 1
        p1[hi] += 1
        p2[lo] += curlen + 1
        p2[hi] -= curlen + 1
    ans = n
    for i in range(n+1):
        if i >= 2:
            ans += i*p1[i]
            ans += p2[i]
        p1[i+1] += p1[i]
        p2[i+1] += p2[i]
    print(ans)
1 Like

Really liked the second approach, easy but powerful observation!