COSTSWAP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic Programming

PROBLEM:

For two binary strings S and T, you can perform the following operations:

  • Choose 1 \leq i \leq j \leq N and swap S_i with S_j.
    This has a cost of A if S_i = 1 after swapping, and a cost of B otherwise.
  • Choose 1 \leq i \leq j \leq N and swap T_i with T_j.
    This has a cost of C if T_i = 1 after swapping, and a cost of D otherwise.

Define f(S, T) to be the minimum cost of making S equal to T, and 0 if it’s impossible to make them equal.
Given N, A, B, C, D, compute the sum of f(S, T) across all 4^N pairs of binary strings of length N.

EXPLANATION:

First, we of course need to figure out how to compute f(S, T) for fixed binary strings S and T.
The only operations with us allow us to rearrange S or T, so if S and T aren’t rearrangements of each other then the answer is immediately 0.

Now, suppose S and T are indeed rearrangements of each other.
Each index i can be one of four types:

  • S_i = T_i = 0, called a 00-type.
  • S_i = T_i = 1, called a 11-type.
  • S_i = 0, T_i = 1, called a 01-type.
  • S_i = 1, T_i = 0, called a 10-type.

We can make a few observations after classifying indices this way:

  1. S and T will be equal if and only if every index is either a 00 type or a 11 type.
  2. If S and T are rearrangements of each other, the number of 01 types must equal the number of 10 types (otherwise their counts of zeros and ones won’t match).

Now, indices that are 00-type or 11-type are already matched, so there’s no need to operate on them at all: they can functionally be ignored.
This leaves operating on the 01 and 10 types.

Note that if i and j are both 01 (or both 10) types, performing the operation (i, j) on either string won’t actually change the strings at all, so it’s never optimal to do so.
On the other hand, if i is a 01 type and j is a 10 type, performing the operation (i, j) will turn one of them into 00 and one of them into 11, which is exactly what we want!

Note that which index becomes 00 and which becomes 11 depends on whether we operate on S or T; and it doesn’t actually matter for the purpose of making S and T equal, so the only question here is of cost.
Specifically, suppose i \lt j, i is a 01 type and j is a 10 type.
Then, operation (i, j) in S will have a cost of A, while doing it in T will have a cost of D.
Since it doesn’t matter which one is done for equality purposes, and we won’t be touching these indices again, it’s clearly optimal to just choose whichever one has smaller cost: which is \min(A, D).
Similarly, if i \gt j, the optimal cost would be \min(B, C) instead.


We now have our setup: there are some 01 indices, and some 10 indices (of the same number).
We can pair a 01 index with a 10 index after it for a cost of \min(A, D), and with a 10 index before it for cost \min(B, C).
Our aim is now to perform this pairing optimally, in order to minimize cost.

This can be done greedily, depending on which of \min(A, D) or \min(B, C) is smaller.
For example, suppose \min(A, D) \leq \min(B, C), so it’s idea to pair 01 indices with 10 indices after them, if possible.
Then, we have the following algorithm:

  • Iterate through indices from left to right.
  • If the current index is 00 or 11, ignore it.
  • If the current index is 01 type, hold it and don’t pair it yet.
  • if the current index is 10 type,
    • If there’s an unpaired 01 type with us, pair them for a cost of \min(A, D).
    • Otherwise, keep this index unpaired.
  • In the end, all unpaired indices can be paired with a cost of \min(B, C).

Now that we know how to compute the answer for a single pair of strings, we need to extend this to computing the sum of answers across all pairs of strings.
Looking back at the answer-computing algorithm, observe that we processed indices from left to right, and at each stage only a couple of things mattered: namely, the number of unpaired 01 indices, and the unpaired 10 indices.

Let’s use this information.
Define dp(i, x, y) to be the sum of answers if we’ve processed the first i characters, there are x unpaired 01 indices, y unpaired 10 indices.
Also define ct(i, x, y) to be the number of strings such that there are x and y unpaired 01, 10 indices after processing i indices. This will be useful later.

Now, we need transitions.
There are four possibilities for what type of index i can be, let’s look at each of them.

  1. 00 or 11 type: this index is functionally ignored, so the sum of answers is from the previous indices only, i.e. dp(i-1, x, y).
  2. 01 type: we hold it and don’t pair it yet.
    The sum of answers still comes from previous indices; but this time with one less 01 index, i.e. dp(i-1, x-1, y).
  3. 10 type: we pair it with a 01 type for cost \min(A, D) if possible, otherwise just hold it.
    • The first case corresponds to a cost of dp(i-1, x+1, y) + ct(i-1, x+1, y)\cdot \min(A, D).
      x+1 is because we’re reducing the number of 01-s with us; and the cost computation is because we get an additional \min(A, D) cost for every string that’s reached this stage; there are ct(i-1, x+1, y) of them by definition.
    • The second case corresponds to dp(i-1, x, y-1), but is only valid when x = 0 (since it’s otherwise optimal to pair).

