PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Utkarsh Gupta
Testers: Jatin Garg, Tejas Pandey
Editorialist: Devendra Singh
DIFFICULTY:
1972
PREREQUISITES:
PROBLEM:
For an array A of length N, let F(A) denote the sum of the product of all the subarrays of A. Formally,
For example, let A = [1, 0, 1], then there are 6 possible subarrays:
- Subarray [1, 1] has product = 1
- Subarray [1, 2] has product = 0
- Subarray [1, 3] has product = 0
- Subarray [2, 2] has product = 0
- Subarray [2, 3] has product = 0
- Subarray [3, 3] has product = 1
So F(A) = 1+1 = 2.
Given a binary array A, determine the sum of F(A) over all the N! orderings of A modulo 998244353.
Note that orderings here are defined in terms of indices, not elements; which is why every array of length N has N! orderings. For example, the 3! = 6 orderings of A = [1, 0, 1] are:
- [1, 0, 1] corresponding to indices [1, 2, 3]
- [1, 1, 0] corresponding to indices [1, 3, 2]
- [0, 1, 1] corresponding to indices [2, 1, 3]
- [0, 1, 1] corresponding to indices [2, 3, 1]
- [1, 1, 0] corresponding to indices [3, 1, 2]
- [1, 0, 1] corresponding to indices [3, 2, 1]
EXPLANATION:
Since the array consists of zeroes and ones only the product of a subarray can only be 1 or 0. The product of a subarray is 1 if and only if it consists of all ones. Each such subarray consisting of all ones contributes 1 to the final the answer. Therefore the problem is reduced to finding number of subarrays that consists of all ones over all the N! orderings of the array A.
Let C_1 represent the number of ones (count of ones) in the array A. Then for each length len from 1 to C_1, we can find number of subarrays of length len consisting of only ones over all N! orderings of the array A by using combinatorics as:
- Select len indices of ones from C_1 indices of ones: ^{C_1}C_{len}.
- Total arrangements of this subarray of length len are Factorial_{len}
- Starting positions in the array for this subarray are N-len+1
- Total arrangements for rest of the numbers in the array are Factorial_{N-len}.
The product of these four values is the number of subarrays (Their contribution to the answer) of length len consisting of only ones over all N! orderings of the array A. Add the answer for each length from len=1 to C_1 to get the final answer.
The Binomial coefficients can be precalculated to improve the runtime of the algorithm. For details of implementation please refer to the solutions attached.
TIME COMPLEXITY:
O(N) for each test case.
SOLUTION:
Setter's solution
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 998244353
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
int sumN=0;
ll fact[N];
ll invfact[N];
ll inv[N];
void factorialsComputation()
{
inv[0]=inv[1]=1;
fact[0]=fact[1]=1;
invfact[0]=invfact[1]=1;
for(int i=2;i<N;i++)
{
inv[i]=(inv[mod%i]*(mod-mod/i))%mod;
fact[i]=(fact[i-1]*i)%mod;
invfact[i]=(invfact[i-1]*inv[i])%mod;
}
}
ll ncr(ll n,ll r)
{
ll ans=fact[n]*invfact[r];
ans%=mod;
ans*=invfact[n-r];
ans%=mod;
return ans;
}
void solve()
{
int n=readInt(1,100000,'\n');
sumN+=n;
assert(sumN<=200000);
int A[n+1]={0};
for(int i=1;i<=n;i++)
{
if(i==n)
A[i]=readInt(0,1,'\n');
else
A[i]=readInt(0,1,' ');
}
ll ans=0;
ll cnt0=0,cnt1=0;
for(int i=1;i<=n;i++)
{
if(A[i]==1)
cnt1++;
else
cnt0++;
}
ll tmp[n+1]={0};
for(int len=1;len<=cnt1;len++)
{
tmp[len]=ncr(cnt1,len)*fact[len];
tmp[len]%=mod;
tmp[len]*=fact[n-len];
tmp[len]%=mod;
ans+=(n-len+1)*tmp[len];
ans%=mod;
}
cout<<ans<<'\n';
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
ios_base::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
int T=readInt(1,1000,'\n');
factorialsComputation();
while(T--)
solve();
assert(getchar()==-1);
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester-1's Solution
// Jai Shree Ram
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=a;i<n;i++)
#define ll long long
#define int long long
#define pb push_back
#define all(v) v.begin(),v.end()
#define endl "\n"
#define x first
#define y second
#define gcd(a,b) __gcd(a,b)
#define mem1(a) memset(a,-1,sizeof(a))
#define mem0(a) memset(a,0,sizeof(a))
#define sz(a) (int)a.size()
#define pii pair<int,int>
#define hell 1000000007
#define elasped_time 1.0 * clock() / CLOCKS_PER_SEC
template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}
// -------------------- Input Checker Start --------------------
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
// -------------------- Input Checker End --------------------
long long sum_n = 0;
const int MOD = 998244353;
struct mod_int {
int val;
mod_int(long long v = 0) {
if (v < 0)
v = v % MOD + MOD;
if (v >= MOD)
v %= MOD;
val = v;
}
static int mod_inv(int a, int m = MOD) {
int g = m, r = a, x = 0, y = 1;
while (r != 0) {
int q = g / r;
g %= r; swap(g, r);
x -= q * y; swap(x, y);
}
return x < 0 ? x + m : x;
}
explicit operator int() const {
return val;
}
mod_int& operator+=(const mod_int &other) {
val += other.val;
if (val >= MOD) val -= MOD;
return *this;
}
mod_int& operator-=(const mod_int &other) {
val -= other.val;
if (val < 0) val += MOD;
return *this;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return x % m;
#endif
unsigned x_high = x >> 32, x_low = (unsigned) x;
unsigned quot, rem;
asm("divl %4\n"
: "=a" (quot), "=d" (rem)
: "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
mod_int& operator*=(const mod_int &other) {
val = fast_mod((uint64_t) val * other.val);
return *this;
}
mod_int& operator/=(const mod_int &other) {
return *this *= other.inv();
}
friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
mod_int& operator++() {
val = val == MOD - 1 ? 0 : val + 1;
return *this;
}
mod_int& operator--() {
val = val == 0 ? MOD - 1 : val - 1;
return *this;
}
mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
mod_int operator-() const {
return val == 0 ? 0 : MOD - val;
}
bool operator==(const mod_int &other) const { return val == other.val; }
bool operator!=(const mod_int &other) const { return val != other.val; }
mod_int inv() const {
return mod_inv(val);
}
mod_int pow(long long p) const {
assert(p >= 0);
mod_int a = *this, result = 1;
while (p > 0) {
if (p & 1)
result *= a;
a *= a;
p >>= 1;
}
return result;
}
friend ostream& operator<<(ostream &stream, const mod_int &m) {
return stream << m.val;
}
friend istream& operator >> (istream &stream, mod_int &m) {
return stream>>m.val;
}
};
#define NCR
const int N = 1e5 + 5;
mod_int fact[N],inv[N];
void init(int n=N){
fact[0]=inv[0]=inv[1]=1;
rep(i,1,N)fact[i]=i*fact[i-1];
rep(i,2,N)inv[i]=fact[i].inv();
}
mod_int C(int n,int r){
if(r>n || r<0)return 0;
return fact[n]*inv[n-r]*inv[r];
}
// (len!)*(n - len)!
int solve(){
int n = readIntLn(1,1e5);
auto a = readVectorInt(n,0,1);
int cnt = count(all(a),1);
// C(n - 1,cnt - 1) + C(n - 2,cnt - 2) ....
vector<mod_int> pref(cnt + 1);
mod_int ans = 0;
for(int i = 1; i <= cnt; i++){
pref[i] = pref[i - 1] + C(cnt,i)*fact[i]*fact[n - i];
ans += pref[i];
}
ans += (n - cnt)*pref[cnt];
cout << ans << endl;
return 0;
}
signed main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#ifdef SIEVE
sieve();
#endif
#ifdef NCR
init();
#endif
int t = readIntLn(1,1000);
while(t--){
solve();
}
assert(sum_n <= 2e5);
return 0;
}
Tester-'2 Solution
#include <bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
const int MAXT = 1000;
const int MAXN = 100000;
const int MAXA = 1;
const int SUMN = 200000;
int sumN = 0;
#define ll long long int
#define mod 998244353
#define N 200007
ll mpow(ll a, ll b) {
ll res = 1;
while(b) {
if(b&1) res *= a, res %= mod;
a *= a;
a %= mod;
b >>= 1;
}
return res;
}
ll fact[N];
ll invfact[N];
ll inv[N];
void pre() {
inv[0]=inv[1]=1;
fact[0]=fact[1]=1;
invfact[0]=invfact[1]=1;
for(int i=2;i<N;i++) {
inv[i]=(inv[mod%i]*(mod-mod/i))%mod;
fact[i]=(fact[i-1]*i)%mod;
invfact[i]=(invfact[i-1]*inv[i])%mod;
}
}
ll comb(ll n,ll r) {
ll ans=fact[n]*invfact[r];
ans%=mod;
ans*=invfact[n-r];
ans%=mod;
return ans;
}
void solve()
{
long long int n = readInt(1, MAXN, '\n');
sumN += n;
assert(sumN <= SUMN);
int a[n];
for(int i = 0; i< n - 1; i++) a[i] = readInt(0, MAXA, ' ');
a[n - 1] = readInt(0, MAXA, '\n');
int c[2] = {0, 0};
for(int i = 0; i < n; i++) c[a[i]]++;
c[1] = n - c[0];
if(c[0] < 2) {
if(c[0]) {
ll ans = 0;
for(ll i = 0; i <= c[1]; i++) {
ll x = (((i*(i + 1))/2)%mod + (((c[1] - i)*((c[1] - i) + 1))/2)%mod)%mod;
ll val = (comb(c[1], i)*fact[i])%mod;
val *= fact[c[1] - i];
val %= mod;
val *= x;
val %= mod;
ans += val;
ans %= mod;
}
cout << ans << "\n";
}
else cout << (((n*(n + 1))/2)%mod*fact[n])%mod << "\n";
return;
}
ll ans = 0;
for(ll i = 1; i <= c[1]; i++) {
ll val = (i*(i + 1)/2)%mod;
ll grps = (comb(c[1], i)*fact[i])%mod;
grps *= (comb(c[0], 2)*2)%mod;
grps %= mod;
grps *= fact[n - 2 - i];
grps %= mod;
grps *= (n - 1 - i);
grps %= mod;
grps *= val;
grps %= mod;
ans += grps;
ans %= mod;
ll g2 = fact[n - 1 - i];
g2 *= (comb(c[1], i)*fact[i])%mod;
g2 %= mod;
g2 *= c[0];
g2 %= mod;
g2 *= val;
g2 %= mod;
ans += g2;
ans %= mod;
ans += g2;
ans %= mod;
}
cout << ans << "\n";
}
int main()
{
pre();
ios_base::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
int T=readInt(1,MAXT,'\n');
while(T--)
solve();
assert(getchar()==-1);
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's Solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define pb push_back
#define all(_obj) _obj.begin(), _obj.end()
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
ll INF = 1e18;
const int N = 2e5 + 11, mod = 998244353;
ll max(ll a, ll b) { return ((a > b) ? a : b); }
ll min(ll a, ll b) { return ((a > b) ? b : a); }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll factorial[N], inverse_factorial[N], NumInverse[N];
long long binomial_coefficient(int n, int k)
{
return factorial[n] * inverse_factorial[k] % mod * inverse_factorial[n - k] % mod;
}
void sol(void)
{
ll ans = 0;
int n, cnt1 = 0;
cin >> n;
vll v(n);
for (int i = 0; i < n; i++)
cin >> v[i], cnt1 += v[i];
for (int i = 1; i <= cnt1; i++)
{
ans += binomial_coefficient(cnt1, i) * factorial[i] % mod * (n - i + 1) % mod * factorial[n - i];
ans %= mod;
}
cout << ans << '\n';
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL), cout.tie(NULL);
NumInverse[0] = NumInverse[1] = 1;
factorial[0] = factorial[1] = 1;
inverse_factorial[0] = inverse_factorial[1] = 1;
for (int i = 2; i < N; i++)
{
NumInverse[i] = NumInverse[mod % i] * (mod - mod / i) % mod;
factorial[i] = factorial[i - 1] * i % mod;
inverse_factorial[i] = (NumInverse[i] * inverse_factorial[i - 1]) % mod;
}
int test = 1;
cin>>test;
while (test--)
sol();
}