DISGUST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Saarang Srinivasan
Testers: Utkarsh Gupta, Hriday
Editorialist: Nishank Suresh

DIFFICULTY:

2903

PREREQUISITES:

Shortest-path algorithms (either Dijkstra or Floyd-Warshall)

PROBLEM:

You have an N\times N matrix A, whose disgust is defined to be \sum_{i=1}^N \sum_{j=1}^N (A_{i, j} - A_{j, i})^2.
You also have M changes, the i-th of which allows you to replace a single occurrence of x_i with y_i for some cost c_i.

Find the minimum possible sum of cost + disgust for A.

EXPLANATION:

Note that each pair of ‘opposite’ cells of A contribute to the answer independently. So, it is enough to solve the problem for a single pair of opposite cells, then apply this solution to every pair.

In other words, the problem can be restated as follows: you have 2 integers x_1 and y_1. You can apply some of the given operations on them to convert them to x_2 and y_2 respectively. Find the minimum possible value of cost + 2\cdot(x_2 - y_2)^2.

There are several different ways to do this, a few of which are detailed below.

Dijkstra

Let f(x, y) denote the minimum value assuming we start with x and y. If we are able to precompute this value for every pair of (x, y), the problem is obviously solved.

Note that f(x, y) can be defined as follows:

  • If no changes are made, the value is 2\cdot (x-y)^2
  • Otherwise, either x or y is changed.
    • Suppose x is changed to z with cost c. Then, we get a value of c + f(z, y)
    • Suppose y is changed to z with cost c. Then, we get a value of c + f(x, z)
  • Clearly, finding the minimum value of this across all z will give us the answer.

This formulation is suspiciously like a shortest path problem: we can think of f(x, y) as the ‘shortest path’ to reach (x, y), and changing either x or y equates to following an edge.
Also, the initial distance to vertex (x, y) is 2\cdot (x-y)^2 and not infinity.

A problem like this can be solved with the help of multisource dijkstra. For each 1 \leq x, y \leq N, set its distance to 2\cdot (x-y)^2 and insert it into a priority queue/set. Now, run the standard dijkstra algorithm, and the end result is the correctly computed values of f(x, y).
Notice that the edges in this case are reversed compared to what is given in the input, i.e, if the input allows you to convert i to j with a cost of c, then you create the edge j\to i with weight c.

There are N^2 vertices and \mathcal{O}(N) edges corresponding to each vertex, for a total of \mathcal{O}(N^3) edges. So, this runs in \mathcal{O}(N^3\log N) time.

The tester’s code below implements this.

Floyd-Warshall

Suppose we start with (x_1, y_1) and end at (x_2, y_2). There are only N different values of x_2, so let’s try each of them and try to find the optimal y_2 once x_2 is fixed.

The cost can be computed as follows:

  • First, the cost to change x_1 to x_2, say this is c(x_1, x_2).
  • Second, the cost to change y_1 to y_2, say c(y_1, y_2)
  • Finally, the value 2\cdot(x_2 - y_2)^2

Computing all the values of c(x, y) is not too hard: c(x, y) is exactly the shortest path from x to y in the graph defined by the input edges. Computing this for every (x, y) pair can be done in \mathcal{O}(N^3) using Floyd-Warshall.

Now, the above cost can be written as c(x_1, x_2) + 2x_2^2 + (c(y_1, y_2) - 4x_2y_2 + 2y_2^2). The first two terms are constants once x_1 and x_2 is fixed, and the third term is of the form kx_2 + m, where k and m are constants that depend on y_1 and y_2.

Minimizing such an expression can be done with the help of the convex hull trick by constructing N such containers and querying the one corresponding to y_1.
This allows us to compute the answer in \mathcal{O}(N\log N) for a given (x_1, y_1) pair, which is fast enough to solve the problem since there are only \mathcal{O}(N^2) such pairs. The editorialist’s code implements this approach.

However, there is also a solution that doesn’t need any fancy data structures.

Consider another function d(x, y), defined as follows:

  • Suppose one of our numbers is fixed to be x, and we now want to change y. What is the minimum cost of doing this?

By considering every possible value that y can be turned into, It is easy to see that

d(x, y) = \min_{z=1}^N (c(y, z) + 2\cdot(x-z)^2)

Once c(x, y) has been computed, d(x, y) can also be trivially computed in \mathcal{O}(N^3).

Now, to answer the query for (x_1, y_1):

  • Fix the value x_2 that x_1 will be turned into. This has cost c(x_1, x_2).
  • Now that x_2 is fixed, we need to convert y_1 to something else. By definition, the minimum cost here is exactly d(x_2, y_1).
  • So, the answer is the minimum of c(x_1, x_2) + d(x_2, y_1) across all 1 \leq x_2 \leq N.

This gives us a solution in \mathcal{O}(N^3).

