DELARR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

For an array A, let f(A) = 1 if it can be fully deleted using the following operation (and f(A) = 0 otherwise):

  • Choose two adjacent unequal elements of A and delete them.
    The remaining elements are concatenated together.

Let g(A) denote the sum of f(A) across all non-empty subarrays of A.
Given N and M, compute \sum g(A) across all arrays of length N containing integers from 1 to M.

EXPLANATION:

Our first order of business should be to work out what f(A) is for a single array A; after all, everything else in the task hinges on it.

So, when is it possible to delete all the elements of an array with the given process?
Obviously, if the array has odd length it’s impossible; so we only look at even-length arrays.
Intuitively, if the same element appears too many times then this won’t be possible; and otherwise it’s likely possible.
In particular, notice that each move can only delete one copy of an element - and we need to make \frac{N}{2} moves in total for an array of length N. This tells us the following:
Claim: Let A be an array of even length N. f(A) = 1 if and only every element of A appears at most \frac{N}{2} times.

Proof

If an element appears \gt \frac{N}{2} times, as noted above at least one of its occurrences will remain after \frac{N}{2} moves - meaning the array cannot be empty.

Now, suppose every element appears \leq \frac{N}{2} times.
Let x be the element with maximum frequency (i.e, a mode of A); and F be this maximum frequency.
Then,

  • If F = \frac{N}{2} and there are two different elements with frequency F, delete one copy of each of them (there will definitely be such a pair that’s adjacent to each other).
  • If F = \frac{N}{2} and the mode is unique, delete one copy of x and any other element.
  • If F \lt \frac{N}{2}, delete any two adjacent different elements.

Note that these cases are exhaustive since F \leq\frac{N}{2}; and after each move, we’ve reduced the length of the array by 2 while also ensuring that in the new array, no element has frequency \gt \frac{N-2}{2}.
Repeatedly performing this process will empty the array, as desired.


Now that we know about f(A), let’s turn to g(A).
g(A) is computed as the sum of f(A) across all subarrays of A; but as noted above it’s enough to only consider even-length subarrays.
We want the sum of g(A) across all possible arrays A of length N with elements from 1 to M.

Observe that the subarrays are pretty independent here, allowing us to use the technique of contribution.
That is, instead of fixing A and trying to compute g(A); we’ll fix a subarray S of A such that f(S) = 1, and then compute the number of arrays A for which S is a subarray - summing this up across all S will give us the same answer.

So, let’s fix the endpoints L and R of the subarray, and try to count the number of arrays A in which A[L\ldots R] adds 1 to the total.
Notice that:

  • The exact values of L and R don’t matter much: only the length (R-L+1).
  • Everything outside the subarray can be freely chosen; and each element has M independent options.

In particular, suppose we knew for each length k the value c_M(k): the number of arrays of length k with elements from 1 to M, such that they can be fully deleted.
Then, the overall answer is just

\sum_{k=1}^{\frac{N}{2}} c_M(2k)\cdot (N-2k+1)\cdot M^{(N-2k)}

This is because we can choose any deletable array of length 2k, choose its starting position, and choose the values of all the other elements.

All that remains is to compute the c_M(k) values, and we’ll be done!


Computing c_M(k)

Recall that we have a condition for a deletable array of length 2k: every element must have frequency \leq k.
This is a bit unwieldy to deal with, so let’s do the opposite: we’ll count arrays of length 2k that can’t be deleted entirely, then subtract them from the total number of arrays (which is just M^{2k}).
Let’s redefine c_M(k) to be this value, i.e, arrays that can’t be deleted.

For an array that can’t be deleted entirely, some element must have frequency \gt k.
Of course, exactly one element can have frequency \gt k.
So,

  • Fix the element that has frequency \gt k (M choices here).
  • Fix this frequency that’s \gt k (anything between k+1 and 2k, say x).
  • Fix the x positions that this element appears in (\binom{2k}{x} choices).
  • Finally, fix the values at the other positions.
    With M-1 choices per position, that’s (M-1)^{2k - x} choices in total.

All together, that tells us that

c_M(2k) = \sum_{x = k+1}^{2k} M\cdot \binom{2k}{x} \cdot (M-1)^{2k-x}

