DISTNEIGH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Satyam
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

2774

PREREQUISITES:

Combinatorics, Stars and bars

PROBLEM:

Given N, x, y, find the number of arrays P satisfying the following:

  • P contains exactly x ones, y twos, and N-x-y threes
  • P_1 = P_N = 1
  • No two adjacent elements of P are equal.

EXPLANATION:

This is a pretty much a purely combinatorial task.

The main observation is as follows:
Consider two adjacent ones in any valid array P. Then, the subarray between them must consist of alternating twos and threes.
In particular, if we fix the first element of this subarray, that fixes the entire subarray.

Now, note that any subarray between two ones will be in one of four forms:

  • Even length, and starting with either 2 or 3
  • Odd length, and starting with either 2 or 3

The even length arrays are functionally equivalent since they contain an equal number of twos and threes.
However, the odd subarrays aren’t equivalent: the starting element appears one more time than the rest. Let’s try to use this.

Suppose there are k_1 odd length subarrays that start with a 2, k_2 odd length subarrays that start with a 3, and k_3 even length subarrays. Then, we have the following equations:

  • k_1 + k_2 + k_3 = x-1, since every 1 except the last has one of these subarrays immediately following it
  • k_1 - k_2 = y - z, since each odd-length subarray starting with 2 contributes one extra 2, and each odd-length subarray starting with 3 contributes one extra 3.

Let’s fix the value of k_1 (it can be anything from 0 to x-1).
Note that the above equations allow us to compute uniquely the values of k_2 and k_3.
If either k_2 or k_3 are invalid (i.e, negative or \gt x-1), then ignore this value of k_1.

So, suppose we know k_1, k_2, k_3. Let’s try to count the number of arrays with these parameters.

That can be done in a few steps.

  • First, there are x-1 subarrays. Let’s fix which ones of them are even length and which are odd: this gives us \binom{x-1}{k_3} choices.
  • Among the odd-length ones, let’s fix which ones start with 2 and which ones start with 3: this gives us \binom{k_1 + k_2}{k_1} choices.
  • The even-length subarrays can be arranged in two ways (starting with 2 or 3), so we have 2^{k_3} choices there.
  • Finally, we need to count the number of ways to choose lengths for these subarrays. That is an application of the stars-and-bars technique.
How?

There are exactly x-1 subarrays. Suppose the i-th of them has length a_i.

The total length of the subarrays must equal the number of twos and threes, i.e, y+z.
So,

a_1 + a_2 + \ldots + a_{x-1} = y+z

However, there are some more constraints:

  • Exactly k_1 + k_2 of these values are odd. Note that we have already fixed which ones are odd with a binomial coefficient, so without loss of generality we can assume a_1, a_2, \ldots, a_{k_1 + k_2} are odd and the rest are even.
  • The even values must be strictly positive (since if they had length 0, two ones would be adjacent).

Let’s write the odd values as a_i = 2b_i + 1, and the even values as a_i = 2 + 2b_i. Plugging this into the first equation,

(2b_1 + 1) + \ldots + (2b_{k_1 + k_2} + 1) + (2 + 2b_{k_1 + k_2 + 1}) + \ldots + (2 + 2b_{x-1}) = y + z \\ \implies 2b_1 + 2b_2 + \ldots + 2b_{x-1} = y+z-(k_1+k_2)-2k_3 \\ \implies b_1 + b_2 + \ldots + b_{x-1} = \frac{y+z-(k_1+k_2)-2k_3}{2}

The b_i have no constraint, other than the fact that they have to be non-negative integers.
Any solution to this equation gives us a valid solution to the original equation by reversing the a_i \to b_i process, so it’s enough to count the number of solutions to this equation.
Do note that if y+z-k_1-k_2-2k_3 is odd, then the equation has no solution.

Counting the number of solutions here is a direct application of stars and bars, and is simply \binom{M+x-2}{M}, where M = \frac{y+z-(k_1+k_2)-2k_3}{2}.

So, the final solution is to simply add up

\binom{x-1}{k_3}\binom{k_1+k_2}{k_1}\binom{M+x-2}{M}2^{k_3}

where M = \frac{y+z-(k_1+k_2)-2k_3}{2}, across all possible k_1.

Each binomial coefficient can be computed in \mathcal{O}(1) if factorials and their inverses are precomputed, and the power of 2 can be computed in \mathcal{O}(\log N) using binary exponentiation (or all required powers of 2 can be precomputed, since we only need powers \leq N anyway).

This gives us a solution in \mathcal{O}(N\log N).

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
// #pragma GCC optimize("O3")
// #pragma GCC target("popcnt")
// #pragma GCC target("avx,avx2,fma")
// #pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;    
#define pb push_back                    
#define mp make_pair          
#define nline "\n"                           
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  
      