TIME COMPLEXITY

\mathcal{O}(N^3) per test case.

CODE:

Setter's code (Floyd-Warshall, C++)
#include <bits/stdc++.h>
using namespace std;

void solve_case();

const long long inf = 1e16;
signed main() {
    std::ios::sync_with_stdio(0);
    std::cout.tie(0);
    std::cin.tie(0);
    int tt = 1;
    std::cin >> tt;
    while(tt--) {
        solve_case();
    }
    return 0;
}

void solve_case() {
    int n, m;
    cin >> n >> m;
    vector<vector<int>> a(n, vector<int>(n));
    vector<vector<long long>> dt(n, vector<long long>(n, inf)), f(n, vector<long long>(n, inf));
    for(int i = 0; i < n; i++) {
        dt[i][i] = 0;
        for(int j = 0; j < n; j++)
            cin >> a[i][j], a[i][j]--;
    }
    for(int x, y, z, i = 0; i < m; i++) {
        cin >> x >> y >> z; --x, --y;
        dt[x][y] = min(dt[x][y], (long long) z);
    }
    for(int j = 0; j < n; j++) 
        for(int i = 0; i < n; i++)
            for(int k = 0; k < n; k++)
                dt[i][k] = min(dt[i][k], dt[i][j] + dt[j][k]);

    for(int i = 0; i < n; i++)
        for(int j = 0; j < n; j++)
            for(int k = 0; k < n; k++)
                f[i][j] = min(f[i][j], dt[i][k] + 2 * (k - j) * (k - j)); //twice to unify cost/disgust

    long long ans = 0;
    for(int i = 0; i < n; i++)
        for(int j = i; j < n; j++) {
            long long best = inf;
            for(int k = 0; k < n; k++)
                best = min({best, f[a[i][j]][k] + dt[a[j][i]][k], f[a[j][i]][k] + dt[a[i][j]][k]});
            ans += best;
            //cout << i << ' ' << j << ' ' << a[i][j] + 1 << ' ' << a[j][i] + 1 << '\n';
            //cout << "best: " << best << '\n';
        }
    cout << ans << '\n';
}
Tester's code (Dijkstra, C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <pair<ll,ll>> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
int sumN=0,sumM=0;
int maxN=0,maxM=0;
int minN=1000;
ll maxMN=0;
ll sumMN=0;
ll distpairstotal=0;
ll distpairsmax=0;
void solve()
{
    int n,m;
    n=readInt(1,500,' ');
    m=readInt(1,100000,'\n');
    maxN=max(maxN,n);
    minN=min(minN,n);
    maxM=max(maxM,m);
    sumN+=n;
    sumM+=m;
    sumMN+=m*n;
    maxMN=max(maxMN,(ll)m*n);
    assert(sumN<=500);
    assert(sumM<=100000);
    int mat[n+1][n+1];
    set <pair<int,int>> sd;
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            if(j!=n)
                mat[i][j]=readInt(1,n,' ');
            else
                mat[i][j]=readInt(1,n,'\n');
        }
    }
    for(int i=1;i<=n;i++)
        adj[i].clear();
    for(int i=1;i<=m;i++)
    {
        ll x,y,z;
        x=readInt(1,n,' ');
        y=readInt(1,n,' ');
        z=readInt(1,1000000000,'\n');
        adj[y].pb(mp(x,z));
    }
    set <tuple<ll,ll,ll>> s;
    ll dist[n+1][n+1];
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            s.insert(make_tuple(2*(i-j)*(i-j),i,j));
            dist[i][j]=2*(i-j)*(i-j);
        }
    }
    while(!s.empty())
    {
        auto it=s.begin();
        auto t=(*it);
        s.erase(it);
        int i=get<1>(t);
        int j=get<2>(t);
        ll curr=get<0>(t);
        for(auto p:adj[i])
        {
            int k=p.first;
            ll add=p.second;
            if(dist[k][j]>(dist[i][j]+add))
            {
                s.erase(make_tuple(dist[k][j],k,j));
                dist[k][j]=dist[i][j]+add;
                s.insert(make_tuple(dist[k][j],k,j));
            }
        }
        for(auto p:adj[j])
        {
            int k=p.first;
            ll add=p.second;
            if(dist[i][k]>(dist[i][j]+add))
            {
                s.erase(make_tuple(dist[i][k],i,k));
                dist[i][k]=dist[i][j]+add;
                s.insert(make_tuple(dist[i][k],i,k));
            }
        }
    }
    ll ans=0;
    for(int i=1;i<=n;i++)
    {
        for(int j=i+1;j<=n;j++)
        {
            sd.insert(mp(mat[i][j],mat[j][i]));
            ans+=dist[mat[i][j]][mat[j][i]];
        }
    }
    distpairstotal+=sd.size();
    distpairsmax=max(distpairsmax,(ll)sd.size());
    cout<<ans<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,100,'\n');
    int tt=T;
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr<<"Total Tests: "<<tt<<'\n';
    cerr<<"Min N: "<<maxN<<'\n';
    cerr<<"Max N: "<<maxN<<'\n';
    cerr<<"Max M: "<<maxM<<'\n';
    cerr<<"Sum of N: "<<sumN<<'\n';
    cerr<<"Sum of M: "<<sumM<<'\n';
    cerr<<"Sum of MN: "<<sumMN<<'\n';
    cerr<<"Max MN: "<<maxMN<<'\n';
    cerr<<"Dist Pairs Total: "<<distpairstotal<<'\n';
    cerr<<"Max Dist Pairs: "<<distpairsmax<<'\n';
}
Editorialist's code (Floyd-Warshall + CHT, C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

