SUMOFPROD2-Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Utkarsh Gupta
Testers: Jatin Garg, Tejas Pandey
Editorialist: Devendra Singh

DIFFICULTY:

1972

PREREQUISITES:

Binomial Coefficient

PROBLEM:

For an array A of length N, let F(A) denote the sum of the product of all the subarrays of A. Formally,

F(A) = \sum_{L=1}^N \sum_{R=L}^N \left (\prod_{i=L}^R A_i\right )

For example, let A = [1, 0, 1], then there are 6 possible subarrays:

  • Subarray [1, 1] has product = 1
  • Subarray [1, 2] has product = 0
  • Subarray [1, 3] has product = 0
  • Subarray [2, 2] has product = 0
  • Subarray [2, 3] has product = 0
  • Subarray [3, 3] has product = 1

So F(A) = 1+1 = 2.

Given a binary array A, determine the sum of F(A) over all the N! orderings of A modulo 998244353.

Note that orderings here are defined in terms of indices, not elements; which is why every array of length N has N! orderings. For example, the 3! = 6 orderings of A = [1, 0, 1] are:

  • [1, 0, 1] corresponding to indices [1, 2, 3]
  • [1, 1, 0] corresponding to indices [1, 3, 2]
  • [0, 1, 1] corresponding to indices [2, 1, 3]
  • [0, 1, 1] corresponding to indices [2, 3, 1]
  • [1, 1, 0] corresponding to indices [3, 1, 2]
  • [1, 0, 1] corresponding to indices [3, 2, 1]

EXPLANATION:

Since the array consists of zeroes and ones only the product of a subarray can only be 1 or 0. The product of a subarray is 1 if and only if it consists of all ones. Each such subarray consisting of all ones contributes 1 to the final the answer. Therefore the problem is reduced to finding number of subarrays that consists of all ones over all the N! orderings of the array A.
Let C_1 represent the number of ones (count of ones) in the array A. Then for each length len from 1 to C_1, we can find number of subarrays of length len consisting of only ones over all N! orderings of the array A by using combinatorics as:

  • Select len indices of ones from C_1 indices of ones: ^{C_1}C_{len}.
  • Total arrangements of this subarray of length len are Factorial_{len}
  • Starting positions in the array for this subarray are N-len+1
  • Total arrangements for rest of the numbers in the array are Factorial_{N-len}.

The product of these four values is the number of subarrays (Their contribution to the answer) of length len consisting of only ones over all N! orderings of the array A. Add the answer for each length from len=1 to C_1 to get the final answer.
The Binomial coefficients can be precalculated to improve the runtime of the algorithm. For details of implementation please refer to the solutions attached.

TIME COMPLEXITY:

O(N) for each test case.

SOLUTION:

Setter's solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0;
ll fact[N];
ll invfact[N];
ll inv[N];
void factorialsComputation()
{
    inv[0]=inv[1]=1;
    fact[0]=fact[1]=1;
    invfact[0]=invfact[1]=1;
    for(int i=2;i<N;i++)
    {
        inv[i]=(inv[mod%i]*(mod-mod/i))%mod;
        fact[i]=(fact[i-1]*i)%mod;
        invfact[i]=(invfact[i-1]*inv[i])%mod;
    }
}
ll ncr(ll n,ll r)
{
    ll ans=fact[n]*invfact[r];
    ans%=mod;
    ans*=invfact[n-r];
    ans%=mod;
    return ans;
}
void solve()
{
    int n=readInt(1,100000,'\n');
    sumN+=n;
    assert(sumN<=200000);
    int A[n+1]={0};
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(0,1,'\n');
        else
            A[i]=readInt(0,1,' ');
    }
    ll ans=0;
    ll cnt0=0,cnt1=0;
    for(int i=1;i<=n;i++)
    {
        if(A[i]==1)
            cnt1++;
        else
            cnt0++;
    }
    ll tmp[n+1]={0};
    for(int len=1;len<=cnt1;len++)
    {
        tmp[len]=ncr(cnt1,len)*fact[len];
        tmp[len]%=mod;
        tmp[len]*=fact[n-len];
        tmp[len]%=mod;
        ans+=(n-len+1)*tmp[len];
        ans%=mod;
    }
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,1000,'\n');
    factorialsComputation();
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester-1's Solution
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}

// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------
long long sum_n = 0;

const int MOD = 998244353;
 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};
