RANDKNAP - Editorial

PROBLEM LINK:

Division 1
Division 2
Video Editorial

Author: Ildar Gainullin
Tester: Radoslav Dimitrov
Editorialist: Srikkanth

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Meet in the middle

PROBLEM:

You are given an array A of n = 240 integers randomly chosen from [0, M) where M = 998244353

Process Q queries of the following type,

Given an integer X \in [0, M), find a subset of A whose sum modulo M is equal to X.

EXPLANATION:

There are 240 integers and 2^{240} subsets, whereas the number of distinct values of subsets sums
modulo M, is only 998244353.

Since the numbers are taken at random, there is a good chance that most values of X can be obtained.

The minimum number of integers you need, to generate all M values is 30, i.e. \log M. One such example is the array {2^0, 2^1, ... 2^{29}}. We have 8 \log M integers and the number of possible sets whose subset sums contain all values from [0, M) increases rapidly with array size.

Suppose there are S subsets of size 30 which generate a particular value then there are
atleast 30 k ^ M subsets of size 30 + k. (Fill the first 30 values, and the remaining k can be anything)

Further that is a very lose bound, there are permutations of the 30 + k values to be considered and contribution of other integers as well.

So there is a very high probability that all values X can be generated using 240 integers.

This leads us to try the following strategy, enumerate a some number of subsets and hopefully there is a reasonable chance that we end up at X.

We don’t have enough time to enumerate 2^{30} subset, but we can use ideas like meet in the middle to cover as many subsets as possible.

Let’s split the array into two and enumerate as many subsets as we can in the first half and as many as we can in the second half.

A reasonable amount is 2^{20}. We can combine every pair of subsets in the two halves to generate other subsets.

Totally we have 2^{40} subsets, and we can be confident that there exists a pair with sum equal to X.

Now our task is given two arrays A, B find a pair, one belonging to A and other belonging to B that adds to X.
This is a well known problem and can be done by storing A in a map and iterating through B, or sorting both
arrays and using a two pointer method.

This is only enough for the first subtask, (we have to iterate through 2^20 numbers to find a suitable sum for every X).

How to process multiple queries?

In the above method, note that X, was not randomised.

We could have specific values of X.

If we had non-overlapping subsets that had sums 2^0, 2^1, ... 2^{29}, then we can easily for any number X by adding the corresponding powers of two in binary
representation of X!.

Notice that we used only 2^{40} subsets and were able to obtain (with very high probability) any given number X.

Let’s split the array into blocks of length 40.

We have six blocks, each can generate any integer that we want (w.h.p).

Suppose we evaluate s subsets in each block, we can get s^6 values.

We need atleast ~M^{\frac{1}{6}} which is ~32 different values in each block.

How to choose the 32 numbers?, We can think of the number X in base 32, and each
block as a digit.

For block i, we find subsets whose sum is d \ 32^i for d = 0, 1, ... 31 modulo M.

Now we can easily find subsets that evaluate to X in \mathcal{O}(1), find the digits in
base 32, and get the corresponding precalculated subsets and concatenate them.

TIME COMPLEXITY:

TIME: O(M^{\frac{1}{6}} 2^{\frac{n}{12}} + Q)
SPACE: O(2 ^ {\frac{n}{12}})

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int mod = 998244352;
 
int read_int();
 
int n, q, a[MAXN];
 
void read() {
	n = read_int();
	for(int i = 0; i < n; i++) {
		a[i] = read_int();
		a[i] %= mod;
	}
}
 
pair<uint64_t, uint64_t> prec[6][32];
pair<uint64_t, uint64_t> ans[2][1 << 15];
 
