SWINC - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Segment trees

PROBLEM:

You’re given a permutation P. You can swap one pair of elements in it.
Maximize the number of sorted subarrays.

EXPLANATION:

It is recommended that you read the solution to the easy version first.

In the hard version, we can no longer check all possible pairs of swaps: we need to do something a bit faster.

Let’s see what information we actually need to simulate a swap (i, j) quickly.

  1. The number of sorted subarrays initially containing index i and/or j.
    • The number of sorted subarrays initially containing i is, at least, a constant.
      We only run into issues here when P[i\ldots j] is already sorted, in which case the value depends on both i and j - in every other case, we can independently sum up the values for i and j.
  2. The number of sorted subarrays after the swap.
    • For this, we only need to know how P_j compares to the neighbors of index i, i.e, values P_{i-1} and P_{i+1}.
    • There are three ranges of elements where the count of new sorted subarrays is the same no matter which P_j is chosen within that range:
      1. P_j \lt \min(P_{i-1}, P_{i+1})
      2. \min(P_{i-1}, P_{i+1}) \lt P_j \lt \max(P_{i-1}, P_{i+1})
      3. P_j \gt \max(P_{i-1}, P_{i+1})
    • Once again, observe that if P[i\ldots j] isn’t sorted after the swap, then the values can essentially be found independently for each of index i and index j - only depending on which ranges P_i and P_j fall into for indices j and i, respectively.

In either case, the situation is a bit unclear if P[i\ldots j] either starts sorted or becomes sorted after the swap.
Let’s do a bit of wishful thinking, and just completely ignore such cases for now - just pretend that i and j are independent.

For a fixed index i, let’s try to find the optimal index j to swap with.
Recall that there are three ranges of elements that we care about, defined by P_{i-1} and P_{i+1}.
Let’s consider each of these ranges separately, and query for the best j within each range.
Note that if the range is fixed, several quantities are fixed too: the cost of removing P_i from index i and the profit of putting P_j into index i (for the fixed range of P_j) are both constants; and the cost of removing P_j from index j can be tied to P_j itself.
The only thing we need to know is the profit of putting P_i into position j.

One way to deal with this, is to process indices in increasing order of their values.
Let’s process the values x = 1, 2, 3, \ldots, N in ascending order, and try to find the best we can do if the value x is swapped away.
For the current value of x, let a_y denote the profit of placing x into the position containing y, minus the cost of removing y from its position.
Then,

  • We have three ranges to consider; for each one we already know the cost of removing x from it and placing another element into it.
    For the remaining change, simply query for the maximum of a_y across all y in the range.
  • Then, we move to x+1.
    Note that most almost all the a_y values remain the same: only the values of the neighbors of x and/or x+1 can change.
    This is a constant number of changes, so simply perform them directly.

We now have a (wishful) solution that only requires \mathcal{O}(N) point updates and range max queries - each of which are easily done with a segment tree in \mathcal{O}(\log N).


Now, we deal with the issue of non-independent swaps.

Call an index i important if P_i \gt P_{i+1} or P_{i-1} \gt P_i (or i = 1 or i = N).
Essentially, an important index is the border of a sorted subarray.

It can be proved that in any optimal swap, at least one index must be important: otherwise the number of sorted subarrays cannot increase.

So, let’s go back to our original solution: except, we’ll maintain a_y only for values at important indices.
Now, let’s analyze when things can potentially go bad:

  • If the subarray containing x and y was initially sorted, and y is important, then y must be at the closest important index to x (either to the left or to the right).
  • If the subarray containing both after a swap is sorted, again y must be an important index relatively close to x (third away to the left/right at best).

This allows us to do the following when processing x:

  • First, disable a_y for all important indices close to x - upto three in each direction.
  • Do the range queries, since we’re guaranteed independence now.

That leaves us \mathcal{O}(N) pairs that can potentially be swaps but we haven’t taken care of.
However, the easy version already allows us to check each of them in constant time, so just do that to finish.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

struct segtree {
    using T = long long;

    T e() {
        return -1e18;
    }

    T op(T x, T y) {
        return max(x, y);
    }

    int n;
    int size;
    vector<T> node;

    segtree() : segtree(0) {}
    segtree(int _n) {
        build(vector<T>(_n, e()));
    }
    segtree(const vector<T> &v) {
        build(v);
    }