#define NCR
const int N = 1e5 + 5;
mod_int fact[N],inv[N];
void init(int n=N){
	fact[0]=inv[0]=inv[1]=1;
	rep(i,1,N)fact[i]=i*fact[i-1];
	rep(i,2,N)inv[i]=fact[i].inv();
}
mod_int C(int n,int r){
	if(r>n || r<0)return 0;
	return fact[n]*inv[n-r]*inv[r];
}

// (len!)*(n - len)!
int solve(){
 		int n = readIntLn(1,1e5);
 		auto a = readVectorInt(n,0,1);

 		int cnt = count(all(a),1);

 		// C(n - 1,cnt - 1) + C(n - 2,cnt - 2) .... 
 		vector<mod_int> pref(cnt + 1);
 		mod_int ans = 0;
 		for(int i = 1; i <= cnt; i++){
 			pref[i] = pref[i - 1] + C(cnt,i)*fact[i]*fact[n - i];
 			ans += pref[i];
 		}
 		ans += (n - cnt)*pref[cnt];

		cout << ans << endl; 		
 		


 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1,1000);
    while(t--){
        solve();
    }
    assert(sum_n <= 2e5);
    return 0;
}
 
Tester-'2 Solution
#include <bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

const int MAXT = 1000;
const int MAXN = 100000;
const int MAXA = 1;
const int SUMN = 200000;
int sumN = 0;

#define ll long long int
#define mod 998244353
#define N 200007

ll mpow(ll a, ll b) {
	ll res = 1;
	while(b) {
		if(b&1) res *= a, res %= mod;
		a *= a;
		a %= mod;
		b >>= 1;
	}
	return res;
}

ll fact[N];
ll invfact[N];
ll inv[N];
void pre() {
    inv[0]=inv[1]=1;
    fact[0]=fact[1]=1;
    invfact[0]=invfact[1]=1;
    for(int i=2;i<N;i++) {
        inv[i]=(inv[mod%i]*(mod-mod/i))%mod;
        fact[i]=(fact[i-1]*i)%mod;
        invfact[i]=(invfact[i-1]*inv[i])%mod;
    }
}
ll comb(ll n,ll r) {
    ll ans=fact[n]*invfact[r];
    ans%=mod;
    ans*=invfact[n-r];
    ans%=mod;
    return ans;
}

void solve()
{
	long long int n = readInt(1, MAXN, '\n');
	sumN += n;
	assert(sumN <= SUMN);
	int a[n];
	for(int i = 0; i< n - 1; i++) a[i] = readInt(0, MAXA, ' ');
	a[n - 1] = readInt(0, MAXA, '\n');
	int c[2] = {0, 0};
		for(int i = 0; i < n; i++) c[a[i]]++;
		c[1] = n - c[0];
		if(c[0] < 2) {
			if(c[0]) {
				ll ans = 0;
				for(ll i = 0; i <= c[1]; i++) {
					ll x = (((i*(i + 1))/2)%mod + (((c[1] - i)*((c[1] - i) + 1))/2)%mod)%mod;
					ll val = (comb(c[1], i)*fact[i])%mod;
					val *= fact[c[1] - i];
					val %= mod;
					val *= x;
					val %= mod;
					ans += val;
					ans %= mod;
				}
				cout << ans << "\n";
			}
			else cout << (((n*(n + 1))/2)%mod*fact[n])%mod << "\n";
			return;
		}
		ll ans = 0;
		for(ll i = 1; i <= c[1]; i++) {
			ll val = (i*(i + 1)/2)%mod;
			ll grps = (comb(c[1], i)*fact[i])%mod;
			grps *= (comb(c[0], 2)*2)%mod;
			grps %= mod;
			grps *= fact[n - 2 - i];
			grps %= mod;
			grps *= (n - 1 - i);
			grps %= mod;
			grps *= val;
			grps %= mod;
			ans += grps;
			ans %= mod;
			ll g2 = fact[n - 1 - i];
			g2 *= (comb(c[1], i)*fact[i])%mod;
			g2 %= mod;
			g2 *= c[0];
			g2 %= mod;
			g2 *= val;
			g2 %= mod;
			ans += g2;
			ans %= mod;
			ans += g2;
			ans %= mod;
		}
		cout << ans << "\n";
}
int main()
{
	pre();
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,MAXT,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's Solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e18;
const int N = 2e5 + 11, mod = 998244353;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll factorial[N], inverse_factorial[N], NumInverse[N];
long long binomial_coefficient(int n, int k)
{
    return factorial[n] * inverse_factorial[k] % mod * inverse_factorial[n - k] % mod;
}
void sol(void)
{
    ll ans = 0;
    int n, cnt1 = 0;
    cin >> n;
    vll v(n);
    for (int i = 0; i < n; i++)
        cin >> v[i], cnt1 += v[i];
    for (int i = 1; i <= cnt1; i++)
    {
        ans += binomial_coefficient(cnt1, i) * factorial[i] % mod * (n - i + 1) % mod * factorial[n - i];
        ans %= mod;
    }
    cout << ans << '\n';
    return;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);
    NumInverse[0] = NumInverse[1] = 1;
    factorial[0] = factorial[1] = 1;
    inverse_factorial[0] = inverse_factorial[1] = 1;
    for (int i = 2; i < N; i++)
    {
        NumInverse[i] = NumInverse[mod % i] * (mod - mod / i) % mod;
        factorial[i] = factorial[i - 1] * i % mod;
        inverse_factorial[i] = (NumInverse[i] * inverse_factorial[i - 1]) % mod;
    }
    int test = 1;
    cin>>test;
    while (test--)
        sol();
}


