PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: wuhudsm
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Dynamic Programming
PROBLEM:
For two binary strings S and T, you can perform the following operations:
- Choose 1 \leq i \leq j \leq N and swap S_i with S_j.
This has a cost of A if S_i = 1 after swapping, and a cost of B otherwise. - Choose 1 \leq i \leq j \leq N and swap T_i with T_j.
This has a cost of C if T_i = 1 after swapping, and a cost of D otherwise.
Define f(S, T) to be the minimum cost of making S equal to T, and 0 if it’s impossible to make them equal.
Given N, A, B, C, D, compute the sum of f(S, T) across all 4^N pairs of binary strings of length N.
EXPLANATION:
First, we of course need to figure out how to compute f(S, T) for fixed binary strings S and T.
The only operations with us allow us to rearrange S or T, so if S and T aren’t rearrangements of each other then the answer is immediately 0.
Now, suppose S and T are indeed rearrangements of each other.
Each index i can be one of four types:
- S_i = T_i = 0, called a 00-type.
- S_i = T_i = 1, called a 11-type.
- S_i = 0, T_i = 1, called a 01-type.
- S_i = 1, T_i = 0, called a 10-type.
We can make a few observations after classifying indices this way:
- S and T will be equal if and only if every index is either a 00 type or a 11 type.
- If S and T are rearrangements of each other, the number of 01 types must equal the number of 10 types (otherwise their counts of zeros and ones won’t match).
Now, indices that are 00-type or 11-type are already matched, so there’s no need to operate on them at all: they can functionally be ignored.
This leaves operating on the 01 and 10 types.
Note that if i and j are both 01 (or both 10) types, performing the operation (i, j) on either string won’t actually change the strings at all, so it’s never optimal to do so.
On the other hand, if i is a 01 type and j is a 10 type, performing the operation (i, j) will turn one of them into 00 and one of them into 11, which is exactly what we want!
Note that which index becomes 00 and which becomes 11 depends on whether we operate on S or T; and it doesn’t actually matter for the purpose of making S and T equal, so the only question here is of cost.
Specifically, suppose i \lt j, i is a 01 type and j is a 10 type.
Then, operation (i, j) in S will have a cost of A, while doing it in T will have a cost of D.
Since it doesn’t matter which one is done for equality purposes, and we won’t be touching these indices again, it’s clearly optimal to just choose whichever one has smaller cost: which is \min(A, D).
Similarly, if i \gt j, the optimal cost would be \min(B, C) instead.
We now have our setup: there are some 01 indices, and some 10 indices (of the same number).
We can pair a 01 index with a 10 index after it for a cost of \min(A, D), and with a 10 index before it for cost \min(B, C).
Our aim is now to perform this pairing optimally, in order to minimize cost.
This can be done greedily, depending on which of \min(A, D) or \min(B, C) is smaller.
For example, suppose \min(A, D) \leq \min(B, C), so it’s idea to pair 01 indices with 10 indices after them, if possible.
Then, we have the following algorithm:
- Iterate through indices from left to right.
- If the current index is 00 or 11, ignore it.
- If the current index is 01 type, hold it and don’t pair it yet.
- if the current index is 10 type,
- If there’s an unpaired 01 type with us, pair them for a cost of \min(A, D).
- Otherwise, keep this index unpaired.
- In the end, all unpaired indices can be paired with a cost of \min(B, C).
Now that we know how to compute the answer for a single pair of strings, we need to extend this to computing the sum of answers across all pairs of strings.
Looking back at the answer-computing algorithm, observe that we processed indices from left to right, and at each stage only a couple of things mattered: namely, the number of unpaired 01 indices, and the unpaired 10 indices.
Let’s use this information.
Define dp(i, x, y) to be the sum of answers if we’ve processed the first i characters, there are x unpaired 01 indices, y unpaired 10 indices.
Also define ct(i, x, y) to be the number of strings such that there are x and y unpaired 01, 10 indices after processing i indices. This will be useful later.
Now, we need transitions.
There are four possibilities for what type of index i can be, let’s look at each of them.
- 00 or 11 type: this index is functionally ignored, so the sum of answers is from the previous indices only, i.e. dp(i-1, x, y).
- 01 type: we hold it and don’t pair it yet.
The sum of answers still comes from previous indices; but this time with one less 01 index, i.e. dp(i-1, x-1, y). - 10 type: we pair it with a 01 type for cost \min(A, D) if possible, otherwise just hold it.
- The first case corresponds to a cost of dp(i-1, x+1, y) + ct(i-1, x+1, y)\cdot \min(A, D).
x+1 is because we’re reducing the number of 01-s with us; and the cost computation is because we get an additional \min(A, D) cost for every string that’s reached this stage; there are ct(i-1, x+1, y) of them by definition. - The second case corresponds to dp(i-1, x, y-1), but is only valid when x = 0 (since it’s otherwise optimal to pair).
- The first case corresponds to a cost of dp(i-1, x+1, y) + ct(i-1, x+1, y)\cdot \min(A, D).
All in all, we obtain
with an additional dp(i-1, x, y-1) term if x = 0.
The transitions for ct(i, x, y) can be derived similarly, and will be
with an additional ct(i-1, x, y-1) if x = 0.
We have \mathcal{O}(N^3) states with \mathcal{O}(1) transitions from each, so this is fast enough for the constraints.
Finally, let’s actually compute the answer.
After processing the entire string, there will be some unpaired 01 and 10 indices.
The numbers of these unpaired indices should be equal; otherwise the strings can’t be rearranged to equal each other and contribute 0 anyway.
This means we’re only interested in states of the form dp(N, x, x) for 0 \leq x \leq N.
Now, for a fixed x, the cost of all operations requiring cost \min(A, D) has been computed already, and is exactly dp(N, x, x).
This leaves operations with a cost of \min(B, C) to perform.
However, since we’re at dp(N, x, x), we know there are exactly x such operations to perform: and this must be done to ct(N, x, x) pairs of strings.
So, the final answer is
Note that this assumed \min(A, D) \leq \min(B, C). If the inequality is the other way, simply swap their roles.
TIME COMPLEXITY:
\mathcal{O}(N^3) per testcase.
CODE:
Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <bitset>
#include <unordered_map>
#define fastio ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=410;
const int LOGN=28;
const ll TMD=998244353;
const ll INF=2147483647;
int T;
ll n,a,b,c,d,mx,mn,ZERO;
//map<ll,ll> num[N][N],dp[N][N];
//unordered_map<ll,ll> num[N][N],dp[N][N];
ll num[2][N][N*2],dp[2][N][N*2];
/*
* stuff you should look for
* [Before Submission]
* array bounds, initialization, int overflow, special cases (like n=1), typo
* [Still WA]
* check typo carefully
* casework mistake
* special bug
* stress test
*/
void init()
{
cin>>n>>a>>b>>c>>d;
mx=min(a,d);
mn=min(b,c);
if(mx<mn) swap(mx,mn);
for(int i=0;i<=1;i++)
for(int j=0;j<=n;j++)
for(int k=0;k<=2*n;k++)
num[i][j][k]=dp[i][j][k]=0;
//num[i][j]=unordered_map<ll,ll>(),dp[i][j]=unordered_map<ll,ll>();
ZERO=n;
}
void solve()
{
num[1][0][0+ZERO]=2;
num[1][1][1+ZERO]=1;
num[1][0][-1+ZERO]=1;
for(int i=2;i<=n;i++)
{
for(int j=0;j<=i;j++)
{
for(int k=-i+ZERO;k<=i+ZERO;k++)
{
num[i&1][j][k]=(num[(i-1)&1][j][k]*2)%TMD;
if(j) num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j-1][k-1])%TMD;
num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j+1][k+1])%TMD;
if(!j) num[i&1][j][k]=(num[i&1][j][k]+num[(i-1)&1][j][k+1])%TMD;
dp[i&1][j][k]=(dp[(i-1)&1][j][k]*2)%TMD;
if(j) dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j-1][k-1])%TMD;
dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j+1][k+1]+num[(i-1)&1][j+1][k+1]*mn)%TMD;
if(!j) dp[i&1][j][k]=(dp[i&1][j][k]+dp[(i-1)&1][j][k+1])%TMD;
//
//printf("num[%d][%d][%d]=%I64d\n",i,j,k,num[i][j][k]);
//
}
}
}
ll ans=0;
for(int j=0;j<=n;j++)
ans=(ans+dp[n&1][j][0+ZERO]+num[n&1][j][0+ZERO]*mx%TMD*j)%TMD;
cout<<ans<<'\n';
}
//-------------------------------------------------------
void gen_data()
{
srand(time(NULL));
}
int bruteforce()
{
return 0;
}
//-------------------------------------------------------
int main()
{
fastio;
cin>>T;
while(T--)
{
init();
solve();
/*
//Stress Test
gen_data();
auto ans1=solve(),ans2=bruteforce();
if(ans1==ans2) printf("OK!\n");
else
{
//Output Data
}
*/
}
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 998244353;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
void solve(int test_case){
ll n,A,B,C,D; cin >> n >> A >> B >> C >> D;
amin(A,D), amin(B,C);
if(A > B) swap(A,B); // assume A is cheaper, other case is symmetric and doesn't matter while counting
// dp[i][01-10][active 01] = {sum,ways}
pll dp1[2*n+5][n+5], dp2[2*n+5][n+5];
memset(dp1,0,sizeof dp1);
memset(dp2,0,sizeof dp2);
dp1[n+1][0] = {0,1};
rep1(i,n){
memset(dp2,0,sizeof dp2);
rep1(j,2*n+1){
rep(k,n+1){
auto [sum,ways] = dp1[j][k];
// 00,11
dp2[j][k].ff += 2*sum;
dp2[j][k].ss += 2*ways;
dp2[j][k].ff %= MOD;
dp2[j][k].ss %= MOD;
// 01
dp2[j+1][k+1].ff += sum;
dp2[j+1][k+1].ss += ways;
dp2[j+1][k+1].ff %= MOD;
dp2[j+1][k+1].ss %= MOD;
// 10
if(!k){
// cancels out with unpaired 01 (in the future), cost = B
dp2[j-1][k].ff += sum+ways*B;
dp2[j-1][k].ss += ways;
dp2[j-1][k].ff %= MOD;
dp2[j-1][k].ss %= MOD;
}
else{
// cancels out with 01, cost = A
dp2[j-1][k-1].ff += sum+ways*A;
dp2[j-1][k-1].ss += ways;
dp2[j-1][k-1].ff %= MOD;
dp2[j-1][k-1].ss %= MOD;
}
}
}
memcpy(dp1,dp2,sizeof dp1);
}
ll ans = 0;
rep(k,n+1){
ans += dp1[n+1][k].ff;
ans %= MOD;
}
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
mod = 998244353
dp1 = [ [ [0 for _ in range(405)] for _ in range(405)] for _ in range(405) ]
dp2 = [ [ [0 for _ in range(405)] for _ in range(405)] for _ in range(405) ]
dp1[0][0][0] = 0
dp2[0][0][0] = 1
for n in range(1, 403):
for i in range(n+1):
for j in range(n+1-i):
# dp[n][i][j]: prefix length n, (i, j) unpaired
ways, val = 0, 0
# 00 and 11
ways += 2*dp2[n-1][i][j]
val += 2*dp1[n-1][i][j]
ways %= mod
val %= mod
# 01
# 1. match with existing 10
ways += dp2[n-1][i][j+1]
val += dp1[n-1][i][j+1] + dp2[n-1][i][j+1]
# 2. if j = 0, can just add an extra
if j == 0 and i > 0:
ways += dp2[n-1][i-1][j]
val += dp1[n-1][i-1][j]
ways %= mod
val %= mod
# 10
if j > 0:
ways += dp2[n-1][i][j-1]
val += dp1[n-1][i][j-1]
dp1[n][i][j] = val%mod
dp2[n][i][j] = ways%mod
for _ in range(int(input())):
n, a, b, c, d = list(map(int, input().split()))
x, y = min(a, d), min(b, c)
if x > y: x, y = y, x
ans = 0
for i in range(n+1):
val, ways = dp1[n][i][i], dp2[n][i][i]
ans += val*x + ways*i*y
print(ans % mod)