DEDSTER - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: khaab_2004, iceknight1093, apoorv_me
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

2943

PREREQUISITES:

Math

PROBLEM:

There are N racers. The i-th of them will travel A_i meters at a speed of B_i.
If a racer finishes at time t, their reported time is \left\lceil t\right\rceil.

For each D = 1, 2, 3, \ldots, K, compute the following:

  • If Spooky travels a distance of D, find the expected (across all non-empty subsets of racers) minimum speed required to not lose to any of the racers.

EXPLANATION:

Let T_i denote the reported time of the i-th racer.
This is a constant independent of D, and equals \left\lceil \frac{A_i}{B_i} \right\rceil.

Let’s try to solve the problem for a fixed D first.
Suppose we also fix a subset S of the racers.
Let M = \min_{i\in S} T_i. To not lose to any of them, Spooky needs to travel fast enough that her reported time is not strictly larger than M.
So, her speed V should satisfy \left\lceil \frac{D}{V} \right\rceil \leq M.

Working this out algebraically, it can be seen that Spooky’s minimum required speed is exactly V = \left\lceil \frac{D}{M}\right\rceil.
To find the expectation of this across all subsets, we can instead find the sum of minimum speeds across all subsets, then divide by 2^N - 1.
To compute this sum, we use the contribution technique.

Let’s fix a speed V - we know an optimal speed must look like \left\lceil \frac{D}{M}\right\rceil, so clearly V lies between 1 and D.
The number of subsets that require a speed of V can be counted as follows:

  • Count the number of elements of T_i such that \left\lceil \frac{D}{V} \right\rceil \geq T_i, let it be x_V.
    This is the number of elements in some suffix of the sorted array T, which can be found easily (say, with binary search).
  • Also count the number of elements of T_i such that \left\lceil \frac{D}{V-1} \right\rceil \geq T_i, say y_V.
  • Then, exactly 2^{x_V} - 2^{y_V} subsets of the racers require a speed of V.
    This is because 2^{x_V} subsets of them are valid with a speed of V; however 2^{y_V} of them are valid with a speed of V-1 as well so we remove them from consideration.
    (The empty set is included in both, and so cancels out and is not present in this count).

So, the overall sum is just

\sum_{V=1}^D V\cdot (2^{x_V}- 2^{y_V})

For a fixed D, this can be computed in \mathcal{O}(D), giving us a solution in \mathcal{O}(N + K^2) overall.


Now that we have a correct solution, let’s attempt to speed it up.

Instead of fixing a D and trying to compute the answer for it alone, we’ll compute the answer for all D simultaneously.
Let \text{ans}[D] denote the answer for D. Initially, this is 0 for all D.

Let’s fix M to be the finish time of the fastest person in the subset.
That is, we’re freely allowed to choose any subset of people whose finish time is \geq M, as long as one of them have a finish time of M.
Suppose there are H_M such subsets (counting them is simple: once again, count the number of subsets with elements \geq M and subtract the number of those with elements \gt M).

Note that for any such subset:

  • For D = 1, 2, 3, \ldots, M, Spooky can drive at a speed of 1.
  • For D = M+1, M+2, M+3, \ldots, 2M, Spooky can drive at a speed of 2.
  • For D = 2M+1, 2M+2, 2M+3, \ldots, 3M, Spooky can drive at a speed of 3.
    \vdots
  • For D = iM+1, iM+2, \ldots, (i+1)M, Spooky can drive at a speed of i.

So, what we’d like to do is, for each i\geq 0, add (i+1)\cdot H_M to the values of \text{ans}[D] such that iM \lt D \leq (i+1)M.
Range additions like this can be processed offline quickly with the help of prefix sums — increase \text{ans}[iM+1] by i\cdot H_M and reduce \text{ans}[(i+1)M+1] by the same value, and after everything is done, compute the prefix sums of \text{ans}.

Note that for a fixed M, we only have \frac{K}{M} range-add updates because we don’t care about i for which i\cdot M \gt K.
This gives us an overall complexity of \displaystyle\mathcal{O}\left(\sum_{i=1}^K \frac{K}{i}\right) = \mathcal{O}(K\log K).

Remember to divide all answers by 2^N-1 before printing them.

TIME COMPLEXITY

\mathcal{O}(N + K\log K) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
constexpr int mod = 998244353;

int add(int a, const int &b) {
    a += b;
    if(a >= mod)    a -= mod;
    return a;
}