This is correct, but of course computing this directly for all 2k will take \mathcal{O}(N^2) time.

Optimizing the computation

Let’s look at the expression we have.

c_M(2k) = \sum_{x = k+1}^{2k} M\cdot \binom{2k}{x} \cdot (M-1)^{2k-x}

M is a constant independent of x here, so for convenience, let’s take it out and multiply it back in at the end.
We’ll also reindex the summation a bit, to get

\sum_{x=0}^{k-1} \binom{2k}{x} (M-1)^x

This uses the fact that \binom{N}{K} = \binom{N}{N-K} (alternately, you can see it as the expression obtained by fixing the positions and values of the small-frequency elements).

Let’s relax the conditions a bit: instead of considering only even-length arrays, we’ll allow for any length - after all, we’re just counting the number of arrays such that some element appears more than half the time.
So, let

f_M(k) = \sum_{x=0}^{2x \lt k} \binom{k}{x} (M-1)^x

As it turns out, we can compute f_M(k+1) from f_M(k).
There are minor differences when doing this for even and odd k, so I’ll present odd k below and even k will be left as an exercise for the reader.
We’re trying to compute f_M(2k+2) from f_M(2k+1). Observe that:

\begin{align*} f_M(2k+2) &= \sum_{x=0}^{k} \binom{2k+2}{x} (M-1)^x \\ &= \binom{2k+2}{0}(M-1)^0 + \binom{2k+2}{1}(M-1)^1 + \ldots + \binom{2k+2}{k}(M-1)^k \\ &= \left(\binom{2k+1}{0} + \binom{2k+1}{-1} \right)(M-1)^0 + \left(\binom{2k+1}{1} + \binom{2k+1}{0} \right)(M-1)^1 + \ldots + \left(\binom{2k+1}{k} + \binom{2k+1}{k-1} \right)(M-1)^k \\ &= f_M(2k+1) + (M-1)\cdot f_M(2k+1) - \binom{2k}{k}(M-1)^{k+1} \end{align*}

The third line arises from simply applying Pascal’s rule to each binomial coefficient.
Note that \binom{x}{-1} = 0 for any x\geq 0.

In this way, f_M(2k+2) can be computed from f_M(2k+1) in constant time.
A similar computation allows you to get f_M(2k+1) from f_M(2k) in constant time.
So, starting with f_M(2) = 1, you can get every f_M(k) value in \mathcal{O}(N) time in total.

Once these are known, the c_M(2k) values are trivially found (since c_M(2k) = f_M(2k)\cdot M), after which the final answer is easily computed in linear time as well.


If you aren’t comfortable with algebraic manipulation like we did towards the end, there’s a more visual way to get to the result (though the formula ends up the same either way).
I refer you to the editorial of AtCoder ABC 235G for more details on this.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 1e6 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll fact[N], ifact[N];

ll bexp(ll a, ll b) {
    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}

void solve(int test_case)
{
    ll n,m; cin >> n >> m;
    
    ll even_sub_cnt = 0;
    for(int i = 2; i <= n; i += 2){
        even_sub_cnt += n-i+1;
    }
    even_sub_cnt %= MOD;
    
    ll p = m-1;
    vector<ll> powp(n+5), powm(n+5);
    powp[0] = 1;
    rep1(i,n) powp[i] = powp[i-1]*p%MOD;
    powm[0] = 1;
    rep1(i,n) powm[i] = powm[i-1]*m%MOD;

    ll tot = powm[n]*even_sub_cnt%MOD;
    ll bad = 0;
    ll sum = 0;

    rep1(i,n){
        if(i&1){
            // up allowed
            sum = sum*(p+1);
            ll add = ncr(i,(i+1)/2-1)*powp[(i+1)/2-1];
            ll sub = ncr(i-1,(i+1)/2-2)*powp[(i+1)/2-1];
            sum += add-sub;
            sum = (sum%MOD+MOD)%MOD;
        }
        else{
            // up not allowed
            sum = sum*(p+1);
            ll sub = ncr(i-1,i/2-1)*powp[i/2];
            sum -= sub;
            sum = (sum%MOD+MOD)%MOD;
        }

        if(i&1) conts;

        bad += sum*m%MOD*powm[n-i]%MOD*(n-i+1);
        bad %= MOD;
    }

    ll ans = (tot-bad+MOD)%MOD;
    cout << ans << endl;
}

