SWTCHNG - Editorial

Links:
Division 1
Division 2
Division 3

Setters : Harsh Sharma , Jaydeep Macchi, Ronels Macwan

Testers : Manan Grover, Samarth Gupta,

Difficulty
Medium - Hard

Prerequisites
Dynamic programming, Bitmasks

Problem
You have to eat the rasgullas in such an order that you maximize the sum
of the taste of all the rasgullas that you eat.Each rasgulla has 0 taste initially. As you eat a rasgulla at coordinate i, all rasgullas in the range [i-d[i] , i+d[i] ] , their taste increases by c[i]. What is the maximum possible sum of taste you can achieve ?

Quick Explanation
Notice that d[i] is quite small. You can brute force all possible
permutations of position which you can eat. Form a two dimensional dp , where dp[i][j] represents the maximum taste we can get by eating first i rasgullas where we can move at most j units to the left , right.

Explanation :
It is always optimal to eat all the rasgullas but the problem is in which order Vedant should eat them so that happiness would be maximized.

If n < 7, then we can do brute-force on all possible n! permutations and calculate the maximum possible happiness.

Consider the case when n \ge 7. Since d[i] \le 7 any rasgulla at index$ i$ can affect the taste of at most 7 rasgullas present at left of it. Hence the only thing that matters for index i is the relative order of the last 7 elements left to it and it’s position with respect to these 7 elements (out of 8 possible positions).

Let dp[i][mask] denote the maximum possible happiness by considering the first i elements and order of last 7 elements is the same as present in mask. (You can use some sort of mapping to map masks with 7! permutations).

Let’s say we are computing an answer for first i+1 elements from the answer of first i elements. Iterate through all possible masks which will denote the relative order of eating of rasgullas in range [i-6, i]. Now, there are a total of 7 possibilities when to eat (i+1)^{th} rasgulla with respect to these [i-6,i] rasgullas. Iterate through all such possible positions for (i+1)^{th} rasgulla.

Now, dp[i][mask] will already have an answer for first i elements. Now, we need to add what (i+1)^{th} element will add when it will be appended in some order. Let’s denote this added value as add(mask, position). To calculate this since d[i] \leq 7 we need to take care of it’s relative order with respect to only [i-6,i] elements. Once we fix the relative position of (i+1)^{th} element, we know which rasgullas are eaten before and after (i+1)^{th} rasgulla, if they fall in a
valid range then some amount will be added to the answer which we are computing for first (i+1) rasgulla. When we fixed the relative position for (i+1)^{th} rasgulla, we also fixed the relative position of [i-5, i+1] rasgullas. Based on relative ordering of [i-5,i+1] rasgullas compute the mask and call it as newMask(mask, position).

Hence dp[i+1][newMask(mask, position)] = max(dp[i][mask] + add(mask, position)) for all possible positions where add(mask, position) is the value which will be added due to position of (i+1)^{th} rasgulla.

Final answer would be :
answer = max(answer, dp[n][mask]) , for all possible 7! masks.

Solutions

Setter’s Code :

#include<bits/stdc++.h>
using namespace std;

#define int long long

const int N = 105;
const int M = 5140;
const int Dmax = 7;

int n, m, d[N], c[N];
int dp[N][M];
int lastCnt;
vector<int> permutation;
vector<vector<int>> allPermutations;
unordered_map<int, int> numToIdx;

int permutationIdx(vector<int> &v)
{
	int answer = 0;
	for (int i = 0; i < v.size(); i++)
	{
		answer = (answer * 10 + v[i]);
	}
	return answer;
}

signed main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);  cout.tie(NULL);

	cin >> n;
	assert(n >= 1 && n <= 100);

	for (int i = 0; i < n; i++)
	{
		cin >> d[i];
		assert(d[i] >= 0 && d[i] <= Dmax);
	}

	for (int i = 0; i < n; i++)
	{
		cin >> c[i];
		assert(c[i] >= 0 && c[i] <= 1000);
	}

	lastCnt = min(Dmax, n);

	vector<int> currPermutation;
	for (int i = 0; i < lastCnt; i++)
	{
		currPermutation.push_back(i);
	}

	do
	{
		numToIdx[permutationIdx(currPermutation)] = allPermutations.size();
		allPermutations.push_back(currPermutation);
	} while (next_permutation(currPermutation.begin(), currPermutation.end()));

	m = allPermutations.size();
	// initial condition for first 'lastCnt' elements

	for (int idx = 0; idx < m; idx++)
	{

		permutation = allPermutations[idx];

		// x will be eaten first
		for (int x = 0; x < lastCnt; x++)
		{
			// y will be eaten after x
			for (int y = x + 1; y < lastCnt; y++)
			{
				// check if x will add anything to y
				if (abs(permutation[x] - permutation[y]) <= d[permutation[x]])
				{
					dp[lastCnt - 1][idx] += c[permutation[x]];
				}
			}
		}
	}

	// now solve iteratively for other indices
	for (int position = lastCnt; position < n; position++)
	{
		for (int idx = 0; idx < m; idx++)
		{
			permutation = allPermutations[idx];
			for (int myRank = 0; myRank <= lastCnt; myRank++)
			{
				int curr = dp[position - 1][idx];


				int order[lastCnt + 1];
				memset(order, -1, sizeof(order));
				order[myRank] = position;

				int cnt = 0;
				for (int x = 0; x <= lastCnt; x++)
				{
					if (order[x] == -1)
					{
						order[x] = position - (lastCnt - permutation[cnt++]);
					}
				}

				for (int x = 0; x <= myRank - 1; x++)
				{
					if (abs(order[x] - order[myRank]) <= d[order[x]])
					{
						curr += c[order[x]];
					}
				}

				for (int x = myRank + 1; x <= lastCnt; x++)
				{
					if (abs(order[x] - order[myRank]) <= d[order[myRank]])
					{
						curr += c[order[myRank]];
					}
				}

				vector<int> nextState;
				for (int x = 0; x <= lastCnt; x++)
				{
					if (position - order[x] > lastCnt - 1) continue;
					nextState.push_back(lastCnt - 1 - (position - order[x]));
				}

				// permutation formed by last lastCnt elements
				int nextIdx = numToIdx[permutationIdx(nextState)];
				dp[position][nextIdx] = max(dp[position][nextIdx], curr);
			}
		}
	}

	int answer = 0;

	for (int idx = 0; idx < m; idx++)
	{
		answer = max(answer, dp[n - 1][idx]);
	}

	cout << answer << '\n';

	return 0;
}

