PROBLEM LINK:
Author: Ritesh Gupta
Tester: Taranpreet Singh
Editorialist: Ritesh Gupta
DIFFICULTY:
\cancel{MEDIUM} EASY-MEDIUM
PREREQUISITES:
\cancel{NTT} PATTERN, COMBINATORICS
PROBLEM:
You are given an array A(|A_i| \le 1) of size N(1 \le N \le 10^5). For every x from -N to N, you have to count the number of non-empty subsequences with a sum equal to x under modulo 163,577,857
QUICK EXPLANATION:
- As we need to find out the sum of subsequences and we know that in this case, the order of elements does not matter. The only thing matter is the count of -1, \space 0, and 1.
- Let suppose, count of -1, \space 0, and 1 is c_{-1}, \space c_0, and c_1 respectively. For every x from -N to N, we can count the number of subsequences with a sum equal to x for -1, \space 0, and 1 separately and represent them as polynomials where power represents the sum of any subsequence and coefficient represents the count of subsequences with a particular sum. Now, the answer is the product of these polynomials.
EXPLANATION:
OBSERVATION:
- The zero can not contribute to the sum of any subsequence and if there are c_0 zeros in the given sequence then all the subsequences constructed using only these zeros are given by 2^{c_0} and all of them have the sum equal to 0.
- As here the coefficients are going to be computed under modulo, so we use NTT over FFT and modulo mentioned in the question is also NTT friendly.
Let assume the count of 1 and -1 in the given sequence is c_1 and c_{-1} and define two polynomials:
First polynomial will be of the form A(x) = a_0x^0 + a_1x^1 + a_2x^2 + ... + a_nx^{c_1}, where a_i is the count of subsequences with a sum equal to i.
Similarly, the other polynomial will be of the form B(x) = b_0x^{0} + b_1x^{-1} + b_2x^{-2} + ... + b_nx^{-c_{-1}}, where b_i is the count of subsequences with a sum equal to - \space i.
We know that the product of these two polynomials P(x) is representing the sum of subsequences made by both 1 and -1. This polynomial will be of the form P(x) = p_{-c_{-1}}x^{-c_{-1}} + p_{(-c_{-1}+1)}x^{(-c_{-1}+1)} + ... + p_{-1}x^{-1} + p_0x^{0} + p_1x^1 + ... + p_{(c_1-1)}x^{(c_1-1)} + p_{c_1}x^{c_1}, where p_i is the count of subsequences with a sum equal to i.
Now, we are going to compute the final answer, in which all the subsequences are considered. As the count of zero is c_0 and subsequences with a sum equal to 0 and only formed using these zeros, are 2^{c_0}. We need to multiply this with each value of P(x). To do that we can modify p_i like this:
p_i = 2^{c_0} * p_i
This includes the empty subarray too. we can remove it by just subtracting -1 from the p_0 and we can print the answer in the given formate.
ALTERNATIVE SOLUTION:
OBSERVATION:
- If we look closely then we find that polynomial P(x) without the NTT. As there is a pattern, we can find a formula to compute each p_i for all valid i.
Let assume the count of 1, -1 and their sum in the given sequence is equal to c_1, c_{-1}, and c = c_1 + c_{-1} respectively.
We can interpolate that the value of p_i equivalent to choosing the choosing (c_1 - i) items from the bag of c items. After we have P(x), we can process the same as the above solution. See the tester’s solution for further help.
COMPLEXITY:
TIME: O(NlogN)
SPACE: O(NlogN)
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 163577857, G = 23, MAXN = 1 << 18;
int gpow[30], invgpow[30];
int fact[MAXN], invfact[MAXN];
int inv[MAXN];
int raise(int number, int exponent) {
int answer = 1;
while (exponent) {
if (exponent & 1) {
answer = answer * number % mod;
}
number = number * number % mod;
exponent >>= 1;
}
return answer;
}
void init() {
fact[0] = 1;
for (int i = 1; i < MAXN; i++) {
fact[i] = fact[i - 1] * i % mod;
}
invfact[MAXN - 1] = raise(fact[MAXN - 1], mod - 2);
for (int i = MAXN - 2; i >= 0; i--) {
invfact[i] = invfact[i + 1] * (i + 1) % mod;
}
inv[1] = 1;
for (int i = 2; i < MAXN; i++) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
int where = (mod - 1) / 2, invg = raise(G, mod - 2);
int idx = 0;
while (where % 2 == 0) {
idx++;
gpow[idx] = raise(G, where);
invgpow[idx] = raise(invg, where);
where /= 2;
}
}
int nCr(int x, int y)
{
if(y>x)
return 0;
int num=fact[x];
num*=invfact[y];
num%=mod;
num*=invfact[x-y];
num%=mod;
return num;
}
void ntt(int *a, int n, int sign) {
for (int i = n >> 1, j = 1; j < n; j++) {
if (i < j) swap(a[i], a[j]);
int k = n >> 1;
while (k & i) {
i ^= k;
k >>= 1;
}
i ^= k;
}
for (int l = 2, idx = 1; l <= n; l <<= 1, idx++) {
int omega = (sign == 1) ? gpow[idx] : invgpow[idx];
for (int i = 0; i < n; i += l) {
int value = 1;
for (int j = i; j < i + (l>>1); j++) {
int u = a[j], v = a[j + (l>>1)] * value % mod;
a[j] = (u + v); a[j] = (a[j] >= mod) ? a[j] - mod : a[j];
a[j + (l>>1)] = (u - v); a[j + (l>>1)] = (a[j + (l>>1)] < 0) ? a[j + (l>>1)] + mod : a[j + (l>>1)];
value = value * omega % mod;
}
}
}
if (sign == -1) {
const int x = raise(n, mod - 2);
for (int i = 0; i < n; i++) {
a[i] = a[i] * x % mod;
}
}
}
void multiply(int* a, int na, int* b, int nb) {
na++; nb++;
int n = 1; while (n < na + nb - 1) n <<= 1;
for (int i = na; i < n; i++) {
a[i] = 0;
}
for (int i = nb; i < n; i++) {
b[i] = 0;
}
ntt(a, n, +1); ntt(b, n, +1);
for (int i = 0; i < n; i++) {
a[i] = a[i] * b[i] % mod;
}
ntt(a, n, -1);
for (int i = na + nb - 1; i < n; i++) {
a[i] = 0;
}
}
int a[MAXN],b[MAXN],ans[MAXN];
int32_t main() {
init();
int t;
cin >> t;
while(t--)
{
int n,x;
cin >> n;
int pos,neg,zero;
pos = neg = zero = 0;
for(int i=1;i<=n;i++)
{
cin >> x;
if(x == 1) pos++;
else if(x == 0) zero++;
else neg++;
}
for(int i=0;i<=pos;i++)
a[i] = nCr(pos,i);
for(int i=0;i<=neg;i++)
b[i] = nCr(neg,i);
multiply(a, pos, b, neg);
for(int i=0;i<=2*n;i++)
ans[i] = 0;
for(int i=0;i<=pos+neg;i++)
ans[n-neg+i] = a[i];
zero = raise(2, zero);
for(int i=0;i<=2*n;i++)
{
ans[i] = zero * ans[i] %mod;
if(i == n)
ans[i] = (ans[i] - 1 + mod)%mod;
cout << ans[i] << " ";
}
cout << endl;
}
return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
import java.text.*;
//Solution Credits: Taranpreet Singh
public class Main{
//SOLUTION BEGIN
long MOD = (long)163577857;
void pre(){}
void solve(int TC) throws Exception{
int n = ni();
int zero = 0, pos = 0, neg = 0;
for(int i = 0; i< n; i++){
int x = ni();
if(x == -1)neg++;
else if(x == 0)zero++;
else pos++;
}
long F = pow(2, zero);
int row = pos+neg;
long prod = 1;
for(int i = -n; i<= n; i++){
long x = 0;
if(i >= -neg && i <= pos){
x = prod;
prod = prod*(row-(i+neg))%MOD;
prod = (prod*pow(i+neg+1, MOD-2))%MOD;
}
x = (x*F)%MOD;
if(i == 0)x = (x+MOD-1)%MOD;
p(x+" ");
}
pn("");
}
long pow(long a, long p){
long o = 1;
for(;p>0;p>>=1){
if((p&1)==1)o = (o*a)%MOD;
a = (a*a)%MOD;
}
return o;
}
//SOLUTION END
long mod = (long)998244353, IINF = (long)1e17;
final int MAX = (int)1e3+1, INF = (int)2e9, root = 3;
DecimalFormat df = new DecimalFormat("0.0000000000000");
double PI = 3.1415926535897932384626433832792884197169399375105820974944, eps = 1e-8;
static boolean multipleTC = true, memory = false;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
int T = (multipleTC)?ni():1;
//Solution Credits: Taranpreet Singh
pre();
for(int i = 1; i<= T; i++)solve(i);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
else new Main().run();
}
long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n(){return in.next();}
String nln(){return in.nextLine();}
int ni(){return Integer.parseInt(in.next());}
long nl(){return Long.parseLong(in.next());}
double nd(){return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next(){
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
e.printStackTrace();
}
}
return st.nextToken();
}
String nextLine(){
String str = "";
try{
str = br.readLine();
}catch (IOException e){
e.printStackTrace();
}
return str;
}
}
}