int main()
{
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};


#ifdef LOCAL
#define dbg(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define dbg(...)
#endif

void __print(int32_t x) {cerr << x;}
void __print(int64_t x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(string x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T>void __print(complex<T> x) {cerr << '{'; __print(x.real()); cerr << ','; __print(x.imag()); cerr << '}';}

template<typename T>
void __print(const T &x);
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto it = x.begin() ; it != x.end() ; it++) cerr << (f++ ? "," : ""), __print(*it); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}


namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;

    Modular(long long k = 0) : value(norm(k)) {}

    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }

    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }

    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }

    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }

    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }

    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }

    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }

    constexpr static value_type mod()      { return     P; }

    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }

    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
    friend void __print(Modular<P> val) {
        cerr << val.value;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}

template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }

}
constexpr int mod = 998244353;
constexpr int maxn = (int)2e6 + 10;
using mod_int = mint_ns::Modular<mod>;
using mi = mod_int;

vector<mi> fct(maxn, 1), invf(maxn, 1);
void calc_fact() {
    for(int i = 1 ; i < maxn ; i++) {
        fct[i] = fct[i - 1] * i;
    }
    invf.back() = mi(1) / fct.back();
    for(int i = maxn - 1 ; i ; i--)
        invf[i - 1] = i * invf[i];
}

mi choose(int n, int r) { // choose r elements out of n elements
    if(r > n)   return mi(0);
    assert(r <= n);
    return fct[n] * invf[r] * invf[n - r];
}

mi place(int n, int r) { // x1 + x2 ---- xr = n and limit value of xi >= n
    assert(r > 0);
    return choose(n + r - 1, r - 1);
}

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    calc_fact();
    int T = input.readInt(1, (int)1e4); input.readEoln();
    int NN = 0, MM = 0;
    while(T-- > 0) {
        int N = input.readInt(1, (int)1e6);     input.readSpace();
        int M = input.readInt(1, (int)1e6);     input.readEoln();
        NN += N, MM += M;

        vector<mod_int> P(N + 1);   P[0] = 1;
        for(int i = 0 ; i < N ; ++i) {
            P[i + 1] = P[i] * M;
        }
        mod_int ss = 1, res = 0;
        for(int i = 2 ; i <= N ; ++i) { // (1, 0) * A ^ 0
            ss = ss + (M - 1) * ss;
            if (i & 1) {
                ss += choose(i - 1, i / 2) * pow(mi(M - 1), i / 2);
            } else {
                ss -= choose(i - 1, i / 2 - 1) * pow(mi(M - 1), i / 2);
            }
            if(i % 2 == 0) {
                res += (N - i + 1) * (P[i] - ss * M) * P[N - i];
            }
        }
        cout << res << '\n';
    }
    assert(NN <= (int)1e6 && MM <= (int)1e6);
    input.readEof();
    return 0;
}
// 1C0 * (M - 1) ^ 0 + 1C0 * (M - 1) ^ 0
Editorialist's code (Python)
mod = 998244353
N = 2* 10**6
fac = [1]*N
for i in range(2, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

for _ in range(int(input())):
    n, m = map(int, input().split())
    ways = [0]*(n+5)
    ways[2] = 1
    pows = [1]*(n+5)
    for i in range(1, n+5): pows[i] = pows[i-1] * (m-1) % mod
    # ways[i] = sum_{j < i/2} (C(i, j) * (m-1)^j)
    for i in range(2, n):
        if i%2 == 1:
            lst = i//2
            ways[i+1] = ways[i] * (1 + m-1) - C(i, lst)*pows[lst+1]
        else:
            lst = i//2 - 1
            ways[i+1] = ways[i] * (1 + m-1) + pows[lst+1]*(C(i+1, lst+1) - C(i, lst))
        ways[i+1] %= mod
    ans = 0
    for i in range(1, n+5): pows[i] = pows[i-1] * m % mod
    for L in range(2, n+1, 2):
        ans -= ways[L] * m % mod * (n-L+1) % mod * pows[n-L] % mod
        ans += (n-L+1) * pows[n] % mod
    print(ans % mod)