Tester’s Code :

#include <bits/stdc++.h>
using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}
 
void readEOF(){
    assert(getchar()==EOF);
}
int dp[101][5040];
vector<int> mp[5040];
map<vector<int>, int> rev_mp;
void pre(){
    vector<int> arr(7);
    for(int i = 0; i < 7 ; i++)
        arr[i] = i;
    int cnt = 0;
    do{
        mp[cnt] = arr;
        rev_mp[arr] = cnt;
        cnt++;
    }while(next_permutation(arr.begin(), arr.end()));
}
int cal(vector<int> &d, vector<int> &c, vector<int> msk){
    map<int, int> m;
    for(int i = 0; i < 7 ; i++)
        m[msk[i]] = i;
    int sweet = 0;
    vector<int> col(7);
    for(int i = 0; i < 7 ; i++){
        sweet += col[msk[i]];
        int dis = d[msk[i]];
        // [msk[i] - dis, msk[i] + dis]
        for(int j = msk[i] - dis ; j <= msk[i] + dis ; j++){
            if(m.find(j) == m.end())
                continue;
            col[j] += c[msk[i]];
        }
    }
    return sweet;
}
int main() {
	// your code goes here
	pre();
	int t = 1;
	while(t--){
	    int m = readIntLn(1, 100);
	    vector<int> d(m), c(m);
	    for(int i = 0; i < m ; i++){
	        if(i == m - 1)
	            d[i] = readIntLn(0, 7);
	        else
	            d[i] = readIntSp(0, 7);
	    }
	    for(int i = 0 ; i < m ; i++){
	        if(i == m - 1)
	            c[i] = readIntLn(0, 1000);
	        else
	            c[i] = readIntSp(0, 1000);
	    }
	    // dp[i][msk] = max(dp[i-1][nmsk] + add(mask, pos);
	    if(m <= 6){ // brute force
	        vector<int> arr(m);
	        for(int i = 0; i < m ; i++)
	            arr[i] = i;
	        int ans = 0;
	        do{
	            vector<int> col(m);
	            int sweet = 0;
	            for(int i = 0; i < m ; i++){
	                sweet += col[arr[i]];
	                int dis = d[arr[i]];
	                for(int j = max(0, arr[i] - dis) ; j <= min(m - 1, arr[i] + dis) ; j++)
	                    col[j] += c[arr[i]];
	            }
	            ans = max(ans, sweet);
	        }while(next_permutation(arr.begin(), arr.end()));
	        cout << ans << '\n';
	        continue;
	    }
	    for(int n = 7 ; n <= m ; n++){
    	    for(int msk = 0 ; msk < 5040 ; msk++){
    	        if(n == 7)
    	            dp[n][msk] = cal(d, c, mp[msk]);
    	        else{
    	            vector<int> get_msk = mp[msk];
    	            vector<int> new_msk(7, 0);
    	            int l = 1, idx = -1;
    	            for(int j = 0; j < 7 ; j++){
    	                if(get_msk[j] == 6){
    	                    idx = j;
    	                    continue;
    	                }
    	                new_msk[l] = get_msk[j] + 1;
    	                l++;
    	            }
    	            idx++;
    	            for(int pos = 0; pos < 7 ; pos++){
    	                int cont = (pos < idx ? c[n - 8]*(d[n - 8] == 7) : c[n - 1]*(d[n - 1] == 7));
    	                if(idx - pos == 1)
    	                    cont = max(c[n - 8]*(d[n - 8] == 7), c[n - 1]*(d[n - 1] == 7));
    	                dp[n][msk] = max(dp[n][msk], dp[n-1][rev_mp[new_msk]] + cont);
    	                if(pos != 6)
    	                    swap(new_msk[pos], new_msk[pos + 1]);
    	            }
    	            idx--;
    	            // Fixed part starts
    	            int sweet = 0;
    	            for(int j = 0; j < idx ; j++){
    	                int dis = d[n + get_msk[j] - 7];
    	                if(get_msk[j] + dis >= 6)
    	                    sweet += c[n + get_msk[j] - 7];
    	            }
    	            for(int j = idx + 1 ; j < 7 ; j++){
    	                if(get_msk[j] + d[n - 1] >= 6)
    	                    sweet += c[n - 1];
    	            }
    	            // Fixed part ends
    	            dp[n][msk] += sweet;
    	        }
    	    }
	    }
	    int ans = 0;
	    for(int msk = 0 ; msk < 5040 ; msk++){
	        ans = max(ans, dp[m][msk]);
	    }
	    cout << ans << '\n';
	}
	readEOF();
}

Time Complexity

Let’s say D = 7 (maximum possible value of di).
Time complexity will be O(n \cdot D! \cdot D^{2})

1 Like