CODEBALLS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yeamin_kaiser
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

NTT, prefix sums

PROBLEM:

There are N balls on a table. The i-th of them has color C_i.
Answer Q queries:

  • Given L and R, find the number of ways to remove a non-empty subset of balls from the table such that there are between L and R distinct colors among the remaining balls.

EXPLANATION:

Let’s first try tackling a single query.
Instead of looking at the subset of removed balls, let’s look at the subset that remains.
For convenience, let’s also compress the colors to 1, 2, 3, \ldots, N.

Let f_x denote the number of balls with color x.
Then, there’s a relatively straightforward (though slow) dynamic programming approach.
Let dp_{i, j} denote the number of ways of choosing a subset of balls with colors 1, 2, 3, \ldots, i such that there are j distinct colors among them.
Then, we have:

dp_{i, j} = dp_{i-1, j} + (2^{f_i}-1)\cdot dp_{i-1, j-1}

This is because we have two choices for color i: don’t use any balls of this color (giving dp_{i-1, j-1} ways), or use some non-empty subset of balls of this color.

For a query (L, R), the answer is simply the sum (dp_{N, L} + dp_{N, L+1} + \ldots + dp_{N, R}).


Of course, this is too slow: the dynamic programming itself takes \mathcal{O}(N^2) time.
However, note that it only needs to be done once: for every query, the dp table remains the same, we just happen to want the sum of various subarrays of dp_N.

So, all we need to do is figure out how to compute the array dp_N fast enough.
For that, we turn to polynomials (specifically, convolution).

Observe that the array dp_i is obtained from the array dp_{i-1} in a rather specific way: dp_{i, j} depends on dp_{i-1, j} and dp_{i-1, j-1}, but only on constant multiples of each of them (and that constant is independent of j).
In such a situation, the dp transition can be modeled using polynomial multiplication.

How?

Specifically, consider two arrays A (of length i) and B (of length 2), where
A = dp_{i-1} and B = [1, 2^{f_i}-1].
For convenience, we treat these array as 0-indexed.

Using them, we can define polynomials a and b, where

a(x) = \sum_{j=0}^{i-1} A_j x^j

and b(x) = B_0 + B_1 x (which is the same way a is defined, just that b has only two terms).

Consider c(x) = a(x)\cdot b(x), the product of these polynomials.
In particular, look at the coefficient of x^j of c: it will be exactly A_j\cdot B_0 + A_{j-1}\cdot B_1, since that’s how polynomial multiplication works.

However, note that based on the definitions of arrays A and B, this exactly equals dp_{i, j}.
In other words, the coefficients of the product c(x) are exactly the values of the dp_i array!

It’s well-known that polynomial multiplication of two arrays of length N can be performed in \mathcal{O}(N\log N) time with FFT (here, we’re working under a nice modulo so we use the Number Theoretic Transform instead).

Simply plugging this into our dp transitions gives us a complexity of \mathcal{O}(N^2 \log N), since we do it N times!
…which is somehow worse than the simple \mathcal{O}(N^2) we started with.

To optimize this further, we need to utilize the fast that we’re actually multiplying polynomials.
Let p_i(x) = 1 + (2^{f_i}-1) x denote the polynomial corresponding to color i.
As noted at the start, we don’t really care about the intermediate dp_{i, j} values: we only want dp_{N}.

That is, we only care about the final result, which is the product
p_1(x)\cdot p_2(x)\cdot\ldots\cdot p_N(x)

This can be computed with divide-and-conquer faster than multiplying them one at a time.
Specifically, let f(L, R) denote the result of multiplying the L-th through the R-th polynomials.
Let M denote the midpoint of the range [L, R].
Then, f(L, R) = f(L, M)\cdot f(M+1, R), where the multiplication is polynomial multiplication, of course.

If implemented properly, this has a time complexity of \mathcal{O}(N\log^2 N).
That’s because in our case, the product of the L-th through the R-th polynomial has degree R-L+1.
So, when looking at the recursion tree, at each ‘level’ the total work done is \mathcal{O}(N\log N), since there could be several polynomial multiplications on this level but the sum of degrees across them all is bounded by N.
The divide-and-conquer ensures that there are \mathcal{O}(\log N) levels, which gives us our complexity bound.


Finally, once dp_N is obtained, answering queries is easy.
As noted at the start, each query is asking for the sum of some subarray of dp_N. This can be answered in constant time with the help of prefix sums.

TIME COMPLEXITY:

\mathcal{O}(N\log^2 N + Q) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define ll long long

const int mod = 998244353;
const int root = 15311432;
const int k = 1 << 23;
int root_1;
vector<int> rev;

ll bigmod(ll base, ll exp, ll mod){
    if(!exp) return 1;
    ll ret=bigmod(base,exp/2,mod);
    ret=(ret*ret)%mod;
    if(exp&1) ret=(ret*base)%mod;
    return ret;
}

void pre(int sz){
    root_1 = bigmod(root, mod - 2, mod);
    if (rev.size() == sz) return;
    rev.resize(sz);
    rev[0] = 0;
    int lg_n = __builtin_ctz(sz);
    for (int i = 1; i < sz; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg_n-1));
}

void fft(vector<int> &a, bool inv){
    int n = a.size();
    for (int i = 1; i < n - 1; ++i) if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int len = 2; len <= n; len <<= 1) {
        int wlen = inv ? root_1 : root;
        for (int i = len; i < k; i <<= 1) wlen = 1ll * wlen * wlen % mod;
        for (int st = 0; st < n; st += len){
            int w = 1;
            for (int j = 0; j < len / 2; j++){
                int ev = a[st + j];
                int od = 1ll * a[st + j + len / 2] * w % mod;
                a[st + j] = ev + od < mod ? ev + od : ev + od - mod;
                a[st + j + len / 2] = ev - od >= 0 ? ev - od : ev - od + mod;
                w = 1ll * w * wlen % mod;
            }
        }
    }
    if (inv){
    int n_1 = bigmod(n, mod - 2, mod);
    for (int &x : a) x = 1ll * x * n_1 % mod;
    }
}
vector<int> mul(vector<int> &a, vector<int> &b){
    int n = a.size(), m = b.size(), sz = 1;
    while (sz < n + m - 1) sz <<= 1;
    vector<int> x(sz), y(sz), z(sz);
    for (int i = 0; i < sz; ++i){
        x[i] = i < n ? a[i] : 0;
        y[i] = i < m ? b[i] : 0;
    }
    pre(sz);fft(x, 0);fft(y, 0);
    for (int i = 0; i < sz; ++i) z[i] = 1ll * x[i]* y[i] % mod;
    fft(z, 1);z.resize(n + m - 1);
    return z;
}

vector<vector<int>> coef;

void calc(int l, int r){
    if(l==r) return;
    int m=(l+r)/2;
    calc(l,m);
    calc(m+1,r);
    coef[l]=mul(coef[l],coef[m+1]);
}


int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int n;
    cin>>n;
    map<int,int> mp;
    for(int i=0;i<n;i++){
        int x;
        cin>>x;
        mp[x]++;
    }
    
    int sz=0;
    coef.assign(n,{});
    for(auto it:mp){
        coef[sz].push_back(1);
        coef[sz].push_back((bigmod(2,it.second,mod)-1+mod)%mod);
        sz++;
    }
    int sss=sz;
    if(sz>1) calc(0,sz-1);
    while(coef[0].size()<=n) coef[0].push_back(0);


    vector<ll> s(n+1,0);

    for(int i=0;i<=n;i++){
        s[i]=coef[0][i];
        if(i) s[i]=(s[i]+s[i-1])%mod;
    }


    int q;
    cin>>q;
    while(q--){
        int l,r;
        cin>>l>>r;

        ll ans=s[r];
        ans-=s[l-1];
        ans+=mod;
        ans%=mod;

        // if(l<=sss) cout<<"HAHAHAHA\n";

        cout<<ans<<"\n";
    }
    
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
// #define mod 998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=300005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

const int mod=998244353;
// 998244353=1+7*17*2^23 : g=3
// 1004535809=1+479*2^21 : g=3
// 469762049=1+7*2^26 : g=3
// 7340033=1+7*2^20 : g=3
// For below change mult as overflow:
 // 10000093151233=1+3^3*5519*2^26 : g=5
 // 1000000523862017=1+10853*1373*2^26 : g=3
 // 1000000000949747713=1+2^29*3*73*8505229 : g=2
// For rest find primitive root using Shoup's generator algorithm
// root_pw: power of 2 >= maxn, Mod-1=k*root_pw => w = primitive^k 
template<long long Mod,long long root_pw,long long primitive>
struct NTT{
 inline long long powm(long long x,long long pw){
  x%=Mod;
  if(abs(pw)>Mod-1) pw%=(Mod-1);
  if(pw<0) pw+=Mod-1;
  ll res=1;
  while(pw){
   if(pw&1LL) res=(res*x)%Mod;
   pw>>=1;
   x=(x*x)%Mod;}
  return res;}
 inline ll inv(ll x){
     return powm(x,Mod-2); }
 ll root,root_1;
 NTT(){
  root=powm(primitive,(Mod-1)/root_pw);
  root_1=inv(root);}
 void ntt(vector<long long> &a,bool invert){
  int n=a.size();
  for(long long i=1,j=0;i<n;i++){
   long long bit=n>>1;
   for(;j&bit;bit>>=1) j^=bit;
   j^=bit;
   if(i<j) swap(a[i],a[j]);}
  for(long long len=2;len<=n;len<<=1){
   long long wlen= invert ? root_1:root;
   for(long long i=len;i<root_pw;i<<=1) wlen=wlen*wlen%Mod;
   for(long long i=0;i<n;i+=len){
    long long w=1;
    for(long long j=0;j<len/2;j++){
     long long u=a[i+j],v=a[i+j+len/2]*w%Mod;
     a[i+j]= u+v<Mod ? u+v:u+v-Mod;
     a[i+j+len/2]= u-v>=0 ? u-v:u-v+Mod;
     w=w*wlen%Mod;}}}
  if(invert){
   ll n_1=inv(n);
   for(long long &x: a) x=x*n_1%Mod;}}
 vector<long long> multiply(vector<long long> const& a,vector<ll> const& b){
  vector<long long> fa(a.begin(),a.end()),fb(b.begin(),b.end());
  int n=1;
  while(n<a.size()+b.size()) n<<=1;
  point(fa,1,n);
  point(fb,1,n);
  for(int i=0;i<n;++i) fa[i]=fa[i]*fb[i]%Mod;
  coef(fa);
  return fa;}
 void point(vector<long long> &A,bool not_pow=1,int atleast=-1){
  if(not_pow){
   if(atleast==-1){
    atleast=1;
    while(atleast<A.size()) atleast<<=1;}
   A.resize(atleast,0);}
  ntt(A,0);}
 void coef(vector<long long> &A,bool reduce=1){
  ntt(A,1);
  if(reduce) while(A.size() and A.back()==0) A.pop_back(); }
 void point_power(vector<long long> &A,long long k){
  for(long long &x: A) x=powm(x,k);}
 void coef_power(vector<long long> &A,int k){
  while(A.size() and A.back()==0) A.pop_back();
  int n=1;
  while(n<k*A.size()) n<<=1;
  point(A,1,n);
  point_power(A,k);
  coef(A);}
 vector<long long> power(vector<long long> a,ll p){
  while(a.size() and a.back()==0) a.pop_back();
  vector<long long> res;
  res.pb(1);
  while(p){
   if(p&1) res=multiply(res,a);
   a=multiply(a,a);
   p/=2;}
  return res;}};

// Input Checker

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readIntVec(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
			else readEoln();
		}
		return v;
	}

	auto readLongVec(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
			else readEoln();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

int32_t main()
{
    IOS;
    input_checker inp;
    NTT<mod,1<<20,3> ntt;
    int n = inp.readInt(1, 1e5); inp.readEoln();
    vector <int> a1=inp.readIntVec(n, 1, 1e9);
    mii mp;
    for(auto a:a1)
    {
        mp[a]++;
    }
    vi pols[n];
    set <pii> st;
    int cur=0;
    for(auto it:mp)
    {
        pols[cur]={1, power(2, it.ss, mod) - 1};
        st.insert({2, cur++});
    }
    while(st.size()>1)
    {
        pii p1=*st.begin();
        st.erase(p1);
        pii p2=*st.begin();
        st.erase(p2);
        vi temp=ntt.multiply(pols[p1.ss], pols[p2.ss]);
        pols[p1.ss].clear();
        pols[p2.ss]=temp;
        st.insert({temp.size(), p2.ss});
    }
    vi fin=pols[(*st.begin()).ss];
    int ni=n;
    n=fin.size();
    rep(i,1,n)
    {
        fin[i]=(fin[i]+fin[i-1])%mod;
    }
    int q = inp.readInt(1, 1e5); inp.readEoln();
    while(q--)
    {
        int l = inp.readInt(1, ni); inp.readSpace();
        int r = inp.readInt(l, ni); inp.readEoln();
        cout<<(fin[min(r, n-1)] - fin[min(n-1, l-1)] + mod)%mod<<"\n";
    }
    inp.readEof();
}
1 Like