ONEPILE - EDITORIAL

PROBLEM LINK:

Contest Div 1
Practice

Setter: Utkarsh Gupta
Tester: Anshu Garg
Editorialist: Keyur Jain

DIFFICULTY

Hard

PREREQUISITES

Dynamic Programming, Cumulative Sum, Game Theory

PROBLEM

Utkarsh is forced to play another game with Ashish.

In this game there are N piles, ith pile contains A_i stones. Utkarsh moves first.

In his turn a player can choose any pile with non-zero stones and remove exactly one stone from that pile. The game ends when there is exactly one pile containing non-zero stones and the player who made the last move wins the game.

Now you are given an array B of length N. Utkarsh wants to know for how many arrays A such that 1≤Ai≤Bi for every i, he will win the game assuming that both players play optimally.

Since the answer can be large output it modulo 998244353.

EXPLANATION

Let us solve a simpler version of the problem first. Let the array A[N] denote the size of the piles and let A_i \geq 2. We need to find out if Utkarsh (who always moves first) can win this single instance of the game.
Upon simulating various simple versions of the game, you should be coming to the conclusion that Utkarsh can win a game if and only if the sum A_1 + A_2 \dots + A_N is odd. To formally state the above :

  • The player with odd sum always wins if A_i \geq 2
  • A player’s sum parity never changes, since each player removes a single element from any pile.

Now let us consider the cases when A_i can be 1 as well, i.e. now A_i \geq 1. A player with odd sum parity will want to ensure that there are no piles with A_i = 1, since A_i \geq 2 is a winning state for the player with an odd sum parity. Conversely, the player with an even sum parity will want to create as many piles of size 1 as he can. Note that either of the players will not be touching the biggest pile because that move will only deviate them away from their goal of maximizing/minimizing the number of size 1 piles.

You can run some more simulations of various cases and arrive at the following observations :

  • (X, 1) is a winning state since player can remove 1 and win the game
  • Any state of the form (X, 1, 1, 1 \dots. 1) (atleast two 1s) is a winning state for the player with even sum

The states of the form (X, 1, 1, 1 \dots. 1) are special states, since they are chased/avoided by players. Let’s denote any special state by S.

Coming back to the problem, for Utkarsh to win the game :

  • If the sum of array is even, he needs to be able to reach state S
  • If the sum of array is odd, he needs to be able to avoid state S

When we analyse the requirements for the state S to be reachable/unreachable, we come to the following conclusion. Formally, for Utkarsh to win :

  • If sum is odd, SUM(A[N]) - MAX(A[N]) > 2 * N - 4 must hold
  • If sum is even. SUM(A[N]) - MAX(A[N]) \leq 2 * N - 3 must hold

The first condition can be calculated by subtracting from all possible cases, the number of cases where SUM(A[N]) - MAX(A[N]) \leq 2 * N - 4


Let us solve the problem in two parts :

Case 1 : MAX(A[N]) \leq N

We can have a simple dp inspired by knapsack.

DP(i, j, k) => the number of arrays such that

  • i => first i elements are chosen
  • j => max element so far \leq j
  • k => sum of the elements so far = k

we can iterate over N, max(A[N]), and sum(A[N]) (which can be at most 3N) (why?) to get the solution in O(N^3)

Case 2 : MAX(A[N]) > N

When elements can be \geq N, only one such element can exist such that our conditions SUM(A[N]) - MAX(A[N]) <= 2 * N - 4 and SUM(A[N]) - MAX(A[N]) <= 2 * N - 3 still hold. All the other elements will necessarily be \leq N.

We can iterate over every i and assume A_i > N and solve it in a similar manner as above.

TIME COMPLEXITY

The time complexity is O(N^3)

SOLUTIONS

Setter's Solution
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <bits/stdc++.h>
#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")

using namespace __gnu_pbds;
using namespace std;

using ll = long long;
using ld = long double;

