SMEX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

Elementary combinatorics

PROBLEM:

The \text{SMEX} of a sequence is the second-smallest non-negative integer that does not appear in it.
You’re given an array A. Compute the sum of the \text{SMEX} of each of its non-empty subsequences.

EXPLANATION:

When attempting to compute the sum of some quantity across all subsequences of an array, one standard idea is to use the contribution technique: that is, fix a value of the quantity, and then try to find out how many subsequences attain this value.

That’s exactly what we’ll do to solve this problem.


Let’s fix K, the value of the \text{SMEX}, and try to figure out how many subsequences attain this.

For the \text{SMEX} of a subsequence to equal K,

  1. Almost every integer from 0 to K-1 should be present.
    In particular, there should be exactly one integer M \lt K that’s not present in the subsequence at all; and every integer \lt K other than M should be present.
    This is what will make K be the second missing integer.
  2. K itself shouldn’t be present in the subsequence, of course.
  3. As long as the first two conditions are followed, integers \gt K can be freely present or absent without affecting the \text{SMEX}.

Looking at this, we see that the only thing that really matters is integers \lt K.
Specifically, we’re looking for the number of ways to choose a subsequence such that all the integers from 0 to K-1 are present; except one.
For integers \gt K, we can freely choose any subset of them, and it doesn’t depend on what the missing integer before K is.

We deal with the elements \lt K first.
Perhaps the simplest way to do so, is to use dynamic programming.
Define f(K, i) to be the number of subsequences containing only integers \lt K, such that exactly i distinct integers are not present in the subsequence.
Note that we care about f(K, 1), where exactly one integer is missing.

Let c_x denote the number of occurrences of x in the array.

The transitions to compute f(K, 1) are fairly easy:

  • f(K, 1) means exactly one integer less than K must be missing.
    For this, we have two options: either include an occurrence of K-1 (in which case a smaller number must then be missing), or don’t include any occurrence of K-1 (in which case all smaller numbers must exist). This gives:
    f(K, 1) = f(K-1, 1) \cdot (2^{c_{K-1}}-1) + f(K-1, 0)
    because:
    • If K-1 exists, there are 2^{c_K} - 1 ways to choose a non-empty set of its occurrences.
      Then, there are f(K-1, 1) ways to choose a subsequence of the smaller elements with exactly one value missing.
    • If K-1 doesn’t exist, then everything before K-1 should be present.
      There are, by definition, f(K-1, 0) ways to do this.
  • The transition tells us that we also need to maintain f(K, 0).
    That’s not too hard: we have f(K, 0) = f(K-1, 0) \cdot (2^{c_{K-1}} - 1)
    This is obvious: there are 2^{c_{K-1}} - 1 ways to choose a non-empty subset of the K-1's, and f(K-1, 0) to choose non-empty subsets of all the smaller numbers, by definition.

It’s easy to compute all the f(K, 0) and f(K, 1) values in linear time using this recurrence.

Now, let’s look at a fixed K.

  • There are f(K, 1) ways to choose the elements \lt K of the subsequence.
  • Once these are fixed, the elements \gt K can be freely included (or not).
    There are N - (c_0 + c_1 + \ldots + c_K) elements that are \gt K, giving 2^{N - c_0 - c_1 - \ldots - c_K} options for them.

To simplify this a bit, define p_i = c_0 + c_1 + \ldots + c_i to be the prefix sum array of c.
Then, the value we’re looking for is just 2^{N - p_K}.

The largest possible \text{SMEX} is N+1, obtained when all integers from 0 to N-1 are present.
So, the final answer is:

\sum_{K=0}^{N+1} f(K, 1) \cdot 2^{N - p_K} \cdot K

This can be found easily in \mathcal{O}(N\log N) or even \mathcal{O}(N) time.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 998244353;

struct mint{
    int x;

    mint (){ x = 0;}
    mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
    mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}

    int val(){
        return x;
    }
    mint &operator++(){
        x++;
        if (x == mod) x = 0;
        return *this;
    }
    mint &operator--(){
        if (x == 0) x = mod;
        x--;
        return *this;
    }
    mint operator++(int32_t){
        mint result = *this;
        ++*this;
        return result;
    }
    
    mint operator--(int32_t){
        mint result = *this;
        --*this;
        return result;
    }
    mint& operator+=(const mint &b){
        x += b.x;
        if (x >= mod) x -= mod;
        return *this;
    }
    mint& operator-=(const mint &b){
        x -= b.x;
        if (x < 0) x += mod;
        return *this;
    }
    mint& operator*=(const mint &b){
        long long z = x;
        z *= b.x;
        z %= mod;
        x = (int)z;
        return *this;
    }
    mint operator+() const {
        return *this;
    }
    mint operator-() const {
        return mint() - *this;
    }
    mint operator/=(const mint &b){
        return *this = *this * b.inv();
    }
    mint power(long long n) const {
        mint ok = *this, r = 1;
        while (n){
            if (n & 1){
                r *= ok;
            }
            ok *= ok;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        return power(mod - 2);
    }
    friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
    friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
    friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
    friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
    friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
    friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
    mint power(mint a, long long n){
        return a.power(n);
    }
    friend ostream &operator<<(ostream &os, const mint &m) {
        os << m.x;
        return os;
    }
    explicit operator bool() const {
        return x != 0;
    }
};

// Remember to check MOD

void Solve() 
{
    int n; cin >> n;
    
    vector <int> f(n + 2);
    for (int i = 1; i <= n; i++){
        int x; cin >> x;
        f[x]++;
    }
    
    vector <mint> dp(3, 0);
    dp[0] = 1;
    mint ans = 0;
    int left = n;
    
    for (int i = 0; i <= n + 1; i++){
        mint w1 = mint(2).power(f[i]) - 1;
        mint w0 = 1;
        vector <mint> ndp(3, 0);
        left -= f[i];
        
        for (int j = 0; j <= 1; j++){
            ndp[j] += dp[j] * w1;
            ndp[j + 1] += dp[j] * w0;
        }
        
        swap(dp, ndp);
        
        ans += dp[2] * i * mint(2).power(left);
    }
    
    ans -= 1;
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    f = [0]*(n+5)
    for x in a: f[x] += 1
    
    ans = -1
    prod, sm, flag = 1, 0, 0
    taken = 0
    for i in range(n+5):
        taken += f[i]
        if flag == 1:
            ans += prod * pow(2, n - taken, mod) * i % mod
        else:
            ans += prod * sm * pow(2, n - taken, mod) * i % mod
        
        if f[i] == 0:
            if flag == 1:
                break
            flag = 1
        else:
            ways = pow(2, f[i], mod) - 1
            prod = prod * ways % mod
            sm += pow(ways, mod-2, mod)
    print(ans % mod)
1 Like