// CHT; query returns max of lines in set
struct Line {
	mutable ll k, m, p;
	bool operator<(const Line& o) const { return k < o.k; }
	bool operator<(ll x) const { return p < x; }
};
struct LineContainer : multiset<Line, less<>> {
	// (for doubles, use inf = 1/.0, div(a,b) = a/b)
	const ll inf = LLONG_MAX;
	ll div(ll a, ll b) { // floored division
		return a / b - ((a ^ b) < 0 && a % b); }
	bool isect(iterator x, iterator y) {
		if (y == end()) { x->p = inf; return false; }
		if (x->k == y->k) x->p = x->m > y->m ? inf : -inf;
		else x->p = div(y->m - x->m, x->k - y->k);
		return x->p >= y->p;
	}
	void add(ll k, ll m) {
		auto z = insert({k, m, 0}), y = z++, x = y;
		while (isect(y, z)) z = erase(z);
		if (x != begin() && isect(--x, y)) isect(x, y = erase(y));
		while ((y = x) != begin() && (--x)->p >= y->p)
			isect(x, erase(y));
	}
	ll query(ll x) {
		assert(!empty());
		auto l = *lower_bound(x);
		return l.k * x + l.m;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	const ll inf = 1e18 + 5;
	int t; cin >> t;
	while (t--) {
		int n, m; cin >> n >> m;
		vector a(n*n, 0);
		for (int &x : a) cin >> x;
		vector dist(n+1, vector(n+1, inf));
		for (int i = 1; i <= n; ++i) dist[i][i] = 0;
		for (int i = 0; i < m; ++i) {
			ll x, y, z; cin >> x >> y >> z;
			dist[x][y] = min(dist[x][y], z);
		}
		for (int k = 0; k < n; ++k) {
			for (int i = 1; i <= n; ++i) {
				for (int j = 1; j <= n; ++j) {
					dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]);
				}
			}
		}
		
		vector<LineContainer> LC(n+1);
		for (int i = 1; i <= n; ++i) {
			for (int j = 1; j <= n; ++j) {
				if (dist[i][j] > inf/2) continue;
				LC[i].add(4*j, -dist[i][j] - 2*j*j);
			}
		}

		ll ans = 0;
		for (int i = 0; i < n; ++i) {
			for (int j = i+1; j < n; ++j) {
				int x = a[i*n + j], y = a[i + j*n];
				ll cost = inf;
				for (int k = 1; k <= n; ++k) {
					// end at (k, z) with cost:
					// 2k^2 + 2z^2 - 4*k*z + dist[x][k] + dist[y][z]
					// (2k^2 + dist[x][k]) + min(dist[y][z] + 2z^2 - 4*k*z) across all z
					cost = min(cost, 2*k*k + dist[x][k] - LC[y].query(k));
				}
				ans += cost;
			}
		}
		cout << ans << '\n';
	}
}
1 Like

My idea. Maybe a little easier to understand.

Suppose we have a pair (x,y), let’s rephrase the cost+digust value.
What we are doing by replacing x with x' is going through a directed edge in a graph and pay the weight of the edge. At last we stop at two nodes (x',y'), and pay 2\times(x'-y')^2, which could be treated as the weight of a special edge.
Suppose the optimized result of (x,y) is (x',y'). Now don’t think about from (x,y) to (x',y'), but think as x goto x', then y' and finally y. You can see, a valid path must be as follows: go from x through zero or more given edges to a final x', go through a special edge to y', then go through zero or more reversed edges back to y.
Let’s construct a graph with 2N nodes. For a given tuple (x, y, z), connect x to y with an edge weighted z, and N + y to N + x with an edge weighted z. Last, connect each pair of x to y with an edge weighted 2\times(x-y)^2.
Then, for each pair A_{ij} and A_{ji}, the minumum cost of this pair is the minumum distance from A_{ij} to A_{ji}. You can use Floyd algorithm to achieve in O(N^3) time complexity.

5 Likes