CARDSMACHINE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: bernarb01
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Stacks, dynamic programming

PROBLEM:

There are N cards, the i-th has number A_i written on it.

Answer Q queries of the following form:

  • Remove the last K_i cards.
  • From the remaining cards, your opponent will remove some from the top and some from the bottom.
    Then, a machine arranges the cards in non-increasing order and returns the top 2 of them.
    You win if the two cards have equal values written on them.
  • How many possible ways of victory do you have?

EXPLANATION:

Rephrasing the process in more familiar terms, it can be seen that the card machine receives some subarray of the array A, and returns the two maximum elements within it.

Deleting the last K_i cards is equivalent to keeping the first N-K_i of them.
So, the query we really want to answer is, “given some prefix of A, how many subarrays within this prefix have their maximum equal to the second maximum?”

We’ll precompute this answer for every prefix, after which queries can be answered in constant time.


Let c_i denote the number of subarrays ending at index i whose two largest elements are equal.

The maximum of any subarray ending at index i is definitely at least A_i.
Let j \lt i be the largest index such that A_j \gt A_i.
Then,

  • For any index k such that j \lt k \lt i, the subarray A[k\ldots i] has two maximums if and only if it contains a second occurrence of A_i.
  • For any index k such that 1 \leq k \leq j, the subarray A[k\ldots i] has two equal maximums if and only if the subarray A[k\ldots j] has two equal maximums - that is, we can essentially ‘cut off’ the part of the subarray after j.
    • This is because, by virtue of what j is, all the elements between index j+1 and i will be \lt A_j, while the maximum of any such subarray will be at least A_j.
      So, all the elements between index j+1 and i no longer matter; the maximums will occur only before index j.

The second case is simple to deal with: by definition, there are b_j subarrays ending at j with two equal maximums.
As for the first case: let m \lt i be the closest occurrence of A_i to the left of index i.

  • If m \lt j, there are no subarrays ending at index i with A_i both occurring twice and being the maximum.
  • If m \gt j, we can choose any index k such that j \lt k \leq m to be the left endpoint, giving us m-j choices.

So, all we need to do is find the indices j and m quickly, after which b_i can be found in constant time.

  • The index j is the closest element to the left of index i that contains a value \gt A_i.
    Finding this for all indices in linear time is a standard task, and can be done using a stack - see here for instance.
  • The index m is the previous occurrence of A_i, and is easily found by, for example, maintaining a list of occurrences of every element.
    Alternately, you can also modify the above algorithm slightly to find the closest index to i that contains an element greater than or equal to A_i, and then check if that value equals A_i.

Once the array b has been computed, the answer for the prefix of length i is just b_1 + b_2 + \ldots + b_i.
So, compute the prefix sum array of b, and print the appropriate prefix sum for each query.

TIME COMPLEXITY:

\mathcal{O}(N + Q) per testcase.

CODE:

Author's code (C++)
/**
 *    author:  BERNARD B.01
**/
#include <bits/stdc++.h>

using namespace std;

#ifdef B01
#include "deb.h"
#else
#define deb(...)
#endif

int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  int tt;
  cin >> tt;
  while (tt--) {
    int n;
    cin >> n;
    vector<int> a(n);
    for (int i = 0; i < n; i++) {
      cin >> a[i];
    }
    vector<int> stk;
    vector<int> prev(n, -1);
    for (int i = 0; i < n; i++) {
      while (!stk.empty() && a[stk.back()] <= a[i]) {
        stk.pop_back();
      }
      if (!stk.empty()) {
        prev[i] = stk.back();
      }
      stk.push_back(i);
    }
    stk.clear();
    vector<int> next(n, n);
    for (int i = n - 1; i >= 0; i--) {
      while (!stk.empty() && a[stk.back()] <= a[i]) {
        stk.pop_back();
      }
      if (!stk.empty()) {
        next[i] = stk.back();
      }
      stk.push_back(i);
    }
    stk.clear();
    vector<int> next_e(n, n);
    for (int i = n - 1; i >= 0; i--) {
      while (!stk.empty() && a[stk.back()] < a[i]) {
        stk.pop_back();
      }
      if (!stk.empty()) {
        next_e[i] = stk.back();
      }
      stk.push_back(i);
    }
    vector<int64_t> ans(n + 1);
    for (int i = 0; i < n; i++) {
      if (next_e[i] < n && a[next_e[i]] == a[i]) {
        int d = i - prev[i];
        ans[next_e[i]] += d;
        ans[min(next_e[next_e[i]], next[i])] -= d;
      }
    }
    for (int i = 0; i < n; i++) {
      ans[i + 1] += ans[i];
    }
    for (int i = 0; i < n; i++) {
      ans[i + 1] += ans[i];
    }
    int q;
    cin >> q;
    while (q--) {
      int k;
      cin >> k;
      cout << ans[n - k - 1] << '\n';
    }
  }
  return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
	int t;
	cin>>t;
	while(t--){
	    int n;
	    cin>>n;
	    int a[n];
	    for(int i = 0; i < n; i++){
	        cin>>a[i];
	    }
	    vector<int> v;
	    int pr[n];
	    memset(pr, -1, sizeof(pr));
	    for(int i = n - 1; i > -1; i--){
	        if(v.size()){
	            if(a[v.back()] < a[i]){
	                pr[v.back()] = i;
	                v.pop_back();
	                i++;
	            }else{
	                v.push_back(i);
	            }
	        }else{
	            v.push_back(i);
	        }
	    }
	    int dp[n][2];
	    map<int, int> mp;
	    int x = 0;
	    int y = 0;
	    for(int i = 0; i < n; i++){
	        dp[i][0] = x + y;
	        dp[i][1] = 0;
	        auto it = mp.find(a[i]);
	        if(it != mp.end()){
	            if((*it).second > pr[i]){
	                dp[i][1] += (*it).second - pr[i];
	            }
	            (*it).second = i;
	        }else{
	            mp.insert({a[i], i});
	        }
	        if(pr[i] != -1){
	            dp[i][1] += dp[pr[i]][1];
	        }
	        x = dp[i][0];
	        y = dp[i][1];
	    }
	    int q;
	    cin>>q;
	    while(q--){
	        int x;
	        cin>>x;
	        cout<<dp[n - x - 1][0] + dp[n - x - 1][1]<<"\n";
	    }
	}
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n = int(input())
    a = [10**9 + 1] + list(map(int, input().split()))
    dp = [0]*(n+1)
    pos = dict()
    stk = [0]*(n+1)
    ptr = 0
    
    for i in range(1, n+1):
        j = 0
        if a[i] in pos: j = pos[a[i]]
        pos[a[i]] = i
        while a[stk[ptr]] <= a[i]: ptr -= 1
        k = stk[ptr]
        dp[i] = max(0, j - k) + dp[k]
        ptr += 1
        stk[ptr] = i
    for i in range(2, n+1): dp[i] += dp[i-1]
    
    for i in range(int(input())):
        print(dp[n-int(input())])
1 Like