3 Likes

I got a one-line solution, however I can’t prove it.
Let p be the number of 1s in the array, and q be the number of 0s.
Then the answer is (q + 1) \times C(N + 1, q + 2) \times factorial(p) \times factorial(q).
With precalculation, the time complexity is O(1) per case.
My submission

1 Like

Here is the correct link for his single line solution

can u please explain how u get this formula?

The formula simplifies to \displaystyle \frac{p (n+1)!}{q+2} . (Solution)
After I got recurrent formulas right, I became slowly aware that there will be cascade of simplifications by using Hockey-stick identity several times. I checked table of solutions after dividing them by p! q! that is produced by recurrences and saw binomial coefficients.
I suppose this is the same way or similar that you got the solution.

1 Like

Proof of the formula :)

1 Like

@devendra7700

  • Starting positions in the array for this subarray are N−len+1
    i can’t understand because of this formula N - len + 1

The following is an example.

It says that if the array is [10, 20, 30, 40, 50] with N = 5, and if we are counting subarrays of len=2 then there are N - len + 1 = 4 such subarrays because their starting positions are:

  • 1 → 10 for subarray [10, 20]
  • 2 → 20 for subarray [20, 30]
  • 3 → 30 for subarray [30, 40]
  • 4 → 40 for subarray [40, 50].

#include <bits/stdc++.h>

#define MOD 998244353

using namespace std;

long long fac[1000010], inv[1000010], finv[1000010];

long long C(long long x, long long y){
if(x < 0 || y > x) return 0;

return fac[x] * finv[y] % MOD * finv[x-y] % MOD;

}

#define int long long

void solve(){
int n,x,c0=0,c1=0,t=0;
cin>>n;

for(int i=1; i<=n; ++i){
    cin>>x;
    c0 += x == 0;
    c1 += x==1;
    
}

for(int i=0; i<=c1; ++i){
    t = (t+i*C(c1 + c0 -i, c0)) % MOD;
    
    
}

cout<<(((t*(c0+1) - C(c1+c0-2, c0-1)) % MOD + MOD ) % MOD + C(c1+c0-2, c0-1)) * fac[c1] %  MOD * fac[c0] % MOD <<endl;

}

signed main(){
    fac[0] = inv[0] = inv[1] = finv[0] = finv[1] = 1;
    
    for(long long i =1; i<=1000000; ++i){
        fac[i] = (fac[i-1] * i) % MOD;
    }
    
    for(long long i = 2; i<=1000000; ++i){
        inv[i] = MOD - MOD/i * inv[MOD % i] % MOD;
    }
    
    for(long long i = 2; i<=1000000; ++i){
        finv[i] = finv[i-1] * inv[i] % MOD;
        
    }
    
    int t; cin>>t;
    while(t--){
        solve();
    }
}

IN the above solution I am not able to figure out what does inv[] and finv[] array representing.

inv[i] is the remainder such that inv[i] * i % MOD = 1 while finv[i] * fac[i] % MOD = 1. These two are reciprocals to i and fac[i].

See Modular Inverse - Algorithms for Competitive Programming and Factorial modulo p - Algorithms for Competitive Programming