    void build(const vector<T> &v) {
        n = (int) v.size();
        if (n <= 1) {
            size = n;
        } else {
            size = 1 << (32 - __builtin_clz(n - 1));
        }
        node.resize(2 * size, e());
        for (int i = 0; i < n; i++) {
            node[i + size] = v[i];
        }
        for (int i = size - 1; i > 0; i--) {
            node[i] = op(node[2 * i], node[2 * i + 1]);
        }
    }

    void set(int p, T v) {
        assert(0 <= p && p < n);
        p += size;
        node[p] = v;
        while (p > 1) {
            p >>= 1;
            node[p] = op(node[2 * p], node[2 * p + 1]);
        }
    }

    T get(int l, int r) {
        assert(0 <= l && l <= r && r <= n);
        T vl = e();
        T vr = e();
        l += size;
        r += size;
        while (l < r) {
            if (l & 1) {
                vl = op(vl, node[l++]);
            }
            if (r & 1) {
                vr = op(node[--r], vr);
            }
            l >>= 1;
            r >>= 1;
        }
        return op(vl, vr);
    }

    T get(int p) {
        assert(0 <= p && p < n);
        return node[p + size];
    }
};

void solve(istringstream cin) {
    int n;
    cin >> n;
    vector<int> p(n);
    for (int i = 0; i < n; i++) {
        cin >> p[i];
        p[i]--;
    }

    vector<int> at(n);
    for (int i = 0; i < n; i++) {
        at[p[i]] = i;
    }

    set<int> st;
    st.emplace(0);
    for (int i = 1; i < n; i++) {
        if (p[i - 1] > p[i]) {
            st.emplace(i);
        }
    }
    st.emplace(n);

    long long ans = 0;
    segtree seg(vector<long long>(n, 0LL));
    vector events(n + 2, vector<pair<int, long long>>());
    for (int x = 0; x < n; x++) {
        for (auto [index, value] : events[x]) {
            seg.set(index, value);
        }

        int l = (at[x] == 0 ? n : p[at[x] - 1]);
        int r = (at[x] == n - 1 ? -1 : p[at[x] + 1]);

        long long lcnt = (at[x] == 0 ? 0 : at[x] - *prev(st.lower_bound(at[x])));
        long long rcnt = (at[x] == n - 1 ? 0 : *st.upper_bound(at[x] + 1) - at[x] - 1);

        if (l > x && x < r) {
            ans = max(ans, seg.get(0, x));
        }
        if (l < x && x < r) {
            ans = max(ans, seg.get(l + 1, x));
            ans = max(ans, seg.get(0, l + 1) - lcnt * (rcnt + 1));
        }
        if (l > x && x > r) {
            ans = max(ans, seg.get(max(r, 0), x));
            ans = max(ans, seg.get(0, max(r, 0)) + rcnt);
        }
        if (l < x && x > r && l < r) {
            ans = max(ans, seg.get(r, x));
            ans = max(ans, seg.get(l + 1, r) + (lcnt + 1) * rcnt);
            ans = max(ans, seg.get(0, l + 1) - lcnt + rcnt);
        }
        if (l < x && x > r && l > r) {
            ans = max(ans, seg.get(l + 1, x));
            ans = max(ans, seg.get(max(r, 0), l + 1) - lcnt);
            ans = max(ans, seg.get(0, max(r, 0)) - lcnt + rcnt);
        }

        if (l < x && x < r) {
            events[r].emplace_back(x, -(lcnt + 1) * rcnt);
        }
        if (l > x && x > r) {
            events[l + 1].emplace_back(x, lcnt);
        }
        if (l > x && x < r && l < r) {
            events[l + 1].emplace_back(x, lcnt * (rcnt + 1));
            events[r].emplace_back(x, lcnt - rcnt);
        }
        if (l > x && x < r && l > r) {
            events[r].emplace_back(x, -rcnt);
            events[l + 1].emplace_back(x, lcnt - rcnt);
        }
    }

    for (int x = 0; x < n; x++) {
        int l = (at[x] == 0 ? n : p[at[x] - 1]);
        int r = (at[x] == n - 1 ? -1 : p[at[x] + 1]);
        if (x < r || r == -1) {
            continue;
        }

        {
            int t = at[r];
            long long cnt0 = at[x] - *prev(st.upper_bound(at[x]));
            long long cnt1 = t - at[x] - 1;
            long long cnt2 = *st.upper_bound(t) - t - 1;
            long long a = 2 * cnt1 + 1 - cnt0 - cnt2;
            bool ok0 = l < p[t];
            bool ok2 = t + 1 < n && x < p[t + 1];
            if (ok0) {
                a += cnt0 * (1 + cnt1 + 1);
            }
            if (ok2) {
                a += (1 + cnt1 + 1) * cnt2;
            }
            if (ok0 && ok2) {
                a += cnt0 * cnt2;
            }
            ans = max(ans, a);
        }

        int t = *st.upper_bound(at[r]);
        if (t < n && p[t] < r && p[t - 1] < x) {
            long long cnt0 = at[x] - *prev(st.upper_bound(at[x]));
            long long cnt1 = t - at[x] - 1;
            long long cnt2 = *st.upper_bound(t) - t - 1;
            long long a = 2 * cnt1 + 1 - cnt0 - cnt2;
            bool ok0 = l < p[t];
            bool ok2 = t + 1 < n && x < p[t + 1];
            if (ok0) {
                a += cnt0 * (1 + cnt1 + 1);
            }
            if (ok2) {
                a += (1 + cnt1 + 1) * cnt2;
            }
            if (ok0 && ok2) {
                a += cnt0 * cnt2;
            }
            ans = max(ans, a);
        }
    }

    for (int i = 1; i + 1 < n; i++) {
        if (p[i] < p[i - 1] && p[i - 1] < p[i + 1]) {
            long long a = i - *prev(st.find(i));
            long long b = *st.upper_bound(i) - i - 1;
            ans = max(ans, a - b);
        }
    }

    for (int i = 1; i + 1 < n; i++) {
        if (p[i - 1] < p[i + 1] && p[i + 1] < p[i]) {
            long long a = *st.upper_bound(i + 1) - i - 1;
            long long b = i - *prev(st.lower_bound(i));
            ans = max(ans, a - b);
        }
    }

    for (int i : st) {
        if (i == n) {
            continue;
        }
        int j = *st.upper_bound(i) - 1;
        if (i == j) {
            continue;
        }
        long long a = 1 - 2 * (j - i);
        if (j + 1 < n && p[i] < p[j + 1]) {
            a += *st.upper_bound(j + 1) - j - 1;
        }
        if (i - 1 >= 0 && p[i - 1] < p[j]) {
            a += i - *prev(st.find(i));
        }
        ans = max(ans, a);
    }

    for (int i : st) {
        if (i < n) {
            int j = *st.upper_bound(i);
            ans += (j - i) * 1LL * (j - i + 1) / 2;
        }
    }
    cout << ans << '\n';
}

