MODCIRC - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Anton Trygub
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Dynamic Programming

PROBLEM:

You are given an array A = [A_1, A_2, \ldots, A_N] containing distinct positive integers.

Let B be a permutation of A. Define the value of B to be

\sum_{i=1}^N (B_i \bmod{B_{i+1}})

where B_{N+1} is treated to be B_1.

Find the maximum value across all permutations of A.

EXPLANATION:

For any permutation B, we can see that B_i \mod B_{i + 1} \le B_i, and B_i \mod B_{i + 1} < B_i if and only if B_i > B_{i + 1}. We say that the pair (B_i, B_{i + 1}) is bad, and it reduces the value of the permutation by B_i - B_i \mod B_{i + 1}; at the end, the value of the permutation B is just sum of all elements in B minus the reduced values of all bad pairs in B. Therefore, our goal now switches to finding a permutation such that the sum of values of bad pairs is as small as possible.

Sort A so that A_1 < A_2 < \dots < A_N. We see that any permutation must have the bad pairs satisfying these conditions:

  • A_N is in exactly one bad pair as the front element. This is because since the array is cyclic, A_N always precedes some element.
  • A_1 is in exactly one bad pair as the back element.
  • There must be a sequence of bad pairs [A_1, A_{R_1}], [A_{L_2}, A_{R_2}], \dots, [A_{L_k}, A_1] where R_i \le L_{i + 1} for all 1 \le i < k. This is because A_{L_{i + 1}} always goes after some element, so we either need A_{L_{i + 1}} to be the back element of some bad pair (which means A_{R_i} = A_{L_{i + 1}}, or there is another bad pair covering A_{L_{i + 1}} (i.e. A_{R_i} < A_{L_{i + 1}}).

We can also prove that from any set of bad pairs satisfying such conditions, we can create a permutation with only those bad pairs. Therefore the problem now becomes finding such a set of bad pairs that satisfies the said conditions. We can do a dynamic programming solution as follow:

Let dp_i be the smallest sum of values of a set of bad pairs [A_1, A_{R_1}], [A_{L_2}, A_{R_2}], \dots, [A_{L_k}, A_i] that satisfies the condition, then

dp_i = \min_{j = i + 1}^{n} (\min_{k = i + 1}^{j} dp_k) + (A_j - A_j \mod A_i)

where dp_N = 0. We can implement this formula easily in O(N^2).

TIME COMPLEXITY:

Time complexity is O(N^2) per test case.

SOLUTION:

Preparer's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<ll> a(n);
		for (auto& x : a) cin >> x;
		sort(begin(a), end(a));

		ll ans = accumulate(begin(a), end(a), 0LL);
		vector<ll> dist(n, numeric_limits<ll>::max()/3);
		dist[n-1] = 0;
		vector<int> mark(n);
		for (int i = 0; i < n; ++i) {
			int u = -1;
			ll mndist = numeric_limits<ll>::max();
			for (int j = 0; j < n; ++j) {
				if (mark[j]) continue;
				if (dist[j] < mndist) {
					mndist = dist[j];
					u = j;
				}
			}
			mark[u] = 1;
			for (int j = 0; j < n; ++j) {
				if (mark[j]) continue;
				ll wt = (j < u ? a[u] - (a[u] % a[j]) : 0);
				if (dist[j] > mndist + wt) {
					dist[j] = mndist + wt;
				}
			}
		}
		cout << ans - dist[0] << '\n';
	}
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
ll n;
ll a[2001];
ll dp[2001];
void solve(){
	cin >> n;
	ll s=0;
	for(int i=1; i<=n ;i++){
		cin >> a[i];dp[i]=8e18;
		s+=a[i];
	}
	sort(a+1,a+n+1);reverse(a+1,a+n+1);
	dp[1]=0;
	for(int i=2; i<=n ;i++){
		for(int j=1; j<i ;j++) dp[i]=min(dp[i],dp[j]+a[j]-a[j]%a[i]);
		for(int j=i; j>1 ;j--) dp[i-1]=min(dp[i-1],dp[i]);
	}
	cout << s-dp[n] << '\n';
}
int main(){
	ios::sync_with_stdio(false);
	int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    const long long inf = numeric_limits<long long>::max() / 2;
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<long long> a(n), dp(n, -inf);
        for (int i = 0; i < n; i++) {
            cin >> a[i];
        }
        sort(a.begin(), a.end(), greater<long long>());
        dp[0] = 0;
        for (int i = 1; i < n; i++) {
            long long mp = -inf;
            for (int j = i - 1; j >= 0; j--) {
                mp = max(mp, dp[j]);
                dp[i] = max(dp[i], mp + a[j] % a[i] - a[j]);
            }
        }
        cout << accumulate(a.begin(), a.end(), 0LL) + dp[n - 1] << '\n';
    }
}