# SWTCHNG - Editorial

Setters : Harsh Sharma , Jaydeep Macchi, Ronels Macwan

Testers : Manan Grover, Samarth Gupta,

Difficulty
Medium - Hard

Prerequisites

Problem
You have to eat the rasgullas in such an order that you maximize the sum
of the taste of all the rasgullas that you eat.Each rasgulla has 0 taste initially. As you eat a rasgulla at coordinate i, all rasgullas in the range [i-d[i] , i+d[i] ] , their taste increases by c[i]. What is the maximum possible sum of taste you can achieve ?

Quick Explanation
Notice that d[i] is quite small. You can brute force all possible
permutations of position which you can eat. Form a two dimensional dp , where dp[i][j] represents the maximum taste we can get by eating first i rasgullas where we can move at most j units to the left , right.

Explanation :
It is always optimal to eat all the rasgullas but the problem is in which order Vedant should eat them so that happiness would be maximized.

If n < 7, then we can do brute-force on all possible n! permutations and calculate the maximum possible happiness.

Consider the case when n \ge 7. Since d[i] \le 7 any rasgulla at index$i$ can affect the taste of at most 7 rasgullas present at left of it. Hence the only thing that matters for index i is the relative order of the last 7 elements left to it and it’s position with respect to these 7 elements (out of 8 possible positions).

Let dp[i][mask] denote the maximum possible happiness by considering the first i elements and order of last 7 elements is the same as present in mask. (You can use some sort of mapping to map masks with 7! permutations).

Let’s say we are computing an answer for first i+1 elements from the answer of first i elements. Iterate through all possible masks which will denote the relative order of eating of rasgullas in range [i-6, i]. Now, there are a total of 7 possibilities when to eat (i+1)^{th} rasgulla with respect to these [i-6,i] rasgullas. Iterate through all such possible positions for (i+1)^{th} rasgulla.

Now, dp[i][mask] will already have an answer for first i elements. Now, we need to add what (i+1)^{th} element will add when it will be appended in some order. Let’s denote this added value as add(mask, position). To calculate this since d[i] \leq 7 we need to take care of it’s relative order with respect to only [i-6,i] elements. Once we fix the relative position of (i+1)^{th} element, we know which rasgullas are eaten before and after (i+1)^{th} rasgulla, if they fall in a
valid range then some amount will be added to the answer which we are computing for first (i+1) rasgulla. When we fixed the relative position for (i+1)^{th} rasgulla, we also fixed the relative position of [i-5, i+1] rasgullas. Based on relative ordering of [i-5,i+1] rasgullas compute the mask and call it as newMask(mask, position).

Solutions

Setter’s Code :

#include<bits/stdc++.h>
using namespace std;

#define int long long

const int N = 105;
const int M = 5140;
const int Dmax = 7;

int n, m, d[N], c[N];
int dp[N][M];
int lastCnt;
vector<int> permutation;
vector<vector<int>> allPermutations;
unordered_map<int, int> numToIdx;

int permutationIdx(vector<int> &v)
{
for (int i = 0; i < v.size(); i++)
{
}
}

signed main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);  cout.tie(NULL);

cin >> n;
assert(n >= 1 && n <= 100);

for (int i = 0; i < n; i++)
{
cin >> d[i];
assert(d[i] >= 0 && d[i] <= Dmax);
}

for (int i = 0; i < n; i++)
{
cin >> c[i];
assert(c[i] >= 0 && c[i] <= 1000);
}

lastCnt = min(Dmax, n);

vector<int> currPermutation;
for (int i = 0; i < lastCnt; i++)
{
currPermutation.push_back(i);
}

do
{
numToIdx[permutationIdx(currPermutation)] = allPermutations.size();
allPermutations.push_back(currPermutation);
} while (next_permutation(currPermutation.begin(), currPermutation.end()));

m = allPermutations.size();
// initial condition for first 'lastCnt' elements

for (int idx = 0; idx < m; idx++)
{

permutation = allPermutations[idx];

// x will be eaten first
for (int x = 0; x < lastCnt; x++)
{
// y will be eaten after x
for (int y = x + 1; y < lastCnt; y++)
{
// check if x will add anything to y
if (abs(permutation[x] - permutation[y]) <= d[permutation[x]])
{
dp[lastCnt - 1][idx] += c[permutation[x]];
}
}
}
}