const ll MOD=998244353; 
  
        
const ll MAX=5000300; 
  
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){  
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}  
ll inverse(ll a,ll MOD){   
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;  
}
void solve(){ 
    ll n,a,b; cin>>n>>a>>b;    
    ll ans=0;    
    ll c=n-a-b;  
    
    assert(a>=2);
    for(ll i=b-c;i<a;i++){      
        ll diff=b-c;     
        if((diff&1)!=(i&1)){   
            continue;        
        }    
        ll l=(i+diff)/2,r=(i-diff)/2;    
        ll now=nCr(i,l,MOD);   
        ll lft=b+c-i;  
        if(lft&1){  
            continue;   
        }      
        lft/=2;         
        ll ext=a-1-i; 
        now=(now*nCr(lft-ext+a-2,a-2,MOD))%MOD;  
        now=(now*binpow(2,ext,MOD))%MOD;
        now=(now*nCr(a-1,i,MOD))%MOD;
        ans+=now; ans%=MOD;  
    }  
    cout<<ans;
    return;                     
}                                      
int main()                                                                           
{  
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                           
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                              
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases=1;               
    //cin>>test_cases;
    precompute(MOD); 
    while(test_cases--){     
        solve();       
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string& pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int n = in.readInt(3, 2e6);
    in.readSpace();
    int x = in.readInt(2, n - 1);
    in.readSpace();
    int y = in.readInt(1, n - 2);
    in.readEoln();
    assert(x + y <= n);
    int z = n - x - y;
    mint ans = 0;
    if (x >= 2) {
        for (int i = 1; i <= min(x - 1, y); i++) {
            int same = x - 1 - i + y - i;
            int dif = x + y - 1 - same;
            if (same <= z) {
                ans += C(x - 1, i) * C(y - 1, i - 1) * C(dif, z - same);
            }
        }
    }
    cout << ans << endl;
    in.readEof();
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 998244353;

const int MAXN = 3e6 + 5;
ll fac[MAXN], invf[MAXN];

ll modpow(ll a, ll n) {
	ll r = 1;
	while (n) {
		if (n & 1) r = (r * a) % mod;
		a = (a * a) % mod;
		n /= 2;
	}
	return r;
}

ll C(int n, int r) {
	if (n < r or n < 0) return 0;
	ll ret = (fac[n] * invf[r]) % mod;
	return (ret * invf[n-r]) % mod;
}

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	fac[0] = invf[0] = 1;
	for (int i = 1; i < MAXN; ++i) {
		fac[i] = (i * fac[i-1]) % mod;
		invf[i] = modpow(fac[i], mod-2);
	}

	int n, x, y; cin >> n >> x >> y;
	int z = n - x - y;
	ll ans = 0;
	for (int i = 0; i <= n; ++i) {
		int j = i + z - y;
		int k = x-1 - i - j;
		// i odd starting with 2, j odd starting with 3, k even
		if (min({i, j, k}) < 0) continue;
		if (max({i, j, k}) > x-1) continue;
		int targetsum = n - x - i - j - 2*k;
		if (targetsum < 0 or (targetsum & 1)) continue;
		targetsum /= 2;

		ll add = C(x-1, k) * C(i+j, i); add %= mod;
		add *= C(targetsum + x - 2, targetsum); add %= mod;
		add *= modpow(2, k); add %= mod;
		ans += add;
	}
	cout << ans%mod << '\n';
}
2 Likes

I solved it with slightly different method :slight_smile:
The idea is to imagine the array with all 3's removed. There would be blocks of 1's and 2's. Since a[1] = 1 and a[n] = 1, the array would look like this: 11..1 \; 22..2 \; 11..1 \; 22..2 \; 11..1. Clearly no. of blocks will be odd with alternating 1's and 2's. Now imagine adding 3's. To “fix” a block of size s you need to use 3 exactly s-1 times. For example, to fix 11111 you need to use 3 four times (to get 131313131).

So, if there are k blocks, you require at least x+y-k threes to fix, Now there are additional (n-x-y ) - (x+y-k) threes left, there are also k-1 borders between 1's and 2's in which we have to insert these additional 3's. So, we gotta select (n+k-2x-2y) spaces among total of k-1 spaces. There are \binom{k-1}{n+k-2x-2y} ways to do so!

Now, to form blocks, note that total size of blocks of 1's is exactly x, let w_{i} be the length of the i^{th} block of 1's. Also, there are ceil(k/2) blocks of 1's and floor(k/2) blocks of 2's We need to find positive integral solutions of -

w_{1} + w_{2}...+w_{ceil(k/2)} = x.
This is simply, \binom{x-1}{ceil(k/2)-1}. Similar result for no. of ways of making blocks of 2's is \binom{y-1}{floor(k/2)-1}.

So finally the answer would be to sum \binom{x-1}{ceil(k/2)-1}\binom{y-1}{floor(k/2)-1}\binom{k-1}{n+k-2x-2y} for all odd k from 1 to n.

Link for my submission: CodeChef: Practical coding for everyone
:slight_smile:

3 Likes