typedef tree<
        int,
        null_type,
        less<int>,
        rb_tree_tag,
        tree_order_statistics_node_update>
        ordered_set;

#define mp make_pair

int MOD =  998244353;

int mul(int a, int b) {
    return (1LL * a * b) % MOD;
}

int add(int a, int b) {
    int s = (a+b);
    if (s>=MOD) s-=MOD;
    return s;
}

int sub(int a, int b) {
    int s = (a+MOD-b);
    if (s>=MOD) s-=MOD;
    return s;
}

int po(int a, ll deg)
{
    if (deg==0) return 1;
    if (deg%2==1) return mul(a, po(a, deg-1));
    int t = po(a, deg/2);
    return mul(t, t);
}

int inv(int n)
{
    return po(n, MOD-2);
}


mt19937 rnd(time(0));


const int LIM = 400005;

vector<int> facs(LIM), invfacs(LIM);

void init()
{
    facs[0] = 1;
    for (int i = 1; i<LIM; i++) facs[i] = mul(facs[i-1], i);
    invfacs[LIM-1] = inv(facs[LIM-1]);
    for (int i = LIM-2; i>=0; i--) invfacs[i] = mul(invfacs[i+1], i+1);
}

int C(int n, int k)
{
    if (n<k) return 0;
    if (n<0 || k<0) return 0;
    return mul(facs[n], mul(invfacs[k], invfacs[n-k]));
}

/*
struct DSU
{
    vector<int> sz;
    vector<int> parent;
    void make_set(int v) {
        parent[v] = v;
        sz[v] = 1;
    }

    int find_set(int v) {
        if (v == parent[v])
            return v;
        return find_set(parent[v]);
    }

    void union_sets(int a, int b) {
        a = find_set(a);
        b = find_set(b);

        if (a != b) {
            if (sz[a] < sz[b])
                swap(a, b);
            parent[b] = a;
            sz[a] += sz[b];
        }
    }

    DSU (int n)
    {
        parent.resize(n);
        sz.resize(n);
        for (int i = 0; i<n; i++) make_set(i);
    }
};*/


bool check(vector<int> b)
{
    vector<int> c(b.begin(), b.end());
    int n = b.size();
    sort(b.begin(), b.end());

    int sum = 0; for (auto it: b) sum+=it;

    if (sum%2==1)
    {
        return (sum-b[n-1]>=2*(n-1)-1);
    }
    else
    {
        return (sum-b[n-1]<=2*(n-1)-1);
    }
}

int ans;

vector<int> searcher;

void dfs(vector<int> &a, int iter)
{
    if (iter==a.size())
    {
        if (check(searcher)) ans++;
    }
    else
    {
        for (int i = 1; i<=a[iter]; i++)
        {
            searcher[iter] = i;
            dfs(a, iter+1);
        }
    }
}

int solve1(vector<int> a)
{
    int n = a.size();
    ans = 0;
    searcher.clear(); searcher.resize(n);
    dfs(a, 0);
    return ans;
}