int power(int a, int b) {
    int res = 1;
    while(b > 0) {
        if(b & 1)   res = res * 1ll * a % mod;
        a = a * 1ll * a % mod;
        b >>= 1;
    }
    return res;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
	int T;  cin >> T;
	while(T--) {
	    int N, K;   cin >> N >> K;
	    vector<int> a(N);   for(auto &i : a)    cin >> i;
	    vector<int> b(N);   for(auto &i : b)    cin >> i;
	    vector<int> c;
	    
	    for(int i = 0 ; i < N ; i++) {
	       c.push_back((a[i] + b[i] - 1) / b[i]);
	    }
	    sort(c.begin(), c.end());
	    for(auto &i : c)    i = min(i, K);
	    vector<int64_t> p(N + 1), ans(K + 1);
	    p[0] = 1;
	    for(int i = 0 ; i < N ; i++) {
	        p[i + 1] = p[i] + p[i];
	        if(p[i + 1] >= mod)     p[i + 1] -= mod;
	    }
	    for(int i = N - 1 ; i >= 0 ; i--) if(i == 0 || c[i] != c[i - 1]) {
	        int add_here = 0, x = c[i];
	        for(int j = i ; j < N ; j++) {
	            if(c[j] != x)    break;
	            add_here = add(add_here, p[N - j - 1]);
	           // cout << j << " ";
	        }
	        for(int j = 1 ; j <= K ; j += x) {
	            ans[j] = add(ans[j], add_here);
	        }
	    }
	    
	    for(int i = 1 ; i <= K ; i++) {
	        ans[i] = add(ans[i], ans[i - 1]);
	    }
	    int inverse = power(p[N] - 1, mod - 2);
	    for(int i = 1 ; i <= K ; i++)    cout << (ans[i] * 1ll * inverse % mod) << " \n"[i == K];
	}
	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
//#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

//input_checker inp;
int sum_n = 0, sum_k = 0;
const int mod = 998244353;

int power(int x, int y){
    if (y == 0) return 1;
    
    int v = power(x, y / 2);
    v = 1LL * v * v % mod;
    
    if (y & 1) return (1LL * v * x) % mod;
    else return v;
}

void Solve() 
{
    // int n, k; 
    // n = inp.readInt(1, (int)3e5); inp.readSpace();
    // k = inp.readInt(1, (int)3e5); inp.readEoln();
    // sum_n += n; sum_k += k; assert(sum_n <= (int)3e5); assert(sum_k <= (int)3e5);

    // auto a = inp.readInts(n, 1, (int)1e9); inp.readEoln();
    // auto b = inp.readInts(n, 1, (int)1e9); inp.readEoln();
    int n, k;
    cin >> n >> k;
    vector <int> a(n), b(n);
    for (auto &x : a) cin >> x;
    for (auto &x : b) cin >> x;
    
    vector <int> f(k + 2, 0); //frequency of times 
    for (int i = 0; i < n; i++){
    	f[min(k + 1, (a[i] + b[i] - 1) / b[i])]++;
    }

    vector <int> p2(n + 1, 1);
    for (int i = 1; i <= n; i++){
    	p2[i] = p2[i - 1] * 2 % mod;
    }
    
    vector <int> diff(k + 1, 0);

    int A = 0, B = 0;
    for (int i = k + 1; i >= 1; i--){
        B = f[i];
        for (int d = 1; d <= k; d += i){
            diff[d] += 1LL * p2[A] * (p2[B] - 1) % mod;
            if (diff[d] >= mod) diff[d] -= mod;
        }
        
        A += B;
    }
    
    int inv = power(p2[n] - 1, mod - 2);
    
    for (int i = 1; i <= k; i++){
        diff[i] += diff[i - 1];
        if (diff[i] >= mod) diff[i] -= mod;
     //   cout << diff[i] << " ";
        cout << (1LL * diff[i] * inv) % mod << " \n"[i == k];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;

    // t = inp.readInt(1, (int)1e5);
    // inp.readEoln();

    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    
   // inp.readEof();
    
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
mod = 998244353
maxn = 3*10**5 + 100
pow2 = [1]*maxn
for i in range(1, maxn): pow2[i] = pow2[i-1] * 2 % mod

for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))
    for i in range(n): a[i] = (a[i] + b[i] - 1) // b[i]
    freq = [0]*(k+2)
    for x in a: freq[min(x, k)] += 1
    for x in reversed(range(k)): freq[x] += freq[x+1]
    
    ans = [0]*(k+2)
    for m in range(1, k+1):
        subs = (pow2[freq[m]] - pow2[freq[m+1]]) % mod
        for i in range(k+1):
            lo, hi = i*m+1, min(i*m+m+1, k+1)
            if lo > k: break
            ans[lo] += subs * (i+1) % mod
            ans[hi] -= subs * (i+1) % mod
    den = pow(pow2[n] - 1, mod-2, mod)
    for i in range(1, k+1): ans[i] = (ans[i] + ans[i-1]) % mod
    for i in range(1, k+1): ans[i] = ans[i] * den % mod
    
    print(*ans[1:k+1])