MAXEQUAL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Simple

PREREQUISITES:

None

PROBLEM:

You’re given an array A, initially filled with zeros.
Process N updates to it, each one giving you X and Y and asking you to set A_X := Y.

After each update, compute f(A) as follows:

  • Replace each 0 in A with any positive integer of your choice.
  • Then, compute the number of pairs (i, j) such that i \lt j and A_i = A_j, i.e, the number of pairs of equal elements in A.
  • f(A) is the maximum possible value of this count.

EXPLANATION:

First, let’s see how we can compute f(A) quickly for a fixed array A.

Let’s ignore the zeros for now, and focus only on the existing non-zero elements.
We want the number of pairs among them that are equal.
Let \text{freq}[x] denote the number of times x appears in the array.
Then, the number of pairs of indices that both contain x is exactly

\frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

So, among non-zero indices, the number of equal pairs is

\sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

Now, we have to think about what to do with the zeros.
It’s not hard to see that since we want to maximize the number of equal pairs, it’s optimal to set all the zeros to the same value, say y.

If there are k zeros, and we set them all to y, the number of additional new pairs we create is exactly

\frac{k\cdot (k-1)}{2} + k\cdot\text{freq}[y]

The first term comes from pairs within the new copies of y, while the second comes from pairs that involve one new copy and one existing copy.

Since k is a constant, and our aim is to maximize this quantity, clearly we should choose whichever y has the maximum \text{freq}[y].

So, we can now compute f(A) for a fixed array A in linear time:

  • First, compute the \text{freq} array of frequencies of non-zero elements.
  • Then, find the maximum element of this array, let it be M.
  • If there are k zeros in the array, f(A) will equal
\frac{k\cdot (k-1)}{2} + k\cdot M + \sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2}

Of course, running this in linear time after every update is going to be too slow.
However, observe that upon an update that sets A_X := Y, very little actually changes:

  1. k decreases by 1, since we have one fewer zero.
  2. \text{freq}[Y] increases by 1.
    • So, only one term changes in the \sum_{x=1}^N \frac{\text{freq}[x]\cdot (\text{freq}[x]-1)}{2} expression.
      Thus, instead of recomputing the whole thing, we can update the sum in \mathcal{O}(1) with the new value of the term corresponding to Y.
  3. M, the maximum value of \text{freq}, either remains the same or becomes equal to \text{freq}[Y] (after \text{freq}[Y] has been increased by 1).
    That is, we can simply set M = \max(M, \text{freq}[Y]).

So, in constant time, all three of k, M, and the summation can be updated.
This means the actual answer can also be updated in constant time, giving us a algorithm that runs in \mathcal{O}(N) time.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;

    vector <int> f(n + 1, 0);
    int ans = 0;
    int mx = 0;
    for (int i = 1; i <= n; i++){
        int x, y; cin >> x >> y;
        ans += f[y]++;
        mx = max(mx, f[y]);
        
        int v = mx * (n - i) + (n - i) * (n - i - 1) / 2;

        cout << (ans + v) << " \n"[i == n];
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;

    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void solve(int n, vector<int>& x, vector<int>& y) {
    vector<int> cnt(n + 1);
    int mx = 0;
    long long sum = 0;
    for (int i = 0; i < n; i++) {
        sum += cnt[y[i]];
        cnt[y[i]] += 1;
        mx = max(mx, cnt[y[i]]);
        long long ans = sum;
        ans -= mx * 1LL * (mx - 1) / 2;
        ans += (mx + n - i - 1) * 1LL * (mx + n - i - 2) / 2;
        cout << ans << " \n"[i == n - 1];
    }
}

////////////////////////////////////////

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        vector<int> x(n), y(n);
        for (int i = 0; i < n; i++) {
            x[i] = in.readInt(1, n);
            in.readSpace();
            y[i] = in.readInt(1, n);
            in.readEoln();
        }
        assert((int) set<int>(x.begin(), x.end()).size() == n);
        solve(n, x, y);
    }
    cerr << sn << endl;
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    freq = [0]*(n+1)
    mxfreq, sm = 0, 0
    for i in range(n):
        x, y = map(int, input().split())
        sm += freq[y]
        freq[y] += 1
        mxfreq = max(mxfreq, freq[y])
        
        rem = n-1-i
        print(sm + rem*mxfreq + rem*(rem-1)//2)

Can you please explain me expected output for following test case:
4
1 1
2 1
3 2
4 3

Expected output: 6 6 3 1

intially our array will be containing all zeroes so the current array looks like this:-
0 0 0 0
now after the first operation
we will have 1 at 0th index (if you convert this to 0 based indexing)
1 0 0 0 now you can see that we can convert all the 0s to any positive integer so let say we change it to 1 all therefore the new array will look something like 1 1 1 1 and now the total distinct indices pairs having same values are (0,1) , (0,2) ,(0,3) , (1,2) ,(1,3) , (2,3) that is 6 ways.

similarly after second operation our array will be 1 1 0 0 for this also we can simply convert all 0s to 1s and it will be same case as previous now coming to third case the array will be like
1 1 3 0 now here it is optimal to change 0 as 1 therefore new array will be 1 1 3 1
so now the distinct indices having equal number is (0,1) (0,3) (1,3) that is only 3 pairs.
hope you could understand now that what is going on and try to figure out yourself for the last case.

Thanks a lot…

1 Like

Well , what’s wrong wih mycode ?

include <bits/stdc++.h>
include
include
using namespace std;

int main(){
int t,n,X,Y;
cin >> t;
while(t–){
cin >> n;
vector sma;
vector arr (n,0);
vector hash(n+1,0);
for(int i = 0 ; i < n ; ++i){
cin >> X >> Y;
sma.push_back(Y);
}
for(int k = 0; k < n ; ++k){
hash[sma[k]]++;
}
int m = 0, ans = 0;
for(int j = 1; j < n+ 1 ; ++j ){
if(hash[j] > m){
ans = ans + m * (m - 1) / 2;
m = hash[j];
}
else{
ans = ans + hash[j] * (hash[j] - 1) / 2;
}
}
m = m + hash[0];

    int last = m * (m - 1) / 2  + ans ;
    for(int j = 0; j < n - 1; ++j){
        if(find(arr.begin(),arr.end(), sma[j] )== arr.end() ){
            arr[j] = sma[j];
            cout << (((n-j) * ((n-j) - 1)) / 2 ) << " "; 
        }
        else{
            cout << ((n-j+1) * ((n-j+1) - 1)) / 2 << " ";
        }
    }
    cout << last;
    cout << "\n";
}
return 0;

}