int solve2(vector<int> a)
{
    int n = a.size();
    int winning = 0;

    int even = 1; int odd = 0;
    for (auto it: a)
    {
        int even1 = 0; int odd1 = 0;
        even1 = add(even1, mul(even, it/2));
        odd1 = add(odd1, mul(odd, it/2));
        even1 = add(even1, mul(odd, (it - it/2)));
        odd1 = add(odd1, mul(even, (it - it/2)));

        even = even1; odd = odd1;
    }

    winning = odd;


    //Let's suppose that we take from 0 to A[i]-1 now

    vector<vector<int>> pref(n, vector<int>(n-1)), suf(n, vector<int>(n-1));

    pref[0][0] = 1;
    for (int i = 0; i<n-1; i++)
    {
        vector<int> part_sum(n);
        for (int j = 0; j<n-1; j++) part_sum[j+1] = add(part_sum[j], pref[i][j]);

        for (int sum = 0; sum<n-1; sum++)
        {
            int lim = min(a[i]-1, n-2); lim = min(lim, sum);
            //[sum-lim, sum];
            pref[i+1][sum] = sub(part_sum[sum+1], part_sum[sum-lim]);
        }
    }

    suf[n-1][0] = 1;
    for (int i = n-1; i>0; i--)
    {
        vector<int> part_sum(n);
        for (int j = 0; j<n-1; j++) part_sum[j+1] = add(part_sum[j], suf[i][j]);

        for (int sum = 0; sum<n-1; sum++)
        {
            int lim = min(a[i]-1, n-2); lim = min(lim, sum);
            //[sum-lim, sum];
            suf[i-1][sum] = sub(part_sum[sum+1], part_sum[sum-lim]);
        }
    }

    for (int i = 0; i<n; i++) if (a[i]>=n)
        {
            //Combining pref[i] and suf[i]
            vector<int> can(n-1);
            for (int l = 0; l<n-1; l++)
                for (int r = 0; r<n-1; r++) if (l+r<n-1) can[l+r] = add(can[l+r], mul(pref[i][l], suf[i][r]));

            int ev = a[i]/2 - (n-1)/2;
            int od = (a[i]-(n-1))-ev;

            for (int sum = 0; sum<n-1; sum++)
            {
                int real_sum = sum+(n-1);

                if (real_sum%2==0)
                {
                    //od:
                    if (real_sum<2*(n-1)-1) winning = sub(winning, mul(od, can[sum]));
                    //ev:
                    if (real_sum<=2*(n-1)-1) winning = add(winning, mul(ev, can[sum]));
                }
                else
                {
                    //ev:
                    if (real_sum<2*(n-1)-1) winning = sub(winning, mul(ev, can[sum]));
                    //od:
                    if (real_sum<=2*(n-1)-1) winning = add(winning, mul(od, can[sum]));
                }
            }
        }
    int S = 2*n;

    vector<vector<int>> tot_sum(n-1, vector<int>(S+1));
    for (int M = 0; M<n-1; M++)
    {
        vector<int> cur(S+1);
        cur[0] = 1;
        for (auto it: a)
        {
            int x = min(it-1, M);
            vector<int> part_sum(S+2);
            for (int i = 0; i<S+1; i++) part_sum[i+1] = add(part_sum[i], cur[i]);

            //range: [sum, sum-lim]

            for (int sum = 0; sum<=S; sum++)
            {
                int lim = min(sum, x);
                cur[sum] = sub(part_sum[sum+1], part_sum[sum-lim]);
            }
        }
        //max sum = M + (n-1)

        tot_sum[M] = cur;
    }

    /*for (int M = 0; M<n-1; M++)
    {
        cout<<M<<": "; for (auto it: tot_sum[M]) cout<<it<<' ';
        cout<<endl;
    }*/


    for (int M = 0; M<(n-1); M++)
    {
        for (int sum = 0; sum<=S; sum++)
        {
            int true_sum = sum+n;
            int ways = tot_sum[M][sum];
            if (M>0) ways = sub(ways, tot_sum[M-1][sum]);

            int part_sum = true_sum - (M+1);

            if (true_sum%2==0)
            {
                if (part_sum<=2*(n-1)-1) winning = add(winning, ways);
            }
            else
            {
                if (part_sum<2*(n-1)-1) winning = sub(winning, ways);
            }
        }
    }
    return winning;
}

void solve()
{
    int n; cin>>n;
    vector<int> a(n); for (int i = 0; i<n; i++) cin>>a[i];
    cout<<solve2(a)<<endl;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr);

    int t; cin>>t;
    while (t--) solve();


}
Tester's Solution
#include<bits/stdc++.h>
using namespace std ;

