REMK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

Prefix sums or priority queues/multisets

PROBLEM:

For a sequence B of length M, define f(B) = \sum_{i=1}^{M-1} \left( B_i + B_{i+1}\right).

You’re given an array A and an integer K.
Find the maximum possible value of f(C) across all subsequences of C of A of length K.

EXPLANATION:

First, we rewrite f(B) slightly.

f(B) = \sum_{i=1}^{M-1} \left( B_i + B_{i+1}\right) = B_1 + B_M + 2\cdot \sum_{i=2}^{M-1} B_i

That is, the ‘border’ elements are added once, and everything else is added twice.

Using this, we immediately obtain a (very slow) solution in \mathcal{O}(N^3) or so:

  • Suppose we fix indices i and j (1 \leq i \lt j \leq N) to be the leftmost and rightmost elements of the chosen subsequence.
    These are the ‘border’ elements, and so will contribute A_i + A_j to the sum.
  • Then, we need to choose K-2 more elements from between indices i and j.
    However, from above, we know that each chosen element will contribute twice its value to the sum.
    So, it’s optimal to choose the largest K-2 elements in this range - which can be found by for example sorting all the elements of the range.

The key observation to optimizing this further, is that we don’t really need to check for all pairs (i, j).
Indeed, note that:

  • If there’s an index x \lt i such that A_i = A_x (meaning i isn’t the leftmost occurrence of A_i), the pair (x, j) will have a not-smaller answer compared to (i, j).
    This is because the sum of the borders remains the same, while (x, j) has strictly more choices for elements in the middle.
  • The same observation applies to the right end to: if j isn’t the rightmost occurrence of A_j, it doesn’t need to be considered.

In other words, we only care about those pairs (i, j) for which i is the leftmost occurrence of A_i, and j is the rightmost occurrence of A_j.

Now, note the constraint on the elements of A: they’re all \leq 50.
So, there are at most 50 choices of i, and at most 50 choices of j.

This leaves the only “slow” part of the algorithm being getting the largest K-2 numbers between i and j.
For this, we can once again use the fact that the values are \leq 50 - it’s optimal to take as many occurrences of 50 as we can, then as many of 49, then 48, and so on.
For a fixed number x, the number of its occurrences between i and j can be found in \mathcal{O}(1) time using prefix sums (store a separate prefix sum array for each x from 1 to 50).

This brings the complexity down to \mathcal{O}(N + M^3) per testcase, where M = 50 for us.
There are at most 100 test cases, so this is fast enough to pass.


There are other solutions too:

  • The M^3 part can be optimized to M^2 \log M with the help of binary search, although this isn’t needed to get AC.
  • Instead of fixing both i and j, you can fix only i, and then iterate j from i+1 to N, while maintaining the best answer ending at the current j.
    For this, you’ll need a data structure that stores the largest K-2 elements seen so far, and updates that when a new element is inserted - a priority queue or multiset can do this.
    The complexity here is \mathcal{O}(NM\log N).

Most solutions with reasonable complexities should pass.

TIME COMPLEXITY:

\mathcal{O}(M^3) or \mathcal{O}(NM\log N) per testcase, where M = 50.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n, k; cin >> n >> k;

    vector <int> a(n), f(51, -1);
    for (int i = 0; i < n; i++){
        cin >> a[i];
        if (f[a[i]] == -1){
            f[a[i]] = i;
        }
    }

    int ans = 0;

    for (int i = 0; i < n; i++) if (f[a[i]] == i){
        priority_queue <int> pq;
        int sum = 0;
        for (int j = i + 1; j < n; j++){
            if (pq.size() == k - 2)
            ans = max(ans, a[j] + 2 * sum + a[i]);

            sum += a[j];
            pq.push(-a[j]);
            if (pq.size() > k - 2){
                sum -= -pq.top();
                pq.pop();
            }
        }
    }

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#ifdef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e5), NN = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)2e5); inp.readSpace();
    int K = inp.readInt(2, N);  inp.readEoln();
    NN += N;

    vector<int> A = inp.readInts(N, 1, 50); inp.readEoln();

    vector<int> first(51, -1), last(51, -1);
    for(int i = 0 ; i < N ; ++i)  last[A[i]] = i;
    for(int i = N - 1 ; i >= 0 ; --i) first[A[i]] = i;

    vector<vector<int>> pf(N);
    vector<bool> good(N);
    for(int i = 1 ; i <= 50 ; ++i) {
      if(first[i] != -1)  good[first[i]] = 1;
      if(last[i] != -1) good[last[i]] = 1;
    }
    vector<int> cnt(51);
    for(int i = 0 ; i < N ; ++i) {
      cnt[A[i]]++;
      if(good[i])
        pf[i] = cnt;
    }

    int64_t res = 0;
    for(int f = 1 ; f <= 50 ; ++f) if(first[f] != -1) for(int l = 1 ; l <= 50 ; ++l) if(last[l] != -1) {
      if(last[l] - first[f] - 1 < K - 2)
        continue;

      int fin = last[l], ini = first[f];
      for(int x = 1 ; x <= 50 ; ++x)
        cnt[x] = pf[fin][x] - pf[ini][x];
      cnt[A[fin]]--;
      int64_t ans = f + l;
      int rem = K - 2;
      for(int x = 50 ; x >= 1 ; --x) {
        int here = min(cnt[x], rem);
        rem -= here;
        ans += 2 * here * x;
      }
      res = max(res, ans);
    }

    cout << res << '\n';
  }
  assert(NN <= (int)2e6);
  inp.readEof();
  
  return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = [0] + list(map(int, input().split()))
    lt, rt = [n+1]*51, [0]*51
    for i in range(1, n+1):
        lt[a[i]] = min(i, lt[a[i]])
        rt[a[i]] = max(i, rt[a[i]])
    
    pref = [ [0 for _ in range(52)] for _ in range(n+1)]
    pref2 = [ [0 for _ in range(52)] for _ in range(n+1)]
    for i in range(1, n+1):
        pref[i] = pref[i-1][:]
        pref2[i] = pref2[i-1][:]
        pref[i][a[i]] += 1
        pref2[i][a[i]] += a[i]
    
    for j in reversed(range(51)):
        for i in range(n+1):
            pref[i][j] += pref[i][j+1]
            pref2[i][j] += pref2[i][j+1]
    
    ans = 0
    for i in range(1, 51):
        for j in range(1, 51):
            L, R = lt[i], rt[j]
            if R-L+1 < k: continue
            sm = i + j
            if k == 2:
                ans = max(ans, sm)
                continue
            
            lo, hi = 1, 50
            while lo < hi:
                mid = (lo + hi + 1)//2
                ct = pref[R-1][mid] - pref[L][mid]
                if ct >= k-2: lo = mid
                else: hi = mid-1
            
            ct = pref[R-1][lo+1] - pref[L][lo+1]
            rem = k-2-ct
            sm += 2*lo*rem + 2*(pref2[R-1][lo+1] - pref2[L][lo+1])
            ans = max(ans, sm)
    print(ans)
3 Likes

https://www.codechef.com/viewsolution/1080199742

Interesting DP Solution
works even without the a[i] constraint

Very nice problems which made me realize the importance of considering the constraints.

My segment tree solution: CodeChef: Practical coding for everyone

Can you add some explanation , like how is this working .

1 Like