void solve() {
	int l = 0;
 
	clock_t st = clock();
	for(int gr = 0; gr < 6; gr++) {
		unordered_map<int, int> mp;
		const int LHS = 17;
 
		for(int mask = 0; mask < (1 << LHS); mask++) {
			int s = 0;
			for(int i = l; i < l + LHS; i++) {
				if(mask & (1 << (i - l))) {
					s += a[i];
				}
				
				mp[s] = mask;
			}
		}
 
		prec[gr][0] = {0, 0};
		for(int x = 1; x < 32; x++) {
			for(int mask = 0; mask < (1 << (40 - LHS)); mask++) {
				int s = 0;
				for(int i = l + LHS; i < l + 40; i++) {
					if(mask & (1 << (i - l - LHS))) {
						s += a[i];
					}
				}
				
				int need = ((x << (gr * 5)) - s + mod) % mod;
				auto it = mp.find(need);
				if(it != mp.end()) {
					prec[gr][x] = {it->second, mask};
					break;
				}
			}
		}
		
		l += 40;
	}
 
	//cerr << (clock() - st) / (double)CLOCKS_PER_SEC << endl;
 
	for(int x = 0; x < 32; x++) {
		prec[0][x].second <<= 20;
		prec[2][x].first <<= 20;
		prec[3][x].second <<= 20;
		prec[5][x].first <<= 20;
		
		prec[1][x].first <<= 40;
		prec[2][x].second <<= 40;
		prec[4][x].first <<= 40;
		prec[5][x].second <<= 40;
	}
 
	for(int i = 0; i < (1 << 15); i++) {
		int v0 = i & 31, v1 = (i >> 5) & 31, v2 = (i >> 10) & 31;
		ans[0][i].first = prec[0][v0].first | prec[0][v0].second | prec[1][v1].first;  
		ans[0][i].second = prec[1][v1].second | prec[2][v2].first | prec[2][v2].second;  
		ans[1][i].first = prec[3][v0].first | prec[3][v0].second | prec[4][v1].first;  
		ans[1][i].second = prec[4][v1].second | prec[5][v2].first | prec[5][v2].second;  
	}
 
	const int MASK_15 = (1 << 15) - 1;
 
	int q;
	q = read_int();
	while(q--) {
		int v;
		v = read_int();
		int l = v & MASK_15, r = v >> 15;
		cout << ans[0][l].first << " " << ans[0][l].second << " " << ans[1][r].first << " " << ans[1][r].second << endl;	
	}
	
	//cerr << (clock() - st) / (double)CLOCKS_PER_SEC << endl;
}
 
 
int main() {
	//freopen("4.in", "r", stdin);
	//freopen("4.out", "w", stdout);
 
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);
 
	read();
	solve();
	return 0;
}
 
const int maxl = 100000;
char buff[maxl];
int ret_int, pos_buff = 0;
 
void next_char() { if(++pos_buff == maxl) fread(buff, 1, maxl, stdin), pos_buff = 0; }
 
int read_int()
{
	ret_int = 0;
	for(; buff[pos_buff] < '0' || buff[pos_buff] > '9'; next_char());
	for(; buff[pos_buff] >= '0' && buff[pos_buff] <= '9'; next_char())
		ret_int = ret_int * 10 + buff[pos_buff] - '0';
	return ret_int;
} 
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'

#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back

