TWRUP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Dynamic programming

PROBLEM:

A tower defense game has N waves. The i-th wave has A_i enemies, each with H_i health, and you have S_i turns to clear it.
You start with one tower with power P=1.
Each wave,

  • Every existing tower locks on to one enemy and does P damage per turn to it.
    Different towers cannot lock on to the same enemy, however.
  • When an enemy dies, the tower can switch targets to a new enemy.

After clearing a wave, you can increase either the number of towers or their power by 1.
Decide if it’s possible to clear all N waves.

EXPLANATION:

Suppose we have K towers, each with power P.
Let’s decide if we can clear a wave with A enemies, H health, in S turns.

A single tower does P damage per turn, and so needs \left\lceil \frac{H}{P} \right\rceil turns to kill a single enemy.
This means, in S turns, one tower can kill a maximum of

\left\lfloor \frac{S}{\left\lceil \frac{H}{P} \right\rceil} \right\rfloor

enemies.

If the above quantity is denoted by c, then with K towers the maximum number of enemies that can be defeated is K\cdot c.
We thus need K\cdot c to be \ge A to be able to clear this wave.

Note that this is a simple \mathcal{O}(1) check once we have all the values.


To solve the actual problem, we use dynamic programming.

A simple starting point is to define dp(i, K, P) to be true if we can clear the first i waves and have K towers each with power P afterwards; and false if we cannot.

Transitions are simple: to compute dp(i, K, P) we need to look at only dp(i-1, K-1, P) and dp(i-1, K, P-1) and run our constant-time check on each to see if either of them work.

Naturally, this solution is much too slow, having \mathcal{O}(N^3) states.
Our task is now to optimize this.

One immediate optimization is to note that we don’t need to maintain both K and P in the DP state.
Since we increase exactly one of them after each wave, their sum will be a constant no matter what the upgrade path is - in particular, K+P will always equal i+3 after clearing i waves.
This allows us to turn our states into just dp(i, P), and compute and use K implicitly from there.
The complexity is immediately brought down to \mathcal{O}(N^2).


We require one more optimization.
For that, we utilize properties of the check formula itself.

Recall that one tower can defeat \left\lfloor \frac{S}{\left\lceil \frac{H}{P} \right\rceil} \right\rfloor enemies in the given turns.
While this quantity varies with P, it can’t actually take too many distinct values: in particular, \left\lceil \frac{H}{P} \right\rceil has \mathcal{O}(\sqrt H) values it can take (treat P \le \sqrt H and P \gt \sqrt H differently to see why.)

Further, for a fixed value of \left\lceil \frac{H}{P} \right\rceil, the set of P that attain this will be a contiguous range.
Finding this range can be done with some simple algebra.

Note that we require K\cdot c \ge A to be able to clear the wave; where c = \left\lfloor \frac{S}{\left\lceil \frac{H}{P} \right\rceil} \right\rfloor.
Since K = i+3-P, this gives us (after rearrangement) another inequality on P; in particular it gives us an upper bound.

Thus, for a fixed value of \left\lceil \frac{H}{P}\right\rceil, the set of P that allow for the wave to be cleared themselves form an interval.


We can utilize this information by maintaining all current valid values of P as a sorted set of disjoint intervals.
When moving to the next wave, this set can then be updated by intersecting it with the intervals of valid P corresponding to the new wave.
If the existing set of intervals has size M, this can be done in \mathcal{O}(M + \sqrt H) time using a two-pointer approach; since we’re really just intersecting two sets of disjoint sorted intervals.

Finally, we need to account for the upgrade at the end of each wave.
Note that each surviving value of P can either stay the same or increase by 1; which means a surviving interval [L, R] becomes the new interval [L, R+1].
(To maintain disjoint-ness, this might require you to merge some intervals afterwards; but again the whole thing can be done in \mathcal{O}(\text{intervals}) utilizing the fact that they’re already sorted.

Finally, note that H \le N means there are \mathcal{O}(\sqrt N) intervals of P active at any point of time.
This means the overall complexity is \mathcal{O}(N\sqrt N).
With low enough constant factor, \mathcal{O}(N\sqrt N\log N) implementations may also pass.

Further, the low-ish constraints allow for using bitsets to maintain the intervals of P.
This ends up being similar in speed to the \mathcal{O}(N\sqrt N) implementation despite being \mathcal{O}(\frac{N^2}{64}), so if implemented properly will likely receive AC as well (and it’s a bit simpler to implement than the interval-merging mentioned above.)

TIME COMPLEXITY:

\mathcal{O}(N \sqrt N) per testcase.

CODE:

Tester's code (C++, bitsets)
#include<bits/stdc++.h>
// #define int long long
#define ull unsigned long long
using namespace std;

void solve() {
    int n; cin >> n;

    int sz = (n+63)/64 + 1;

    vector<ull> dp(sz, 0), good(sz, 0);
    dp[0] = 1ULL << 1;

    auto addRange = [&](int l, int r) {
        if (l > r) return;
        int wl = l/64, wr = r/64;
        int bl = l%64, br = r%64;
        ull comp = ~0ULL;
        if(wl == wr){
            ull mask = (comp << bl) & (comp >> (63 - br));
            good[wl] |= dp[wl] & mask;
        } 
        else{
            good[wl] |= dp[wl] & (comp << bl);
            for(int w = wl + 1; w < wr; w++){
                good[w] |= dp[w];
            }
            good[wr] |= dp[wr] & (comp >> (63 - br));
        }
    };

    bool f = 1;
    for (int i = 0; i < n; i++) {
        int A, H, S;
        cin >> A >> H >> S;

        if(f == 0)continue;

        for(int i = 0;i < sz;i++)good[i]=0;

        for(int k = 1; k <= i+1;){
            int e = (A+k-1)/k;
            int hi;
            if(e == 1){
                hi = i+1;
            } 
            else{
                hi = min(i+1,(A-1)/(e-1));
            }
            if(e <= S){
                int a = S / e;
                int p = (H + a - 1) / a;
                int mx = i+2 - p;
                addRange(k, min(hi, mx));
            }
            k = hi + 1;
        }

        bool b = 0;
        for(ull x : good){
            if (x){b = 1;break;}
        }
        if(!b){
            f = 0;continue;
        }
        ull temp = 0;
        for(int w = 0; w < sz; w++){
            ull shifted = (good[w] << 1) | temp;
            temp = good[w] >> 63;
            dp[w] = good[w] | shifted;
        }
    }

    if(f) cout << "Yes\n";
    else cout << "No\n";
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    int t; cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}