// now solve iteratively for other indices
for (int position = lastCnt; position < n; position++)
{
for (int idx = 0; idx < m; idx++)
{
permutation = allPermutations[idx];
for (int myRank = 0; myRank <= lastCnt; myRank++)
{
int curr = dp[position - 1][idx];

int order[lastCnt + 1];
memset(order, -1, sizeof(order));
order[myRank] = position;

int cnt = 0;
for (int x = 0; x <= lastCnt; x++)
{
if (order[x] == -1)
{
order[x] = position - (lastCnt - permutation[cnt++]);
}
}

for (int x = 0; x <= myRank - 1; x++)
{
if (abs(order[x] - order[myRank]) <= d[order[x]])
{
curr += c[order[x]];
}
}

for (int x = myRank + 1; x <= lastCnt; x++)
{
if (abs(order[x] - order[myRank]) <= d[order[myRank]])
{
curr += c[order[myRank]];
}
}

vector<int> nextState;
for (int x = 0; x <= lastCnt; x++)
{
if (position - order[x] > lastCnt - 1) continue;
nextState.push_back(lastCnt - 1 - (position - order[x]));
}

// permutation formed by last lastCnt elements
int nextIdx = numToIdx[permutationIdx(nextState)];
dp[position][nextIdx] = max(dp[position][nextIdx], curr);
}
}
}

for (int idx = 0; idx < m; idx++)
{
}

return 0;
}

Tester’s Code :

#include <bits/stdc++.h>
using namespace std;

long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
}
long long readIntLn(long long l, long long r) {
}
string readStringLn(int l, int r) {
}
string readStringSp(int l, int r) {
}

assert(getchar()==EOF);
}
int dp[101][5040];
vector<int> mp[5040];
map<vector<int>, int> rev_mp;
void pre(){
vector<int> arr(7);
for(int i = 0; i < 7 ; i++)
arr[i] = i;
int cnt = 0;
do{
mp[cnt] = arr;
rev_mp[arr] = cnt;
cnt++;
}while(next_permutation(arr.begin(), arr.end()));
}
int cal(vector<int> &d, vector<int> &c, vector<int> msk){
map<int, int> m;
for(int i = 0; i < 7 ; i++)
m[msk[i]] = i;
int sweet = 0;
vector<int> col(7);
for(int i = 0; i < 7 ; i++){
sweet += col[msk[i]];
int dis = d[msk[i]];
// [msk[i] - dis, msk[i] + dis]
for(int j = msk[i] - dis ; j <= msk[i] + dis ; j++){
if(m.find(j) == m.end())
continue;
col[j] += c[msk[i]];
}
}
return sweet;
}
int main() {
pre();
int t = 1;
while(t--){
vector<int> d(m), c(m);
for(int i = 0; i < m ; i++){
if(i == m - 1)
else
}
for(int i = 0 ; i < m ; i++){
if(i == m - 1)
else
}
if(m <= 6){ // brute force
vector<int> arr(m);
for(int i = 0; i < m ; i++)
arr[i] = i;
int ans = 0;
do{
vector<int> col(m);
int sweet = 0;
for(int i = 0; i < m ; i++){
sweet += col[arr[i]];
int dis = d[arr[i]];
for(int j = max(0, arr[i] - dis) ; j <= min(m - 1, arr[i] + dis) ; j++)
col[j] += c[arr[i]];
}
ans = max(ans, sweet);
}while(next_permutation(arr.begin(), arr.end()));
cout << ans << '\n';
continue;
}
for(int n = 7 ; n <= m ; n++){
for(int msk = 0 ; msk < 5040 ; msk++){
if(n == 7)
dp[n][msk] = cal(d, c, mp[msk]);
else{
vector<int> get_msk = mp[msk];
vector<int> new_msk(7, 0);
int l = 1, idx = -1;
for(int j = 0; j < 7 ; j++){
if(get_msk[j] == 6){
idx = j;
continue;
}
new_msk[l] = get_msk[j] + 1;
l++;
}
idx++;
for(int pos = 0; pos < 7 ; pos++){
int cont = (pos < idx ? c[n - 8]*(d[n - 8] == 7) : c[n - 1]*(d[n - 1] == 7));
if(idx - pos == 1)
cont = max(c[n - 8]*(d[n - 8] == 7), c[n - 1]*(d[n - 1] == 7));
dp[n][msk] = max(dp[n][msk], dp[n-1][rev_mp[new_msk]] + cont);
if(pos != 6)
swap(new_msk[pos], new_msk[pos + 1]);
}
idx--;
// Fixed part starts
int sweet = 0;
for(int j = 0; j < idx ; j++){
int dis = d[n + get_msk[j] - 7];
if(get_msk[j] + dis >= 6)
sweet += c[n + get_msk[j] - 7];
}
for(int j = idx + 1 ; j < 7 ; j++){
if(get_msk[j] + d[n - 1] >= 6)
sweet += c[n - 1];
}
// Fixed part ends
dp[n][msk] += sweet;
}
}
}
int ans = 0;
for(int msk = 0 ; msk < 5040 ; msk++){
ans = max(ans, dp[m][msk]);
}
cout << ans << '\n';
}