P149 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

You are given an array A of length N.
You start at index 1.
Each second, the following happens:

  • If you’re standing at index x, you can either move to index x+1, or choose an index i such that i \leq x, and set A_i = 0.
  • Then, add sum(A) to your score.

Find the minimum possible final score.

EXPLANATION:

Consider some index i.
A_i will be added to the score once each time:

  • The magician hasn’t yet reached index i, or
  • The magician is moving beyond i, but hasn’t yet set A_i to 0, or
  • The magician sets some A_j to 0, but hasn’t yet set A_i to 0.

The first one among these is trivial: the magician needs i-1 steps to reach index i for the first time.
So, we add (i-1)\cdot A_i to the answer for every i.
Now, we only need to worry about what happens once the magician is already at index i or beyond.

We also make the following observation: whenever the magician chooses to set some A_i to 0, it’s always optimal to choose the maximum existing one to do so (if there are multiple maximums, let’s break ties by choosing the leftmost among them).


Let dp_i denote the answer for the suffix starting from index i.
To compute this, let’s try fixing k \geq i to be the index where the magician sets A_i = 0.
Then,

  • A cost of (k-i)\cdot A_i is incurred by the movement.
  • For every index j such that i \lt j \leq k and A_j \gt A_i, it’s optimal to delete this index before we delete A_i.
    So, each such index “delays” the deletion by one step.
    If there are x such indices, we incur an additional x\cdot A_i damage.
    • x is easily maintained by just processing k in increasing order.
  • For every other index (including those past k), A_i will be deleted before them, so they will all get delayed instead.
    This incurs an additional cost of S, where S is the sum of values of all these indices.
    • S is also easily maintained if k is processed in increasing order, with the additional help of suffix sums on array A.

Let M be the minimum cost across all k \geq i.
We then simply have dp_i = M + dp_{i+1}.

Note that the fact that we can take dp_{i+1} itself requires a proof of correctness - we’re essentially saying that the optimal deletion index we find for i will not conflict with the already existing optimal solution (note that when fixing k, we assumed that every delay was made optimally - as in the smaller value being delayed by a larger one).

Proof

Consider two indices i \lt j.
Let A_i \leq A_j, and their optimal deletion positions found via the given algorithm be r_i and r_j.
We’ll prove that there’s no conflict between them - that is, either r_i \lt j, or r_i \geq r_j.

Consider some index k between j and r_j.
When at index i, comparing the costs between choosing k and r_j as the deletion index:

  • Indices \leq k and \gt r_j contribute the same cost to both, so they can be ignored.
  • For k \lt x \leq r_j, we get a cost of A_x when looking at k, and \min(A_x, A_i) when looking at r_j.
    In particular, if A_x \leq A_i then the contribution to both is the same, so again we can ignore such indices.
  • When looking at r_j, we get an extra (r_j - k)\cdot A_i from the movement costs.

So, if there are m_1 indices between k+1 and r_j containing values \gt A_i, and the sum of values at these indices is s_1, our claim will be satisfied iff we can show that
(r_j-k+m_1)\cdot A_i \leq s_1

To that end, let’s perform the same analysis with indices k and r_j, but for the index j this time.
A similar argument tells us that we want to compare the quantities (k-x+m_2)\cdot A_j and s_2, where:

  • m_2 is the number of values at indices between k+1 and r_j that are \gt A_j.
  • s_2 is the sum of the above values.

Now, we already know r_j is optimal for index j. This means
(r_j-k+m_2)\cdot A_j \leq s_2

Since A_j \geq A_i, this tells us that (r_j-k+m_2)\cdot A_i \leq s_2.

From here, observe that we reach (r_j-k+m_1)\cdot A_i in the LHS by adding one copy of A_i for each element that’s \gt A_i but \leq A_j.
Meanwhile, in the RHS we reach s_1 from s_2 by adding the value of each element that’s \gt A_i but \leq A_j.
So, an equal number of terms are added to both sides; but each element added to the LHS is smaller than each added to the RHS.
This maintains the inequality, and so we obtain (r_j-k+m_1)\cdot A_i \leq s_1 as desired.

A similar proof can be made for the case when i \lt j and A_i \gt A_j.

The end result is that the optimal deletion points found by our algorithm don’t conflict, as we wanted.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

int main() {
	ll tt=1;
    cin>>tt;
    while(tt--){
        ll n;
        cin>>n;
        ll a[n+1];
        ll ans=0;
        for(int i=1;i<=n;i++){
            cin>>a[i];
            ans+=a[i]*(i-1);
        }
        ll suf[n+2]={};
        for(int i=n;i>=1;i--){
            suf[i]=suf[i+1]+a[i];
        }
        ll cnt;
        ll sum;
        for(int i=1;i<=n;i++){
            cnt=suf[i+1];sum=0;
            for(int j=i+1;j<=n;j++){
                sum+=a[i]+min(a[i],a[j]);
                cnt=min(cnt,sum+suf[j+1]);
            }
            ans+=cnt;
        }
        cout<<ans<<"\n";
    }
    return 0;
}

Tester's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;

int smn = 0;

 
void solve()
{
	int n;
	cin >> n;
	n = inp.readInt(1,5000);
	smn += n;
	inp.readEoln();
	vi a(n);
	// take(a,n);
	repin{
		cin >> a[i];
		a[i] = inp.readInt(1,1000'000'000);
		if(i == n-1)inp.readEoln();
		else inp.readSpace();
	}
	vector<int> dp(n,INF);
	int ans = 0;
	repin{
		ans += i*a[i];
	}
	vector<int> sf = a;
	rrep(i,n-2,0)sf[i] += sf[i+1];
	repin{
		int cnt = 0;
		int sm = 0;
		rep(j,i,n){
			if(a[j] > a[i])cnt++;
			else if(j != i)sm += a[j];
			dp[i] = min(dp[i],(j+1<n?sf[j+1]:0) + sm + cnt*a[i] + (j-i)*a[i]);
		}
	}

	repin{
		ans += dp[i];
	}

	cout << ans << "\n";
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
        // int t; cin >> t; while(t--)
        int t = inp.readInt(1,1000);inp.readEoln();
        while(t--)
        solve();
        assert(smn <= 5000);
        inp.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    suf = a[:] + [0] 
    dp = [0]*(n+1)

    for i in reversed(range(n)):
        suf[i] = suf[i+1] + a[i]
        more, lesssm = 0, 0
        dp[i] = suf[i+1]
        for k in range(i+1, n):
            if a[k] > a[i]: more += 1
            else: lesssm += a[k]
            dp[i] = min(dp[i], (k-i+more)*a[i] + lesssm + suf[k+1])
        dp[i] += dp[i+1]
    print(dp[0] + sum(i*a[i] for i in range(n)))

hey @iceknight1093 , do you make videos for editorials? You explanations are well written, but videos make me understand easily! or you know someone who makes videos of this hard type questions?

I don’t make videos myself, and I also don’t really watch them so I don’t know if anyone makes proper video editorials for hard problems.
I’m sure you can find plenty of screencasts and quicker solution discussions though - I know aryanc403 makes them whenever he’s able to participate.

Yes got it. Btw aryan403 is not making codechef video since last 2 contests.:smiling_face_with_tear: