RECNDSUB - Editorial

PROBLEM LINK:

Practice
Contest

Author: Ritesh Gupta
Tester: Taranpreet Singh
Editorialist: Ritesh Gupta

DIFFICULTY:

\cancel{MEDIUM} EASY-MEDIUM

PREREQUISITES:

\cancel{NTT} PATTERN, COMBINATORICS

PROBLEM:

You are given an array A(|A_i| \le 1) of size N(1 \le N \le 10^5). For every x from -N to N, you have to count the number of non-empty subsequences with a sum equal to x under modulo 163,577,857

QUICK EXPLANATION:

  • As we need to find out the sum of subsequences and we know that in this case, the order of elements does not matter. The only thing matter is the count of -1, \space 0, and 1.
  • Let suppose, count of -1, \space 0, and 1 is c_{-1}, \space c_0, and c_1 respectively. For every x from -N to N, we can count the number of subsequences with a sum equal to x for -1, \space 0, and 1 separately and represent them as polynomials where power represents the sum of any subsequence and coefficient represents the count of subsequences with a particular sum. Now, the answer is the product of these polynomials.

EXPLANATION:

OBSERVATION:

  • The zero can not contribute to the sum of any subsequence and if there are c_0 zeros in the given sequence then all the subsequences constructed using only these zeros are given by 2^{c_0} and all of them have the sum equal to 0.
  • As here the coefficients are going to be computed under modulo, so we use NTT over FFT and modulo mentioned in the question is also NTT friendly.

Let assume the count of 1 and -1 in the given sequence is c_1 and c_{-1} and define two polynomials:

First polynomial will be of the form A(x) = a_0x^0 + a_1x^1 + a_2x^2 + ... + a_nx^{c_1}, where a_i is the count of subsequences with a sum equal to i.

Similarly, the other polynomial will be of the form B(x) = b_0x^{0} + b_1x^{-1} + b_2x^{-2} + ... + b_nx^{-c_{-1}}, where b_i is the count of subsequences with a sum equal to - \space i.

We know that the product of these two polynomials P(x) is representing the sum of subsequences made by both 1 and -1. This polynomial will be of the form P(x) = p_{-c_{-1}}x^{-c_{-1}} + p_{(-c_{-1}+1)}x^{(-c_{-1}+1)} + ... + p_{-1}x^{-1} + p_0x^{0} + p_1x^1 + ... + p_{(c_1-1)}x^{(c_1-1)} + p_{c_1}x^{c_1}, where p_i is the count of subsequences with a sum equal to i.

Now, we are going to compute the final answer, in which all the subsequences are considered. As the count of zero is c_0 and subsequences with a sum equal to 0 and only formed using these zeros, are 2^{c_0}. We need to multiply this with each value of P(x). To do that we can modify p_i like this:

p_i = 2^{c_0} * p_i

This includes the empty subarray too. we can remove it by just subtracting -1 from the p_0 and we can print the answer in the given formate.

ALTERNATIVE SOLUTION:

OBSERVATION:

  • If we look closely then we find that polynomial P(x) without the NTT. As there is a pattern, we can find a formula to compute each p_i for all valid i.

Let assume the count of 1, -1 and their sum in the given sequence is equal to c_1, c_{-1}, and c = c_1 + c_{-1} respectively.

We can interpolate that the value of p_i equivalent to choosing the choosing (c_1 - i) items from the bag of c items. After we have P(x), we can process the same as the above solution. See the tester’s solution for further help.

COMPLEXITY:

TIME: O(NlogN)
SPACE: O(NlogN)

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
 
#define int long long
 
using namespace std;
 
const int mod = 163577857, G = 23, MAXN = 1 << 18;
 
int gpow[30], invgpow[30];
int fact[MAXN], invfact[MAXN];
int inv[MAXN];
 
int raise(int number, int exponent) {
    int answer = 1;
    while (exponent) {
        if (exponent & 1) {
            answer = answer * number % mod;
        }
        number = number * number % mod;
        exponent >>= 1;
    }
    return answer;
}
 
void init() {
    fact[0] = 1;
    for (int i = 1; i < MAXN; i++) {
        fact[i] = fact[i - 1] * i % mod;
    }
    invfact[MAXN - 1] = raise(fact[MAXN - 1], mod - 2);
    for (int i = MAXN - 2; i >= 0; i--) {
        invfact[i] = invfact[i + 1] * (i + 1) % mod;
    }
    inv[1] = 1;
    for (int i = 2; i < MAXN; i++) {
        inv[i] = (mod - mod / i) * inv[mod % i] % mod;
    }
 
    int where = (mod - 1) / 2, invg = raise(G, mod - 2);
    int idx = 0;
    while (where % 2 == 0) {
        idx++;
        gpow[idx] = raise(G, where);
        invgpow[idx] = raise(invg, where);
        where /= 2;
    }
}
 
int nCr(int x, int y)
{
    if(y>x)
        return 0;
    int num=fact[x];
    num*=invfact[y];
    num%=mod;
    num*=invfact[x-y];
    num%=mod;
    return num;
}
 
void ntt(int *a, int n, int sign) {
    for (int i = n >> 1, j = 1; j < n; j++) {
        if (i < j) swap(a[i], a[j]);
        int k = n >> 1;
        while (k & i) {
            i ^= k;
            k >>= 1;
        }
        i ^= k;
    }
    for (int l = 2, idx = 1; l <= n; l <<= 1, idx++) {
        int omega = (sign == 1) ? gpow[idx] : invgpow[idx];
        for (int i = 0; i < n; i += l) {
            int value = 1;
            for (int j = i; j < i + (l>>1); j++) {
                int u = a[j], v = a[j + (l>>1)] * value % mod;
                a[j] = (u + v); a[j] = (a[j] >= mod) ? a[j] - mod : a[j];
                a[j + (l>>1)] = (u - v); a[j + (l>>1)] = (a[j + (l>>1)] < 0) ? a[j + (l>>1)] + mod : a[j + (l>>1)];
                value = value * omega % mod;
            }
        }
    }
    if (sign == -1) {
        const int x = raise(n, mod - 2);
        for (int i = 0; i < n; i++) {
            a[i] = a[i] * x % mod;
        }
    }
}
 
void multiply(int* a, int na, int* b, int nb) {
    na++; nb++;
    int n = 1; while (n < na + nb - 1) n <<= 1;
    for (int i = na; i < n; i++) {
        a[i] = 0;
    }
    for (int i = nb; i < n; i++) {
        b[i] = 0;
    }
 
    ntt(a, n, +1); ntt(b, n, +1);
    for (int i = 0; i < n; i++) {
        a[i] = a[i] * b[i] % mod;
    }
    ntt(a, n, -1);
    for (int i = na + nb - 1; i < n; i++) {
        a[i] = 0;
    }
}
 
int a[MAXN],b[MAXN],ans[MAXN];
 
int32_t main() {
    init();
 
    int t;
    cin >> t;
 
    while(t--)
    {
        int n,x;
        cin >> n;
 
        int pos,neg,zero;
        pos = neg = zero = 0;
 
        for(int i=1;i<=n;i++)
        {
            cin >> x;
 
            if(x == 1) pos++;
            else if(x == 0) zero++;
            else neg++;
        }
 
        for(int i=0;i<=pos;i++)
            a[i] = nCr(pos,i);
 
        for(int i=0;i<=neg;i++)
            b[i] = nCr(neg,i);
 
        multiply(a, pos, b, neg);
 
        for(int i=0;i<=2*n;i++)
        	ans[i] = 0;
 
        for(int i=0;i<=pos+neg;i++)
        	ans[n-neg+i] = a[i];
 
        zero = raise(2, zero);
 
        for(int i=0;i<=2*n;i++)
        {
        	ans[i] = zero * ans[i] %mod;
 
        	if(i == n)
                ans[i] = (ans[i] - 1 + mod)%mod;
 
        	cout << ans[i] << " ";
        }
        cout << endl;
    }
 
    return 0;
}
Tester's Solution
import java.util.*;
import java.io.*; 
import java.text.*;
//Solution Credits: Taranpreet Singh
public class Main{
    //SOLUTION BEGIN
    long MOD = (long)163577857;
    void pre(){}
    void solve(int TC) throws Exception{
        int n = ni();
        int zero = 0, pos = 0, neg = 0;
        for(int i = 0; i< n; i++){
            int x = ni();
            if(x == -1)neg++;
            else if(x == 0)zero++;
            else pos++;
        }
        long F = pow(2, zero);
        int row = pos+neg;
        long prod = 1;
        for(int i = -n; i<= n; i++){
            long x = 0;
            if(i >= -neg && i <= pos){
            	x = prod;
            	prod = prod*(row-(i+neg))%MOD;
            	prod = (prod*pow(i+neg+1, MOD-2))%MOD;
            	
            }
            x = (x*F)%MOD;
            if(i == 0)x = (x+MOD-1)%MOD;
            p(x+" ");
        }
        pn("");
    }
    long pow(long a, long p){
        long o = 1;
        for(;p>0;p>>=1){
            if((p&1)==1)o = (o*a)%MOD;
            a = (a*a)%MOD;
        }
        return o;
    }
    //SOLUTION END
    long mod = (long)998244353, IINF = (long)1e17;
    final int MAX = (int)1e3+1, INF = (int)2e9, root = 3;
    DecimalFormat df = new DecimalFormat("0.0000000000000");
    double PI = 3.1415926535897932384626433832792884197169399375105820974944, eps = 1e-8;
    static boolean multipleTC = true, memory = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        int T = (multipleTC)?ni():1;
        //Solution Credits: Taranpreet Singh
        pre();
        for(int i = 1; i<= T; i++)solve(i);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
        else new Main().run();
    }
    long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
    int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n(){return in.next();}
    String nln(){return in.nextLine();}
    int ni(){return Integer.parseInt(in.next());}
    long nl(){return Long.parseLong(in.next());}
    double nd(){return Double.parseDouble(in.next());}
 
    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }
 
        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }
 
        String next(){
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }
 
        String nextLine(){
            String str = "";
            try{    
                str = br.readLine();
            }catch (IOException e){
                e.printStackTrace();
            }   
            return str;
        }
    }
}  
4 Likes

can someone explain it more clearly the value of pi is equivalent to choosing (c1−i) items from the bag of c items?

2 Likes

Suppose you have
p=count of 1;
q=count of -1;
z=count of 0;

Suppose you want to make sum k>0 (for k<0 process is similar).
Let’s see the choices we have,
!st choice -> take k 1’s and 0 -1’s number of ways to do this = pCk * qC0;
2nd choice -> take k+1 1’s and 1 -1’s number of ways to do this = [ pC(k+1) ] * [qC1]
3rd choice -> take k+2 1’s and 2 -1’s number of ways to do this = [ pC(k+2) ] * [qC2]

and so on.

This sum turns out to be (p+q)C(q+k) using binomial expansion of (1+x)^p and (1+1/x)^q

Also with all these choices you have the choice to take any number of zeros
therefore multiply power(2,z) in each term and take the sum modulo m;

For case of zero we can’t have empty subsequence so -1 for the case of zero.

https://www.codechef.com/viewsolution/32367090

8 Likes

@zacky_901 Could you please elaborate on this. Thanks!

Another (maybe more intuitive) way to derive the answer:

Let cnt_{-1}, cnt_{-0} and cnt_{1} represent the number of occurrences of -1, 0, and 1 respectively.

Now suppose you wish to calculate the number of ways to obtain a positive sum = k (similar approach can be used for a negative sum). Then the number of ways to do so would be to take k amount of 1's and then out of the remaining 1’s (i.e. cnt_{1} - k), for each 1 you take, you have to also take a -1 (so that the sum remains k). Also, we can add in extra 0's since they do not effect the sum.

This can be expressed as:

ways[k] = 2^{cnt_{-0}} * \sum_{i=0} \binom{cnt_{1}}{k + i} *\binom{cnt_{-1}}{i}

\quad \quad \quad \space \space \space= 2^{cnt_{-0}} * \sum_{i=0} \binom{cnt_{1}}{cnt_{1} - k - i} *\binom{cnt_{-1}}{i}

Now, we can use Vandermonde’s identity to get:

ways[k] = 2^{cnt_{-0}} * \binom{cnt_{1} + cnt_{-1}}{cnt_{1} - k}

33 Likes

Thanks bro. That helped. Wasn’t aware of this identity. Had come up with summation expression but wasn’t able to simplify it. Are you aware of more questions which use this identity?

1 Like

I cannot recall any questions which use this identity but googling “codechef vandermonde’s identity” revealed that GMEDIAN’s editorial mentioned this identity.

1 Like

Hey! I have a doubt. The limit of i in summation ranges from 0 to min(cnt1-k,cnt-1) if i am not wrong. So my question is wont it effect the identity? bcz in identity upper limit is necessarily needed to be cnt1-k

1 Like

@m0nk3ydluffy Just for curiosity, what came to your mind before applying nCr = nCn-r identity, cause I was stuck on that step and I wanted a closed-form of summation but I was not able to get it. Maybe because I had never used Vandermonde’s Identity before, but after reading Wikipedia article I feel this identity is very intuitive especially after reading combinatorial proof in the given article. So what are your thoughts, what should we do in such cases?

Value of nCr is zero when r<0 or r>n. This fact preserves the mathematical soundness of all combinatorial identities. So you could very well change the range of i in that summation, it won’t make the identity incorrect :slight_smile: .

3 Likes

Oh yeah i got that thanx!

Hey!. As @anon49376339 said, we generally define binomial coefficient \binom{n}{m} to be equal to zero if m is not in the valid range (valid range being 0 <= m <= n). So we do not have to care about the upper bound in the identity (hence I did not specify any upper bound in the summation). In fact, the identity would be valid even if we sum over all the integers in [-\infin, \infin] (as we have defined the binomial coefficient to be zero for negative m)

1 Like

Hey! I guess you mean “nCr = nC(n - r) identity”. You can read https://trans4mind.com/personal_development/mathematics/series/summingBinomialCoefficients.htm. It has some interesting tricks which might be helpful in the future.

1 Like

Yeah! I got that thanx

If anyone was wondering about primitive roots for NTT, I found some to be 18 for 2^21, and 55 for 2^22

2 Likes

@m0nk3ydluffy Yes , my bad !

Can anybody tell me why I am getting WA
https://www.codechef.com/viewsolution/32375114
I have applied a same approach.

I used same approach as you but with modulo inverse. Can you tell me how to speed up the solution, my solution passed but took around 1 second. I stored factorials upto 1e5 and modinverse of factorials upto 1e5. I guess complexity of precomputation should be O(N\log(P)) where P is the prime, and complexity in finding answer will be O(N) still it takes 1000ms.

How should I improve complexity, I suppose the computation of moduloinverse is taking a lot of time.

Here is submission: https://www.codechef.com/viewsolution/32380636

You can precompute in O(n + \log(p))
If you understand modular arithmetic, why this works should be obvious.

vector<ll> fact;
vector<ll> invfact;
void computefactorial(int n){
    fact.resize(n);
    invfact.resize(n);
    fact[0]=1;
    for(int i=1;i<n;i++){
        fact[i]=i*fact[i-1];
        fact[i]%=p;
    }
    invfact[n-1]=modpower(fact[n-1],p-2);
    for(int i=n-2;i>=0;i--){
        invfact[i]=(i+1)*invfact[i+1];
        invfact[i]%=p;
    }
}
1 Like

Multiply those polynomials and write down the coefficient of x^K in the product,
[(1+x)^(p+q)]/[x^q].
This coefficient will be the same as the summation of terms above.