#define ll              long long 
#define pb              push_back
#define all(v)          v.begin(),v.end()
#define sz(a)           (ll)a.size()
#define F               first
#define S               second
#define INF             2000000000000000000
#define popcount(x)     __builtin_popcountll(x)
#define pll             pair<ll,ll>
#define pii             pair<int,int>
#define ld              long double

const int M = 1000000007;
const int MM = 998244353;

template<typename T, typename U> static inline void amin(T &x, U y){ if(y<x) x=y; }
template<typename T, typename U> static inline void amax(T &x, U y){ if(x<y) x=y; }

#ifdef LOCAL
#define debug(...) debug_out(#__VA_ARGS__, __VA_ARGS__)
#else
#define debug(...) 2351
#endif
long long readInt(long long l,long long r,char end){
    long long x = 0;
    int cnt = 0;
    int first =-1;
    bool is_neg = false;
    while(true) {
        char g = getchar();
        if(g == '-') {
            assert(first == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if(cnt == 0) {
                first = g - '0';
            }
            ++cnt;
            assert(first != 0 || cnt == 1);
            assert(first != 0 || is_neg == false);
            
            assert(!(cnt > 19 || (cnt == 19 && first > 1)));
        } 
        else if(g == end) {
            if(is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        } 
        else {
            assert(false);
        }
    }
}
string readString(int l,int r,char end){
    string ret = "";
    int cnt = 0;
    while(true) {
        char g = getchar();
        assert(g != -1);
        if(g == end) {
            break;
        }
        ++cnt;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
    
const int MOD=MM;
struct Mint {
    int val;
 
    Mint(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
        if (v >= MOD)
            v %= MOD;
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        } 
        return x < 0 ? x + m : x;
    } 
    explicit operator int() const {
        return val;
    }
    Mint& operator+=(const Mint &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    Mint& operator-=(const Mint &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
    Mint& operator*=(const Mint &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
    Mint& operator/=(const Mint &other) {
        return *this *= other.inv();
    }
    friend Mint operator+(const Mint &a, const Mint &b) { return Mint(a) += b; }
    friend Mint operator-(const Mint &a, const Mint &b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint &a, const Mint &b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint &a, const Mint &b) { return Mint(a) /= b; }
    Mint& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
    Mint& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
    Mint operator++(int32_t) { Mint before = *this; ++*this; return before; }
    Mint operator--(int32_t) { Mint before = *this; --*this; return before; }
    Mint operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
    bool operator==(const Mint &other) const { return val == other.val; }
    bool operator!=(const Mint &other) const { return val != other.val; }
    Mint inv() const {
        return mod_inv(val);
    }
    Mint power(long long p) const {
        assert(p >= 0);
        Mint a = *this, result = 1;
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
        return result;
    }
    friend ostream& operator << (ostream &stream, const Mint &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, Mint &m) {
        return stream>>m.val;   
    }
};


const int Sz = 503;
int B[Sz], N, sumN = 0;

Mint ans = 0;
Mint dp[2][Sz][3*Sz], sum[2][Sz][3*Sz];

void smol() {
    for(int j=1;j<=N;++j) {
        dp[0][j][0] = 1;
        for(int k=0;k<=3*N;++k) {
            sum[0][j][k] = 1;
        }
    }
    for(int i=1;i<=N;++i) {
        int cur = i & 1, prev = (i - 1) & 1;
        for(int j=1;j<=N;++j) {
            int f = B[i];
            amin(f, j);
            sum[cur][j][0] = 0;
            for(int k=1;k<=3*N;++k) {
                dp[cur][j][k] = sum[prev][j][k-1] - (k - f - 1 >= 0 ? sum[prev][j][k-f-1]: 0);
                sum[cur][j][k] = sum[cur][j][k-1] + dp[cur][j][k];
            }
        }
    }
    Mint odd = 0, even = 1;
    for(int i=1;i<=N;++i) {
        int o = (B[i] + 1)/2;
        int e = B[i] / 2;
        Mint tmp = even;
        even = even * e + odd * o;
        odd = tmp * o + odd * e;
    }
    ans += odd;
    for(int i=1;i<=N;++i) {
        for(int j=1;j<=3*N;++j) {
            Mint f = dp[N & 1][i][j] - dp[N & 1][i-1][j];
            if(j % 2 == 1 && j - i <= 2 * N - 4) {
                ans -= f;
            }
            else if(j % 2 == 0 && j - i <= 2 * N - 3) {
                ans += f;
            }
        }   
    }
}

Mint pref[Sz][2*Sz], suf[Sz][2*Sz], psum[2][2*Sz];

void gr8() {
    pref[0][0] = 1;
    for(int i=0;i<=2*N;++i) {
        psum[0][i] = 1;
    }
    for(int i=1;i<=N;++i) {
        int f = B[i];
        psum[i & 1][0] = 0;
        int prev = (i - 1) & 1;
        for(int j=1;j<=2*N;++j) {
            pref[i][j] = psum[prev][j - 1] - (j - f - 1 >= 0 ? psum[prev][j - f - 1]: 0);
            psum[i & 1][j] = psum[i & 1][j-1] + pref[i][j];
        }
    }
    suf[N+1][0] = 1;
    for(int i=0;i<=2*N;++i) {
        psum[(N + 1) & 1][i] = 1;
    }
    for(int i=N;i>=1;--i) {
        int f = B[i];
        psum[i & 1][0] = 0;
        int prev = (i + 1) & 1;
        for(int j=1;j<=2*N;++j) {
            suf[i][j] = psum[prev][j-1] - (j - f - 1 >= 0 ? psum[prev][j - f - 1]: 0);
            psum[i & 1][j] = psum[i & 1][j-1] + suf[i][j];
        }
    }
    // merge
    for(int i=1;i<=N;++i) {
        if(B[i] <= N) {
            continue;
        }
        psum[0][0] = pref[i-1][0];
        psum[1][0] = 0;
        for(int j=1;j<=2*N;++j) {
            if(j & 1) {
                psum[1][j] = psum[1][j-1] + pref[i-1][j];
                psum[0][j] = psum[0][j-1];
            }
            else {
                psum[0][j] = psum[0][j-1] + pref[i-1][j];
                psum[1][j] = psum[1][j-1];
            }
        }
        int even = B[i] / 2 - N / 2;
        int odd = (B[i] + 1)/2 - (N + 1)/2;
        int cnt[2] = {even, odd};

        for(int j=0;j<=2*N;++j) {
            for(int k=0;k<2;++k) {
                for(int l=0;l<2;++l) {
                    int p = (j + k + l) % 2;
                    if(p == 1) {
                        int rem = 2 * N - 4 - j;
                        if(rem >= 0) {
                            Mint f = suf[i+1][j] * psum[k][rem] * cnt[l];
                            ans -= suf[i+1][j] * psum[k][rem] * cnt[l];
                        }
                    }
                    else {
                        int rem = 2 * N - 3 - j;
                        if(rem >= 0) {
                            ans += suf[i+1][j] * psum[k][rem] * cnt[l];
                        }
                    }
                }
            }
        }
    }
}

int _runtimeTerror_()
{
    N = readIntLn(2,500);
    for(int i=1;i<=N;++i) {
        if(i == N) {
            B[i] = readIntLn(1,1e9);
        }
        else {
            B[i] = readIntSp(1,1e9);
        }
    }
    ans = 0;
    sumN += N;
    smol();
    gr8();
    cout << ans << "\n";
    return 0;
}

int main()
{
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    #ifdef runSieve
        sieve();
    #endif
    #ifdef NCR
        initialize();
    #endif
    int TESTS = 1;
    TESTS = readIntLn(1,100);
    //cin >> TESTS;
    while(TESTS--)
        _runtimeTerror_();
    assert(sumN <= 500);
    // assert(getchar() == -1);
    return 0;
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: