GOOD_NESS - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shubham_grg
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

SOS DP

PROBLEM:

You have two arrays A and B of length N.
For a fixed parameter X, a set of indices is said to be X-good if the bitwise OR of the corresponding elements of A doesn’t exceed K.
The goodness of an X-good set of indices is the sum of the corresponding values in B.

For each X from 1 to K, find the sum of goodness across all possible X-good subsets.

EXPLANATION:

Let’s redefine things to be a bit stricter: we say a subset of indices is X-good if the bitwise OR of the corresponding elements in A equals X.
If we’re able to compute the sum of goodness for all X under this definition, the original problem’s answer is obtained by just taking prefix sums.

Now, we fix X and see what we get.
For any set \{i_1, i_2, \ldots, i_k\} of indices that should be X-good, note that we must have each A_{i_j} be a submask of X.

So, let’s take every index i such that A_i is a submask of X.
Any subset of these indices will give something that’s Y-good for some submask Y of X.
This observation gives us a means of tackling the task: we can compute the sum of goodness across all subsets of these indices, then later subtract out the sums corresponding to strict submasks of X.

We now have two things to be done fast: compute the sum of goodness across all subsets of the chosen indices, and once we have those, subtract out stuff from smaller masks.


To compute the sum of goodness across all subsets, we use the following observation: given a (multi-)set S = \{x_1, x_2, \ldots, x_N\}, the sum of sums of all its sub-(multi-)sets is given simply by

2^{N-1} \cdot \left(x_1 + x_2 + \ldots + x_N \right)

This is because there are exactly 2^{N-1} subsets containing each x_i.
(This logic fails for the N = 0 case, but the formula works out there nonetheless since the sum of elements is 0.)

So, suppose we’re able to compute, for each X, the values s_X and c_X, where:

  • c_X is the number of indices i such that A_i is a submask of X.
  • s_X is the sum of B_i across all indices i such that A_i is a submask of X.

The value we’re looking for then is exactly s_X \cdot 2^{c_X - 1}.

Computing c_X and s_X for every X from 0 to K is a classical problem, and can be done in \mathcal{O}(K\log K) time with the help of SOS DP.


Now that we know c_X and s_X, define T_X := s_x \cdot 2^{c_X - 1} to be the sum of goodness of all subsets that are Y-good for some submask Y of X.

Let \text{ans}_X denote the sum of goodness of all subsets that are X-good.
We know all the T_X values, and we want to get \text{ans}_X values from them.

This can also be done using SOS DP!
Notice that T_X is just the sum-of-subsets of \text{ans}_X, so to get the \text{ans} array, all you need to do is reverse the process.
(One rather nice way of looking at SOS DP is as multidimensional prefix sums, which lets you derive modifications to it - like working with supersets rather than subsets or inverting the process - quite easily. See this blog.)

Finally, as noted at the start, take the prefix sums of \text{ans} to get the actual answer.

TIME COMPLEXITY:

\mathcal{O}(N + K\log K) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

typedef long long int ll;
 
#define int                 ll
#define fast                ios::sync_with_stdio(0),cin.tie(0), cout.tie(0);
#define endl                "\n"

template<typename T> istream& operator>>(istream& is,  vector<T>  &v){ for(auto& i : v) is >> i; return is;}
template<typename T> ostream& operator<<(ostream& os,  vector<T>  v){ for(auto& i : v) os << i << ' '; return os;}

const int MOD= 1e9+7, inf=INT_MAX, inff=INT_MIN;
const int N=(1e6)+5;
ll expo(ll a, ll b)   {ll res = 1; a%=MOD; while (b > 0) {if (b & 1)res = (res * a) % MOD; a = (a * a) % MOD; b = b >> 1;} return res;}

vector<vector<int>>temp(N, vector<int>(20));
vector<int>cnt(N), sum(N), exact(N), ans(N);

void Solve()
{
    int n, k; cin>>n>>k;
    vector<int> a(n), b(n); cin>>a>>b;

    for(int i=0; i<=k; i++) 
    {
        cnt[i]=sum[i]=exact[i]=ans[i]=0;
        for(int j=0; j<20; j++) temp[i][j]=0;
    }

    for(int i=0; i<n; i++)
    {
        cnt[a[i]]++;
        sum[a[i]]+=b[i];
    }

    for(int i=0; i<20; i++)
    {
        for(int mask=0; mask<=k; mask++)
        {
            if((mask>>i)&1)
            {
                cnt[mask]+=cnt[mask^(1<<i)];
                sum[mask]+=sum[mask^(1<<i)];
                sum[mask]%=MOD;
            }
        }
    }

    int bits=20;

    for(int mask=1; mask<=k; mask++)
    {
        for(int i=0; i<bits; i++)
        {
            temp[mask][i+1]=temp[mask][i-1+1];
            if((mask>>i)&1)
            {
                temp[mask][i+1]+=temp[mask^(1<<i)][i-1+1];
                temp[mask][i+1]%=MOD;
            }
        }
        int ways=(cnt[mask]?(sum[mask]*expo(2, cnt[mask]-1)):0);
        exact[mask]=(ways%MOD)-temp[mask][bits];
        exact[mask]=(exact[mask]+MOD)%MOD;
        temp[mask][0]=exact[mask];

        for(int i=0; i<bits; i++)
        {
            temp[mask][i+1]=temp[mask][i-1+1];
            if((mask>>i)&1)
            {
                temp[mask][i+1]+=temp[mask^(1<<i)][i-1+1];
                temp[mask][i+1]%=MOD;
            }
        }
    }

    int tot=0;
    for(int i=1; i<=k; i++)
    {
        tot+=exact[i];
        ans[i-1]=(tot%=MOD);
    }

    for(int i=0; i<k; i++)
    {
        cout<<ans[i]<<(i==k-1?"\n":" ");
    }
}   

signed main()
{ 
    fast
    
    int T=1;
    cin >> T;
   
    while (T--)
    {
        Solve();  
    } 
        #ifndef ONLINE_JUDGE
    cerr<<"\ntime taken : "<<(float)clock()/CLOCKS_PER_SEC<<" secs"<<"\n";
    #endif
 
    return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

// Count sum of B, and count of A in the submask of i 
// Then, we can get sum of good subsequences with OR = submask of i 
// Inverse SOS dp to get back OR = exactly i 
// prefix sums

int n, k;
const int N = 2e5 + 69;
const int K = 2e6 + 69;
const int mod = 1e9 + 7;
int sum[K], cnt[K], f[K], dp[K];

void Solve() 
{
    cin >> n >> k;
    int og = k;
    int pp = 1;
    int h = 0;
    while (pp <= k){
        pp *= 2;
        h++;
    }
    
    k = pp;
    
    vector <int> a(n), b(n);
    for (auto &x : a) cin >> x;
    for (auto &x : b) cin >> x;
    
    for (int i = 0; i < k; i++){
        cnt[i] = f[i] = dp[i] = sum[i] = 0;
    }
    
    for (int i = 0; i < n; i++){
        cnt[a[i]]++;
        sum[a[i]] += b[i];
        if (sum[a[i]] >= mod) sum[a[i]] -= mod;
    }
    
    for (int i = 0; i < k; i++){
        f[i] = cnt[i];
    }
    for (int i = 0; i < h; i++) for (int mask = 0; mask < k; mask++){
        if (mask >> i & 1){
            f[mask] += f[mask ^ (1 << i)];
            if (f[mask] >= mod) f[mask] -= mod;
        }
    }
    
    for (int i = 0; i < k; i++){
        cnt[i] = f[i];
        f[i] = sum[i];
    }
    
    for (int i = 0; i < h; i++) for (int mask = 0; mask < k; mask++){
        if (mask >> i & 1){
            f[mask] += f[mask ^ (1 << i)];
            if (f[mask] >= mod) f[mask] -= mod;
        }
    }
    
    for (int i = 0; i < k; i++){
        sum[i] = f[i];
    }
    
    vector <int> p2(n + 1, 1);
    for (int i = 1; i <= n; i++) p2[i] = p2[i - 1] * 2 % mod;
    
    for (int i = 0; i < k; i++){
        if (cnt[i])
        dp[i] = sum[i] * p2[cnt[i] - 1] % mod;
        else dp[i] = 0;
        
        f[i] = dp[i];
    }
    
    for (int i = h - 1; i >= 0; i--) for (int mask = k - 1; mask >= 0; mask--){
        if (mask >> i & 1){
            f[mask] -= f[mask ^ (1 << i)];
            if (f[mask] < 0) f[mask] += mod;
        }
    }
    
    for (int i = 1; i <= og; i++){
        f[i] += f[i - 1];
        f[i] %= mod;
        
        cout << f[i] << " \n"[i == og];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
import sys
input = sys.stdin.readline
mod = 10**9 + 7

for test in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))

    dim, m = 0, 1
    while m <= k:
        m *= 2
        dim += 1
    
    freq, base_val = [0]*m, [0]*m
    for i in range(n):
        freq[a[i]] += 1
        base_val[a[i]] += b[i]

    sub_ct, sub_sm = freq[:], base_val[:]
    for i in range(dim):
        for mask in range(m):
            if mask & 2**i:
                sub_sm[mask] += sub_sm[mask ^ 2**i]
                sub_ct[mask] += sub_ct[mask ^ 2**i]
    
    all_sm = [0]*m
    for i in range(m):
        if sub_ct[i] > 0:
            all_sm[i] = sub_sm[i] * pow(2, sub_ct[i]-1, mod) % mod
    
    ans = all_sm[:]
    for i in range(dim):
        for mask in range(m):
            if mask & 2**i:
                ans[mask] -= ans[mask ^ 2**i]
    
    for i in range(m):
        if i > 0: ans[i] += ans[i-1]
        ans[i] %= mod
    print(*ans[1:k+1])
2 Likes

I just noticed that Goodness Over Good shortens to GOG, similar to SOS. Was it an intentional hint? xD

1 Like

Unfortunately not - the original proposal had the same name but didn’t use SOS DP in its solution at all :slightly_smiling_face:

1 Like