PROBLEM LINK:
Setter: Sai Suman Chitturi
Tester: Aryan Chaudhary
Editorialist: Rishabh Gupta
DIFFICULTY:
Medium
PREREQUISITES:
DP with Bitmasking
PROBLEM:
It’s Valentine’s Month, and so Chef wants to gift a string to Chefina. He knows that Chefina likes palindromic strings. Chef has a string S consisting of digits from 0 to 9. He wants to convert this to a palindrome by performing zero or more operations, where an operation is defined as follows:
- Pick any digit and replace all of its occurrences with any other digit.
For example, if S=12123, in one move it can be turned into 14143 or 12128, but not 32123.
You are also given N integers cost_1,cost_2,…,cost_N where the cost of performing an operation that replaces X characters is cost_X.
Your task is to help Chef minimize the cost of converting his string to a palindrome.
QUIICK EXPLANATION:
Which numbers will convert to the same number at the end?
This can be found by using disjoint set unions and the i^{th} and the (n+1-i)^{th} character will belong to the same group. So, some groups will be formed denoting that numbers of a particular group will finally convert to a single number.
How to calculate the answer of a formed group?
Our task is to convert all numbers of a particular group into a single number by doing multiple operations. We can use the bitmasking dp technique to calculate this.
The hidden catch
It might be possible that when two groups are transformed into a single number the cost is lesser. This arises due to the nature of the cost array.
To solve this issue, we can again use dp with bitmasking to choose the best answer.
EXPLANATION:
The numbers at s[i] and s[n+1-i] will finally convert to the same number and so will all their other occurrences. So, we can form groups for the numbers which will convert to the same number. This can be done using dsu. Since there are only 10 possible groups this can be done without any optimizations.
After the group formations, we want to calculate the answer for a particular group to convert them into a single integer. To do this we can use the dp with bitmasking technique. The final step of the conversion will involve two subgroups merging into one. We can iterate through all the subgroups of the group and find the new cost using the cost of the subgroups and the cost of this operation. The relation will look like
dp[mask] = min(dp[mask] ,dp[submask] + dp[mask-sumask] + cost[number~of~elements ~in~ submask])
We can pick the operation with the minimum cost. Iterating on all the submasks of all masks takes O(3^{10}) time.
But we are not done yet. Simply adding the cost of each group(to convert them into a single integer) won’t e our answer. It may be optimal to convert two groups into a single integer because of the non-monotonic nature of the cost function.
Consider the string 1231 and the cost array as 9 1 2 4, it is optimal to convert 2 1’s into 2 and then 3 2’s into 3, rather than just converting 2 to 3.
So we again have to use the dynamic programming with bitmasking approach to update the dp.
In the previous step we have calculated the minimum cost to convert every possible mask(hence a group) into a single number. In this step we will update the dp values such that if a mask if formed of two or more groups(these are the groups formed after the dsu step) its cost can be directly the addition of the cost of the subgroups.
dp[mask] = min(dp[mask] , dp[submask] + dp[mask-submask] ) , provided that the submask is a collection of 1 or more complete groups(these are the groups formed after the dsu step). If we proceed in this manner and update the whole dp array, our final answer will be equal to ans[1023], since 1023 is the mask of the all the numbers together in a group.
TIME COMPLEXITY:
O(3^{10} + n) for each test case, as the time to iterate through all submasks for all possible masks is 3^{10}, and O(n) computations to count the number of occurrences of a number and form groups among them.
SOLUTION:
Setter's Solution
// Utkarsh's solution coded in Java
import java.util.*;
class Main {
static Scanner scan = new Scanner(System.in);
static final int N = 500023;
static boolean[] visited = new boolean[N];
static ArrayList<ArrayList<Integer>> adj = new ArrayList<ArrayList<Integer>>();
static ArrayList<Integer> comp = new ArrayList<Integer>();
static void dfs(int curr) {
visited[curr] = true;
comp.add(curr);
for(int i = 0; i < adj.get(curr).size(); i++) {
int it = adj.get(curr).get(i);
if(visited[it]) {
continue;
}
dfs(it);
}
}
static int popCount(long n) {
int count = 0;
while(n > 0) {
count += (n & 1);
n >>= 1;
}
return count;
}
static void solve() {
adj.clear();
comp.clear();
for(int i = 0; i < 15; i++) {
visited[i] = false;
adj.add(new ArrayList<>());
}
int n = scan.nextInt();
String s = scan.next();
int[] cost = new int[n + 1];
for(int i = 1; i <= n; i++) {
cost[i] = scan.nextInt();
}
long[] cnt = new long[10];
for(int i = 0; i < s.length(); i++) {
char ch = s.charAt(i);
cnt[ch - '0']++;
}
long[] dp = new long[(1 << 10) + 10];
dp[0] = 0;
int[] mskcnt = new int[(1 << 10) + 10];
for(int i = 1; i < (1 << 10); i++) {
for(int j = 0; j <= 9; j++) {
if((i&(1 << j)) != 0) {
mskcnt[i] += cnt[j];
}
}
}
for(int msk = 1; msk < (1 << 10); msk++) {
dp[msk] = Long.MAX_VALUE;
if((popCount(msk)) == 1) {
dp[msk] = 0;
continue;
}
for(int submask = msk; submask != 0; submask = (submask - 1) & msk) {
if(submask == msk || submask == 0) {
continue;
}
int complement = (msk ^ submask);
dp[msk] = Math.min(dp[msk], dp[submask] + dp[complement] + Math.min(cost[mskcnt[submask]], cost[mskcnt[complement]]));
}
}
boolean[][] isedge = new boolean[12][12];
for(int i = 0; i < n; i++) {
if(s.charAt(i) != s.charAt(n - i - 1) && isedge[s.charAt(i) - '0'][s.charAt(n - i - 1) - '0'] == false) {
adj.get(s.charAt(i) - '0').add(s.charAt(n - i - 1) - '0');
adj.get(s.charAt(n - i - 1) - '0').add(s.charAt(i) - '0');
isedge[s.charAt(i) - '0'][s.charAt(n - i - 1) - '0'] = true;
isedge[s.charAt(n - i - 1) - '0'][s.charAt(i) - '0'] = true;
}
}
ArrayList<Integer> v = new ArrayList<Integer>();
for(int i = 0; i <= 9; i++) {
if(visited[i]) {
continue;
}
comp.clear();
dfs(i);
int tmp = 0;
for(int j: comp) {
tmp += (1 << j);
}
v.add(tmp);
}
int sz = v.size();
long[] finaldp = new long[(1 << sz) + 10];
finaldp[0] = 0;
for(int msk = 1; msk < (1 << sz); msk++) {
finaldp[msk] = Long.MAX_VALUE;
for(int submask = msk; submask != 0; submask = (submask - 1) & msk) {
int temp = 0;
for(int i = 0; i < sz; i++) {
if((submask&(1 << i)) != 0) {
temp |= (v.get(i));
}
}
finaldp[msk] = Math.min(finaldp[msk], dp[temp] + finaldp[submask ^ msk]);
}
}
System.out.println(finaldp[(1 << sz) - 1]);
}
public static void main(String[] args) {
int t = scan.nextInt();
for(; t > 0; t--) {
solve();
}
}
}```
Editorialist''s Solution
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define dd double
#define endl "\n"
#define pb push_back
#define all(v) v.begin(),v.end()
#define mp make_pair
#define fi first
#define se second
#define vll vector<ll>
#define pll pair<ll,ll>
#define fo(i,n) for(int i=0;i<n;i++)
#define fo1(i,n) for(int i=1;i<=n;i++)
ll mod=1000000007;
ll n,k,t,m,q,flag=0;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(a) -- no. of elements strictly less than a
// s.find_by_order(i) -- itertor to ith element (0 indexed)
ll min(ll a,ll b){if(a>b)return b;else return a;}
ll max(ll a,ll b){if(a>b)return a;else return b;}
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#ifdef NOOBxCODER
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#else
#define NOOBxCODER 0
#endif
vector<int>subs[1024];
fo(i,1024){
fo(j,1024){
if((j&i) == i)subs[j].pb(i);
}
}
cin>>t;
//t=1;
while(t--){
cin>>n;
string s;
cin>>s;
int cnt[10]={0}; fo(i,n)cnt[s[i]-'0']++;
int a[n+1];a[0]=0;
fo(i,n)cin>>a[i+1];
int root[10]; fo(i,10)root[i]=i;
fo(i,n/2){
if(root[s[i]-'0'] == root[s[n-1-i]-'0']);
else{
int val1 =root[s[i]-'0'];
fo(j,10)if(root[j] ==val1)root[j]=root[s[n-1-i]-'0'];
}
}
//fo(i,10)cout<<root[i]<<" ";cout<<endl;
int num[1024]={0};
fo(i,10)fo(j,1024){if(((1<<i)& j) == (1<<i))num[j]+=cnt[i]; }
int ans[1024];
int isset[1024]={0};isset[0]=1;
fo(i,10){
int val=0;
fo(j,10)if(root[j]==i)val+=(1<<j);
isset[val]=1;
//cout<<val<<endl;
}
fo(i,1024){
if(subs[i].size() <=2){ans[i]=0; continue;}
ans[i]=1e8;
fo(j,subs[i].size()){
ans[i]= min(ans[i] , ans[subs[i][j]] + ans[i - subs[i][j]] + a[num[subs[i][j]]]);
}
}
//cout<<cnt[1]<<cnt[2]<<num[2]<<" "<<num[4]<<" "<<num[6]<<endl;
//cout<<ans[1023]<<endl;
fo(i,1024){
if(subs[i].size() <=2){ans[i]=0; continue;}
fo(j,subs[i].size()){
if(isset[subs[i][j]] ==1 && isset[i - subs[i][j]] ==1){
isset[i]=1;
ans[i]= min(ans[i] , ans[subs[i][j]] + ans[i - subs[i][j]] );
}
}
}
cout<<ans[1023]<<endl;
}
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
return 0;
}