using namespace std;
template<class T, class T1> int chkmin(T &x, const T1 &y) { return x > y ? x = y, 1 : 0; }
template<class T, class T1> int chkmax(T &x, const T1 &y) { return x < y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int mod = 998244353;

int read_int();

int n, q, a[MAXN];

void read() {
	n = read_int();
	for(int i = 0; i < n; i++) {
		a[i] = read_int();
		a[i] %= mod;
	}
}

pair<uint64_t, uint64_t> prec[6][32];
pair<uint64_t, uint64_t> ans[2][1 << 15];

void solve() {
	int l = 0;

	const int MASK_20 = (1 << 20) - 1;
	const int LHS = 17;

	//clock_t st = clock();
	for(int gr = 0; gr < 6; gr++) {
		unordered_map<int, int> mp;

		for(int mask = 0; mask < (1 << LHS); mask++) {
			int s = 0;
			for(int i = l; i < l + LHS; i++) {
				if(mask & (1 << (i - l))) {
					s += a[i];
					if(s >= mod) s -= mod;
				}
			}

			mp[s] = mask;
		}

		prec[gr][0] = {0, 0};
		for(int x = 1; x < 32; x++) {
			for(int mask = 0; mask < (1 << (40 - LHS)); mask++) {
				int s = 0;
				for(int i = l + LHS; i < l + 40; i++) {
					if(mask & (1 << (i - l - LHS))) {
						s += a[i];
						if(s >= mod) s -= mod;
					}
				}

				int need = ((x << (gr * 5)) - s + mod) % mod;
				auto it = mp.find(need);
				if(it != mp.end()) {
					uint64_t merged = ((uint64_t)mask << LHS) | it->second;
					prec[gr][x] = {merged & MASK_20, merged >> 20};
					break;
				}
			}
		}

		l += 40;
	}

	//cerr << (clock() - st) / (double)CLOCKS_PER_SEC << endl;

	for(int x = 0; x < 32; x++) {
		prec[0][x].second <<= 20;
		prec[2][x].first <<= 20;
		prec[3][x].second <<= 20;
		prec[5][x].first <<= 20;

		prec[1][x].first <<= 40;
		prec[2][x].second <<= 40;
		prec[4][x].first <<= 40;
		prec[5][x].second <<= 40;
	}

	for(int i = 0; i < (1 << 15); i++) {
		int v0 = i & 31, v1 = (i >> 5) & 31, v2 = (i >> 10) & 31;
		ans[0][i].first = prec[0][v0].first | prec[0][v0].second | prec[1][v1].first;  
		ans[0][i].second = prec[1][v1].second | prec[2][v2].first | prec[2][v2].second;  
		ans[1][i].first = prec[3][v0].first | prec[3][v0].second | prec[4][v1].first;  
		ans[1][i].second = prec[4][v1].second | prec[5][v2].first | prec[5][v2].second;  
	}

	const int MASK_15 = (1 << 15) - 1;

	int q;
	q = read_int();
	while(q--) {
		int v;
		v = read_int();
		int l = v & MASK_15, r = v >> 15;
		cout << ans[0][l].first << " " << ans[0][l].second << " " << ans[1][r].first << " " << ans[1][r].second << endl;	
	}

	//cerr << (clock() - st) / (double)CLOCKS_PER_SEC << endl;
}


int main() {
	//freopen("4.in", "r", stdin);
	//freopen("4.out", "w", stdout);

	ios_base::sync_with_stdio(false);
	cin.tie(nullptr);

	read();
	solve();
	return 0;
}

const int maxl = 100000;
char buff[maxl];
int ret_int, pos_buff = 0;

void next_char() { if(++pos_buff == maxl) fread(buff, 1, maxl, stdin), pos_buff = 0; }

int read_int()
{
	ret_int = 0;
	for(; buff[pos_buff] < '0' || buff[pos_buff] > '9'; next_char());
	for(; buff[pos_buff] >= '0' && buff[pos_buff] <= '9'; next_char())
		ret_int = ret_int * 10 + buff[pos_buff] - '0';
	return ret_int;
}

Editorialist's Implementation
#include<bits/stdc++.h>
 
using namespace std;
 
#define LL long long int
#define FASTIO ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
const int oo = 1e9 + 5;
const LL ooll = (LL)1e18 + 5;
// const int MOD = 1e9 + 7;
const int MOD = 998244353;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
#define rand(l, r) uniform_int_distribution<int>(l, r)(rng)

clock_t start = clock();

const int N = 1e5 + 5;

int failed = 0;

vector<pair<int,int>> compute_subsets(vector<int> v, vector<int> need) {
    int n = v.size(), nh = 18;
    vector<pair<int,int>> lef, rig;
    for (int i=0;i<(1<<nh);++i) {
        int sum = 0;
        for (int j=0;j<nh;++j) if (i & (1<<j)) {
            sum += v[j];
            if (sum >= MOD) sum -= MOD;
        }
        lef.push_back({sum, i});
    }

    for (int i=0;i<(1<<nh);++i) {
        int sum = 0;
        for (int j=0;j<nh;++j) if (i & (1<<j)) {
            sum += v[j+nh];
            if (sum >= MOD) sum -= MOD;
        }
        rig.push_back({sum, i});
    }

    sort(lef.begin(), lef.end());
    sort(rig.begin(), rig.end());
    
    vector<pair<int,int>> answer;
    for (auto it : need) {
        int lptr = 0, rptr = (int)rig.size()-1;
        bool ok = false;
        for (lptr = 0; lptr < (int)lef.size(); ++lptr) {
            while (rptr >= 0 && rig[rptr].first + lef[lptr].first > it) {
                --rptr;
            }
            if (rptr < 0) {
                break;
            }
            if (lef[lptr].first + rig[rptr].first == it) {
                ok = true;
                answer.push_back({lef[lptr].second, rig[rptr].second});
                break;
            }
        }
        if (ok) continue;
        rptr = (int)rig.size()-1;
        it += MOD;
        for (lptr = 0; lptr < (int)lef.size(); ++lptr) {
            while (rptr >= 0 && rig[rptr].first + lef[lptr].first > it) {
                --rptr;
            }
            if (rptr < 0) {
                break;
            }
            if (lef[lptr].first + rig[rptr].first == it) {
                ok = true;
                answer.push_back({lef[lptr].second, rig[rptr].second});
                break;
            }
        }
        if (!ok) {
            ++failed;
            return {};
        }
    }

    for (int i=0;i<32;++i) {
        LL mask = answer[i].first + (1LL<<18) * answer[i].second;
        int chk = 0;
        for (int j=0;j<40;++j) if ((mask >> j) & 1) {
            chk += v[j];
            if (chk >= MOD) chk -= MOD;
        }
        assert(chk == need[i]);
    }

    return answer;
}

void solve() {
    vector<vector<int>> A(6, vector<int>(40, 0));
    int n;
    cin >> n;
    for (int i=0;i<6;++i) for (int j=0;j<40;++j) {
        cin >> A[i][j];
        // A[i][j] = rand(0, MOD-1);
    }
    vector< vector< pair<int,int> > > pre_computed_sets(6);
    int go = 1;
    for (int i=0;i<6;++i) {
        vector<int> need(32);
        for (int j=0;j<32;++j) need[j] = (j * 1LL * go) % MOD;
        pre_computed_sets[i] = compute_subsets(A[i], need);
        go = (go * 1LL * 32) % MOD;
    }
    // assert(!failed);

    int q = 300000;
    cin >> q;
    while(q--) {
        int x, go = 1;
        cin >> x;
        // x = rand(0, MOD-1);
        int x_Test = x;
        vector<LL> masks(6);

        // int chk = 0;
        for (int i=0;i<6;++i) {
            int take = x % 32;
            masks[i] = pre_computed_sets[i][take].first + (1LL<<18) * pre_computed_sets[i][take].second;
            
            // int chk = 0;
            // for (int j=0;j<40;++j) if ((1LL<<j) & masks[i]) {
            //     chk += A[i][j];
            //     if (chk >= MOD) chk -= MOD;
            // }
            // assert(chk == (take * 1LL * go) % MOD);
            // go = (go * 32);
            x /= 32;
        }
        // assert(chk == x_Test);
        LL MASK = (1LL<<20) - 1;
        
        // LL m[4];
        // m[0] = (masks[0]) + ((masks[1] & MASK) << 40);
        // m[1] = (masks[1] >> 20) + (masks[2] << 20);
        // m[2] = (masks[3]) + ((masks[4] & MASK) << 40);
        // m[3] = (masks[4] >> 20) + (masks[5] << 20);

        // chk = 0;
        // for (int i=0;i<4;++i) {
        //     for (int j=0;j<60;++j) {
        //         int pos = i * 60 + j;
        //         int ii = pos / 40, jj = pos % 40;
        //         if ((m[i] >> j) & 1) {
        //             chk += A[ii][jj];
        //             if (chk >= MOD) chk -= MOD;
        //         }
        //     }
        // }
        // assert(chk == x_Test);

        cout << (masks[0]) + ((masks[1] & MASK) << 40) << " ";
        cout << (masks[1] >> 20) + (masks[2] << 20) << " ";
        cout << (masks[3]) + ((masks[4] & MASK) << 40) << " ";
        cout << (masks[4] >> 20) + (masks[5] << 20) << "\n";


    }
}

void getinput() {
    cout << 240 << '\n';
    for (int i=0;i<240;++i) cout << rand(0, MOD-1) << " ";
    cout << '\n';
    int q = 100;
    cout << q << '\n';
    for (int i=0;i<q;++i) {
        cout << rand(0, MOD-1) << '\n';
    }
}

int main() {
    FASTIO;
    // getinput();
    // return 0;
    // int T = 100000;
    int T = 1;
    // cin >> T;
    for (int t=1;t<=T;++t) {
        solve();
    }
    cerr << fixed << setprecision(10);
    cerr << "Time: " << (clock() - start) / ((long double) CLOCKS_PER_SEC) << " secs\n"; 
    return 0;
} 

2 Likes

I thought I could solve this using the traditional Meet in the middle method but with n = 240 => 2^120 didn’t seem to be feasible. Now I know. Thanks for the editorial

When you say we can obtain every number with a very high probability - this is kind of unclear. For example, numbers < 4M are more likely to not be representable by a subset sum. Sure, you can say 4M is a minority compared to 1B, but than you say “We have six blocks, each can generate any integer that we want (w.h.p).”, and we rely on this heavily - this seems either wrong, or at least not formulated in a right way.

I believe the approach works, but still having a hard time understanding it give the issue I just mentioned.

I think you’re assuming that we should form a subset whose sum is Xi but we have to form a subset whose sum modulo 998244353 should be equal to Xi therefore I think probability of numbers less than 6M would also be same.

1 Like

why in the beginning it is written 2^40 subsets , but we can form 2^240 subsets

Oh right, wow I misread the problem statement, sorry about my earlier comment.

Hi, firstly the numbers are modulo prime, and the 240 numbers are generated at random,

I agree, it’s not completely clear why there exists a subset of 40 numbers that gives any value modulo 998244353

I tried to provide some reasoning behind this in the first paragraph. Firstly you need at least 30 numbers to cover all possibilities, and 30 are sufficient. (eg. powers of two).

Let’s call a set of integers “favourable” if all of it’s subsets cover all values modulo M, As you increase the size of the array from 30 to 30 + k, the number of such sets increases very quickly.

The easiest thing we can do is check by implementing, on my computer, even after 1000 testcases, I was unable to get a set of 40 numbers which didn’t contain a given integer X.

1 Like

Thanks, fixed

Thanks, that makes a lot of sense, appreciate it! My issue was that I misunderstood the problem statement and thought we had to find a subset with exact sum, not by the same modulo. With that the editorial makes perfect sense. Cheers!