////////////////////////////////////////

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string &pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 3e5);
        in.readEoln();
        sn += n;
        auto p = in.readInts(n, 1, n);
        in.readEoln();
        vector<int> a(n + 1);
        for (int i = 0; i < n; i++) {
            assert(a[p[i]] == 0);
            a[p[i]] = 1;
        }
        ostringstream sout;
        sout << n << '\n';
        for (int i = 0; i < n; i++) {
            sout << p[i] << " \n"[i == n - 1];
        }
        solve(istringstream(sout.str()));
    }
    cerr << sn << endl;
    assert(sn <= 3e5);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

template<class T, T unit = T()>
struct SegTree {
	T f(T a, T b) { return max(a, b); }
	vector<T> s; int n;
	SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
	void update(int pos, T val) {
		for (s[pos += n] = val; pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    constexpr ll minf = -1e18;
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> p(n+2), pos(n+2);
        for (int i = 1; i <= n; ++i) {
            cin >> p[i];
            pos[p[i]] = i;
        }
        
        p[0] = n+1;
        
        vector<int> lt(n+2, 1), rt(n+2, 1);
        lt[0] = rt[n+1] = lt[n+1] = rt[0];
        for (int i = 2; i <= n; ++i) {
            if (p[i] > p[i-1]) lt[i] += lt[i-1];
        }
        for (int i = n-1; i >= 1; --i) {
            if (p[i] < p[i+1]) rt[i] += rt[i+1];
        }

        auto getct = [&] (int i, int x) {
            // Place x at position i
            if (x < min(p[i-1], p[i+1])) return 1ll + rt[i+1];
            if (x > max(p[i-1], p[i+1])) return 1ll + lt[i-1];
            if (p[i-1] < p[i+1]) return 1ll*(lt[i-1] + 1)*(rt[i+1] + 1);
            return 1ll;
        };

        auto solve_oneswap = [&] (int i, int j) {
            if (i > j) swap(i, j);
            if (i == j) return minf;
            
            ll res = -1ll*lt[i]*rt[i] - 1ll*lt[j]*rt[j];
            if (i+rt[i] > j) res += 1ll*lt[i]*rt[j];

            if (i+1 == j) {
                if (p[i] < p[j]) {
                    res += 2;
                    if (p[j] > p[i-1]) res += lt[i-1];
                    if (p[i] < p[j+1]) res += rt[j+1];
                }
                else {
                    int L = i, R = j;
                    if (p[j] > p[i-1]) L = i - lt[i-1];
                    if (p[i] < p[j+1]) R = j + rt[j+1];
                    res += 1ll*(i-L+1)*(R-i+1);
                    res += 1ll*(R-j+1)*(j-i);
                }
                return res;
            }

            if (p[j] < p[i+1] and p[j-1] < p[i] and rt[i+1]+i+1 >= j) {
                int L = i, R = j;
                if (p[j] > p[i-1]) L = i - lt[i-1];
                if (p[i] < p[j+1]) R = j + rt[j+1];
                res += 1ll*(i-L+1)*(R-i+1);
                res += 1ll*(R-j+1)*(j-i);
            }
            else {
                // no interaction
                res += getct(i, p[j]) + getct(j, p[i]);
            }
            return res;
        };

        vector<int> mark(n+2);
        mark[1] = mark[n] = 1;
        for (int i = 1; i < n; ++i) {
            if (p[i] > p[i+1]) mark[i] = mark[i+1] = 1;
        }
        set<int> important;
        for (int i = 1; i <= n; ++i)
            if (mark[i]) important.insert(i);
        
        vector<array<ll, 3>> vals(n+2);
        vector<ll> curval(n+2);
        vector<vector<array<int, 2>>> events(n+2);
        for (int i = 1; i <= n; ++i) {
            ll inc = 1ll*lt[i]*rt[i];

            int L = min(p[i-1], p[i+1]), R = max(p[i-1], p[i+1]);
            auto &cur = vals[p[i]];
            cur[0] = 1 + rt[i+1] - inc;
            cur[2] = 1 + lt[i-1] - inc;
            if (p[i-1] < p[i+1]) cur[1] = 1ll*(1 + lt[i-1])*(1 + rt[i+1]) - inc;
            else cur[1] = 1 - inc;

            events[L].push_back({p[i], 1});
            events[R].push_back({p[i], 2});
        }

        ll ans = 0, add = 0;
        for (int i = 1; i <= n; ++i)
            ans += lt[i];
        
        SegTree<ll, minf> seg(n+1);
        for (int i = 1; i < n; ++i) {
            curval[p[i]] = vals[p[i]][0];
            seg.update(p[i], curval[p[i]]);
        }
        curval[p[n]] = vals[p[n]][1];
        seg.update(p[n], curval[p[n]]);
        
        const int itrs = 3;
        for (int x = 1; x <= n; ++x) {
            int i = pos[x];
            ll inc = 1ll*lt[i]*rt[i];
            int L = min(p[i-1], p[i+1]), R = max(p[i-1], p[i+1]);

            vector<int> removed;
            auto it = important.lower_bound(i);
            for (int j = 0; j < itrs; ++j)
                if (it != begin(important)) --it;
            
            for (int j = 0; j < 2*itrs + 1; ++j) {
                if (it == end(important)) break;
                int y = p[*it];
                seg.update(y, minf);
                removed.push_back(*it);
                it = important.erase(it);
            }

            add = max(add, seg.query(1, L) + 1 + rt[i+1] - inc);
            add = max(add, seg.query(R+1, n+1) + 1 + lt[i-1] - inc);
            if (p[i-1] < p[i+1]) add = max(add, seg.query(L+1, R) + 1ll*(1 + lt[i-1])*(1 + rt[i+1]) - inc);
            else add = max(add, seg.query(L+1, R) + 1 - inc);

            for (int j : removed) {
                add = max(add, solve_oneswap(i, j));
                important.insert(j);
                seg.update(p[j], curval[p[j]]);
            }

            for (auto [y, id] : events[x]) {
                curval[y] = vals[y][id];
                seg.update(y, curval[y]);
            }
        }

        cout << ans + add << '\n';
    }
}