All in all, we obtain

dp(i, x, y) = 2dp(i-1, x, y) + dp(i-1, x-1, y) + dp(i-1, x+1, y) + ct(i-1, x+1, y)\cdot\min(A, D)

with an additional dp(i-1, x, y-1) term if x = 0.

The transitions for ct(i, x, y) can be derived similarly, and will be

ct(i, x, y) = 2ct(i-1, x, y) + ct(i-1, x-1, y) + ct(i-1, x+1, y)

with an additional ct(i-1, x, y-1) if x = 0.

We have \mathcal{O}(N^3) states with \mathcal{O}(1) transitions from each, so this is fast enough for the constraints.


Finally, let’s actually compute the answer.
After processing the entire string, there will be some unpaired 01 and 10 indices.
The numbers of these unpaired indices should be equal; otherwise the strings can’t be rearranged to equal each other and contribute 0 anyway.

This means we’re only interested in states of the form dp(N, x, x) for 0 \leq x \leq N.
Now, for a fixed x, the cost of all operations requiring cost \min(A, D) has been computed already, and is exactly dp(N, x, x).
This leaves operations with a cost of \min(B, C) to perform.
However, since we’re at dp(N, x, x), we know there are exactly x such operations to perform: and this must be done to ct(N, x, x) pairs of strings.
So, the final answer is

\sum_{x=0}^N \left(dp(N, x, x) + x\cdot\min(B, C)\cdot ct(N, x, x) \right)

Note that this assumed \min(A, D) \leq \min(B, C). If the inequality is the other way, simply swap their roles.

TIME COMPLEXITY:

\mathcal{O}(N^3) per testcase.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <bitset>
#include <unordered_map>
#define fastio ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=410;
const int LOGN=28;
const ll  TMD=998244353;
const ll  INF=2147483647;
int T;
ll  n,a,b,c,d,mx,mn,ZERO;
//map<ll,ll> num[N][N],dp[N][N];
//unordered_map<ll,ll> num[N][N],dp[N][N];
ll num[2][N][N*2],dp[2][N][N*2];

/*
 * stuff you should look for
 * [Before Submission]
 * array bounds, initialization, int overflow, special cases (like n=1), typo
 * [Still WA]
 * check typo carefully
 * casework mistake
 * special bug
 * stress test
 */

void init()
{
    cin>>n>>a>>b>>c>>d;
    mx=min(a,d);
    mn=min(b,c);
    if(mx<mn) swap(mx,mn);
    for(int i=0;i<=1;i++)
        for(int j=0;j<=n;j++)
            for(int k=0;k<=2*n;k++)
                num[i][j][k]=dp[i][j][k]=0;
                //num[i][j]=unordered_map<ll,ll>(),dp[i][j]=unordered_map<ll,ll>();
    ZERO=n;
}

void solve()
{
    num[1][0][0+ZERO]=2;
    num[1][1][1+ZERO]=1;
    num[1][0][-1+ZERO]=1;
    for(int i=2;i<=n;i++)
    {
        for(int j=0;j<=i;j++)
        {
            for(int k=-i+ZERO;k<=i+ZERO;k++)
            {
                num[i&1][j][k]=(num[(i-1)&1][j][k]*2)%TMD;
                if(j) num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j-1][k-1])%TMD;
                num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j+1][k+1])%TMD;
                if(!j) num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j][k+1])%TMD;
                dp[i&1][j][k]=(dp[(i-1)&1][j][k]*2)%TMD;
                if(j) dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j-1][k-1])%TMD;
                dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j+1][k+1]+num[(i-1)&1][j+1][k+1]*mn)%TMD;
                if(!j) dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j][k+1])%TMD;
                //
                //printf("num[%d][%d][%d]=%I64d\n",i,j,k,num[i][j][k]);
                //
            }
        }
    }
    ll ans=0;
    for(int j=0;j<=n;j++)
        ans=(ans+dp[n&1][j][0+ZERO]+num[n&1][j][0+ZERO]*mx%TMD*j)%TMD;
    cout<<ans<<'\n';
}

//-------------------------------------------------------

void gen_data()
{
    srand(time(NULL));
}

int bruteforce()
{
    return 0;
}

//-------------------------------------------------------

