PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Anton Trygub
Tester: Harris Leung
Editorialist: Trung Dang
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Dynamic Programming
PROBLEM:
You are given an array A = [A_1, A_2, \ldots, A_N] containing distinct positive integers.
Let B be a permutation of A. Define the value of B to be
\sum_{i=1}^N (B_i \bmod{B_{i+1}})
where B_{N+1} is treated to be B_1.
Find the maximum value across all permutations of A.
EXPLANATION:
For any permutation B, we can see that B_i \mod B_{i + 1} \le B_i, and B_i \mod B_{i + 1} < B_i if and only if B_i > B_{i + 1}. We say that the pair (B_i, B_{i + 1}) is bad, and it reduces the value of the permutation by B_i - B_i \mod B_{i + 1}; at the end, the value of the permutation B is just sum of all elements in B minus the reduced values of all bad pairs in B. Therefore, our goal now switches to finding a permutation such that the sum of values of bad pairs is as small as possible.
Sort A so that A_1 < A_2 < \dots < A_N. We see that any permutation must have the bad pairs satisfying these conditions:
- A_N is in exactly one bad pair as the front element. This is because since the array is cyclic, A_N always precedes some element.
- A_1 is in exactly one bad pair as the back element.
- There must be a sequence of bad pairs [A_1, A_{R_1}], [A_{L_2}, A_{R_2}], \dots, [A_{L_k}, A_1] where R_i \le L_{i + 1} for all 1 \le i < k. This is because A_{L_{i + 1}} always goes after some element, so we either need A_{L_{i + 1}} to be the back element of some bad pair (which means A_{R_i} = A_{L_{i + 1}}, or there is another bad pair covering A_{L_{i + 1}} (i.e. A_{R_i} < A_{L_{i + 1}}).
We can also prove that from any set of bad pairs satisfying such conditions, we can create a permutation with only those bad pairs. Therefore the problem now becomes finding such a set of bad pairs that satisfies the said conditions. We can do a dynamic programming solution as follow:
Let dp_i be the smallest sum of values of a set of bad pairs [A_1, A_{R_1}], [A_{L_2}, A_{R_2}], \dots, [A_{L_k}, A_i] that satisfies the condition, then
dp_i = \min_{j = i + 1}^{n} (\min_{k = i + 1}^{j} dp_k) + (A_j - A_j \mod A_i)
where dp_N = 0. We can implement this formula easily in O(N^2).
TIME COMPLEXITY:
Time complexity is O(N^2) per test case.
SOLUTION:
Preparer's Solution
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<ll> a(n);
for (auto& x : a) cin >> x;
sort(begin(a), end(a));
ll ans = accumulate(begin(a), end(a), 0LL);
vector<ll> dist(n, numeric_limits<ll>::max()/3);
dist[n-1] = 0;
vector<int> mark(n);
for (int i = 0; i < n; ++i) {
int u = -1;
ll mndist = numeric_limits<ll>::max();
for (int j = 0; j < n; ++j) {
if (mark[j]) continue;
if (dist[j] < mndist) {
mndist = dist[j];
u = j;
}
}
mark[u] = 1;
for (int j = 0; j < n; ++j) {
if (mark[j]) continue;
ll wt = (j < u ? a[u] - (a[u] % a[j]) : 0);
if (dist[j] > mndist + wt) {
dist[j] = mndist + wt;
}
}
}
cout << ans - dist[0] << '\n';
}
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
ll n;
ll a[2001];
ll dp[2001];
void solve(){
cin >> n;
ll s=0;
for(int i=1; i<=n ;i++){
cin >> a[i];dp[i]=8e18;
s+=a[i];
}
sort(a+1,a+n+1);reverse(a+1,a+n+1);
dp[1]=0;
for(int i=2; i<=n ;i++){
for(int j=1; j<i ;j++) dp[i]=min(dp[i],dp[j]+a[j]-a[j]%a[i]);
for(int j=i; j>1 ;j--) dp[i-1]=min(dp[i-1],dp[i]);
}
cout << s-dp[n] << '\n';
}
int main(){
ios::sync_with_stdio(false);
int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
const long long inf = numeric_limits<long long>::max() / 2;
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<long long> a(n), dp(n, -inf);
for (int i = 0; i < n; i++) {
cin >> a[i];
}
sort(a.begin(), a.end(), greater<long long>());
dp[0] = 0;
for (int i = 1; i < n; i++) {
long long mp = -inf;
for (int j = i - 1; j >= 0; j--) {
mp = max(mp, dp[j]);
dp[i] = max(dp[i], mp + a[j] % a[i] - a[j]);
}
}
cout << accumulate(a.begin(), a.end(), 0LL) + dp[n - 1] << '\n';
}
}