CNTGOODARRH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author:
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Dynamic programming

PROBLEM:

An array A of length N is said to be good, if it contains a non-cyclic subarray of length \frac{N}{2} such that multiplying all the elements of this subarray by -1 results in A having negative sum.

You’re given an N\times N matrix M.
Count the number of good arrays of length N such that:

  1. 1 \leq A_i \leq N
  2. M[i][A_i] = 1 for each i.

EXPLANATION:

In the cyclic case, the situation was simple: the array was not good if and only if it was \frac{N}{2}-periodic.
However, that no longer applies here - for example the array [2, 1, 1, 2] is not good now.

Instead, let’s look at the first \frac{N}{2} elements of A.
If their sum is not equal to the sum of the last \frac{N}{2} elements, the array is immediately good: multiplying the larger of the first/second half by -1 will result in a negative sum.
So, the only interesting arrays are those whose first half and second half have the same sum.

Let S denote the sum of the first \frac{N}{2} elements; so that the whole array has sum 2S.

Observe that an array is good if and only if we’re able to find a subarray of length \frac{N}{2} with sum \gt S, so this will be our goal.


Let’s look at the subarray of length \frac{N}{2} starting at i.
The sum of this subarray is, of course, A_i + A_2 + \ldots + A_{i+\frac{N}{2}-1}.
This can be rewritten as

(A_1 + A_2 + \ldots + A_{i+\frac{N}{2}-1}) - (A_1 + A_2 + \ldots + A_{i-1})

which can then further be written as

S + (A_{\frac{N}{2}+1} + A_{\frac{N}{2}+2} + \ldots + A_{\frac{N}{2}+i-1}) - (A_1 + A_2 + \ldots + A_{i-1})

That is, starting with a value of S, we add in the next i-1 elements after index \frac{N}{2}+1, and subtract the first i-1 elements.

Our goal is to make this value \gt S for some index i.
Conversely, the only bad arrays are those for which the above expression always remains \leq S.

Note that this is equivalent to saying that we want (A_{\frac{N}{2}+1} + A_{\frac{N}{2}+2} + \ldots + A_{\frac{N}{2}+i-1}) - (A_1 + A_2 + \ldots + A_{i-1})
to be \leq 0 for every i from 1 to \frac{N}{2}.

This observation allows us to solve the problem using dynamic programming.


Define dp[i][x] to be the number of ways to choose elements A_1, A_2, \ldots, A_i and A_{\frac{N}{2}+1}, \ldots, A_{\frac{N}{2}+i} such that:

  1. The current value of (A_1 + A_2 + \ldots + A_{i-1}) - (A_{\frac{N}{2}+1} + A_{\frac{N}{2}+2} + \ldots + A_{\frac{N}{2}+i-1}) is x.
  2. The value of that difference has never fallen below 0 before this.

If we’re able to compute this, the number of bad arrays is given by just dp[\frac{N}{2}][0], since we want to ensure that the final difference is 0 (recall that we already derived that the first and last half must have the same sum).

How do we compute dp[i][x]?
Transitions in \mathcal{O}(N^2) time are quite easy: there are upto N options for what A_i can be, and upto N options for what A_{i+\frac{N}{2}} can be.
Try all pairs of these options, and for a fixed pair:

  • Let d = A_i - A_{i + \frac{N}{2}}
  • Add dp[i-1][x-d] to dp[i][x].

This is correct, but unfortunately too slow: there can be \mathcal{O}(N^3) states (since there are N choices of index, and x can be as large as \frac{N^2}{2}), and doing \mathcal{O}(N^2) transitions from each of them results in a complexity of \mathcal{O}(N^5).

However, there’s a simple optimization here.
For a fixed index i, let’s first compute a helper array ct[i][d], which denotes the number of ways of choosing A_i and A_{i+\frac{N}{2}} such that A_i - A_{i+\frac{N}{2}} = d.
Note that this is independent of x, and so needs to be done only once for each index; the overall complexity is hence \mathcal{O}(N^3).

Once the ct array is known, observe that to compute some dp[i][x] we can simply iterate through values of the difference d instead, and for each of them add dp[i-1][x-d] \cdot ct[i][d] to dp[i][x].
This is much faster because there are only \mathcal{O}(N) values of d: after all, the difference of two elements that are both in [1, N], must lie in [-N, N].

This brings the complexity down to \mathcal{O}(N^4), which is fast enough to pass given that N \leq 100.

TIME COMPLEXITY:

\mathcal{O}(N^4) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

const int mod = 1e9 + 7;

void Solve() 
{
    int n; cin >> n;
    
    vector<string> a(n);
    for (auto &x : a) cin >> x;
    
    vector<int> dp(n * n + 1, 0);
    dp[0] = 1;
    const int mod = 1e9 + 7;
    
    for (int i = 0; i < n / 2; i++){
        vector <int> ndp(n * n + 1, 0);
        
        vector <int> ways(2 * n + 1, 0);
        for (int x = 1; x <= n; x++){
            for (int y = 1; y <= n; y++){
                if (a[i][x - 1] == '1' && a[i + n / 2][y - 1] == '1'){
                    ways[x - y + n]++;
                }
            }
        }
        
        for (int j = 0; j <= n * i; j++){
            for (int d = -n; d <= n; d++){
                if (j + d >= 0){
                    ndp[j + d] += dp[j] * ways[d + n];
                    ndp[j + d] %= mod;
                }
            }
        }
        
        dp = ndp;
    }
    
    int total = 1;
    for (int i = 0; i < n; i++){
        int cnt = 0;
        for (int j = 0; j < n; j++){
            cnt += a[i][j] == '1';
        }
        
        total *= cnt;
        total %= mod;
    }
    
    total -= dp[0];
    if (total < 0) total += mod;
    
    cout << total << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
3 Likes

I am not able to understand what our dp is actually calculating and how is dp[n/2][0] storing the no of arrays such that any subarray sum of this array is less than sum(array)/2.

how does O(N^4) works, at max (N^3) should work right ?