int main()
{
    fastio;

    cin>>T;
	while(T--)
	{
		init();
		solve();

		/*

		//Stress Test

		gen_data();
		auto ans1=solve(),ans2=bruteforce();
		if(ans1==ans2) printf("OK!\n");
		else
		{
			//Output Data
 		}

		*/
	}

	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case){
    ll n,A,B,C,D; cin >> n >> A >> B >> C >> D;
    amin(A,D), amin(B,C);
    if(A > B) swap(A,B); // assume A is cheaper, other case is symmetric and doesn't matter while counting

    // dp[i][01-10][active 01] = {sum,ways}
    pll dp1[2*n+5][n+5], dp2[2*n+5][n+5];
    memset(dp1,0,sizeof dp1);
    memset(dp2,0,sizeof dp2);
    dp1[n+1][0] = {0,1};

    rep1(i,n){
        memset(dp2,0,sizeof dp2);
        rep1(j,2*n+1){
            rep(k,n+1){
                auto [sum,ways] = dp1[j][k];

                // 00,11
                dp2[j][k].ff += 2*sum;
                dp2[j][k].ss += 2*ways;
                dp2[j][k].ff %= MOD;
                dp2[j][k].ss %= MOD;

                // 01
                dp2[j+1][k+1].ff += sum;
                dp2[j+1][k+1].ss += ways; 
                dp2[j+1][k+1].ff %= MOD;
                dp2[j+1][k+1].ss %= MOD;

                // 10
                if(!k){
                    // cancels out with unpaired 01 (in the future), cost = B
                    dp2[j-1][k].ff += sum+ways*B;
                    dp2[j-1][k].ss += ways;

                    dp2[j-1][k].ff %= MOD;
                    dp2[j-1][k].ss %= MOD;
                }
                else{
                    // cancels out with 01, cost = A
                    dp2[j-1][k-1].ff += sum+ways*A;
                    dp2[j-1][k-1].ss += ways;

                    dp2[j-1][k-1].ff %= MOD;
                    dp2[j-1][k-1].ss %= MOD;
                }
            }
        }

        memcpy(dp1,dp2,sizeof dp1);
    }

    ll ans = 0;
    rep(k,n+1){
        ans += dp1[n+1][k].ff;
        ans %= MOD;
    }

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
dp1 = [ [ [0 for _ in range(405)] for _ in range(405)] for _ in range(405) ]
dp2 = [ [ [0 for _ in range(405)] for _ in range(405)] for _ in range(405) ]

dp1[0][0][0] = 0
dp2[0][0][0] = 1

for n in range(1, 403):
    for i in range(n+1):
        for j in range(n+1-i):
            # dp[n][i][j]: prefix length n, (i, j) unpaired
            ways, val = 0, 0

            # 00 and 11
            ways += 2*dp2[n-1][i][j]
            val += 2*dp1[n-1][i][j]

            ways %= mod
            val %= mod
            # 01
            # 1. match with existing 10
            ways += dp2[n-1][i][j+1]
            val += dp1[n-1][i][j+1] + dp2[n-1][i][j+1]
            # 2. if j = 0, can just add an extra
            if j == 0 and i > 0:
                ways += dp2[n-1][i-1][j]
                val += dp1[n-1][i-1][j]
            
            ways %= mod
            val %= mod
            
            # 10
            if j > 0:
                ways += dp2[n-1][i][j-1]
                val += dp1[n-1][i][j-1]
            
            dp1[n][i][j] = val%mod
            dp2[n][i][j] = ways%mod

for _ in range(int(input())):
    n, a, b, c, d = list(map(int, input().split()))

    x, y = min(a, d), min(b, c)
    if x > y: x, y = y, x
    ans = 0
    for i in range(n+1):
        val, ways = dp1[n][i][i], dp2[n][i][i]
        ans += val*x + ways*i*y
    print(ans % mod)
1 Like

I think we can solve this problem in O(n^2 logn) using ntt + catalan convolution

My idea:

  • Swapping only (0,1) or (1,0) is optimal in the strings

  • Frequency of 0s and 1s should be same in both strings.

  • Lets consider positions of 0s in the strings, and lets try to match all 0s (if all 0s are alligned in correct positions in both strings, 1s be matched automatically).

  • Say C_1 = \min(A, D) and C_2 = \min(B, C) be the two types of operations we shall apply. Lets assume C_1 < C_2

  • I try to fix the number of times C_1 and C_2 to be applied to make the two strings equal. Say x be the number of times C_1 applied and y be the number of times C_2 be applied.

  • So 2x + 2y positions have 01 or 10 pairs. rest n-2x-2y positions have s[i] = t[i]. Now lets choose these 2x+2y positions out of N. Lets deal with these positions from now on…

Visualisation of this problem to bracket sequence

  • This problem reduces to finding number of bracket sequences of length 2x+2y such that longest regular bracket subsequence is of length 2x. This longest regular bracket subsequence refers to C_1 operation (the cheaper one C_1 < C_2)

  • The opening brackets refers to s[i] t[i] = 01 and closing brackets refer to 10

  • This could be visualized as inserting any number of balanced regular sequences of total lengths = 2x anywhere between )))))... .... ((((( → length = 2y

  • I need to allocate some balanced bracket sequences to 2y + 1 places such that total length of bracket sequence is 2x

  • This can be solved with ntt, polynomial multiplication of c_0 + c_1x + c_2x^2 + ... where c_i refers to i^{th} catalan number.

Submission: CodeChef: Practical coding for everyone

4 Likes

Update: CodeChef: Practical coding for everyone
Can be solved in O(N^2) using catalan convolution formula

4 Likes