MXFREQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash_agrawal27
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

simple

PREREQUISITES:

Maximum subarray sum

PROBLEM:

You’re given an array A and an integer X.
At most once, you can multiply all the elements of a subarray of A by X.

Find the maximum possible frequency of an element of X.

EXPLANATION:

Let’s fix an integer M, and see what the maximum possible frequency of M we can obtain is.
The final answer is the maximum of this across all integers M.

There are two possibilities for M: either it appears in the initial array, or it does not.
Note that if M doesn’t appear in the initial array. we can ignore it entirely: all occurrences of M we obtain will have to come from multiplying \frac M X by X (if this is an integer in the first place), so we could just choose to not do anything instead and take the frequency of \frac M X.

So, only M that already appear in A need to be considered.
Once again, we have a couple of cases: either M is a multiple of X, or it is not.

If M is not a multiple of X, then it’s impossible to get more copies of M by multiplying any other integer. So, the best we can do here is just \text{freq}[M] itself.

This leaves the case where M is a multiple of X, so we might be able to multiply some copies of \frac M X by X to obtain more copies of M.
However, since we must multiply some subarray by X, it’s possible that we might multiply some occurrences of M by X (and so they’ll no longer be M).

Let’s define a profit P_i for each element of A:

  • If A_i = \frac M X, then P_i = 1 since we obtain one more occurrence by multiplying it.
  • If A_i = M, then P_i = -1 since we lose an occurrence by multiplying it.
  • For any other A_i, P_i = 0 since what happens to it doesn’t matter.

Now, notice that if we multiply the subarray [L, R] by X, our profit is simply
P_L + P_{L+1} + \ldots + P_R.
We want the maximum possible profit, which of course corresponds to the maximum subarray sum!

The maximum subarray sum of P can be computed in \mathcal{O}(N) time, and the algorithm to do so is well-known.


Now, we have only solved the problem for a fixed value of M in \mathcal{O}(N) time.
We did observe that only elements of A are candidates for M, but that still leads to a solution that’s \mathcal{O}(N^2) which is too slow.

Let’s look back at the array P, which we were computing the maximum subarray sum of.
Note that each P_i is either -1, 0, or 1.
We can always just ignore the zeros entirely: this won’t affect the sum of any subarray.
This simple optimization, if properly implemented, makes our algorithm fast enough!

Note that after throwing out zeros, the length of P will equal \text{freq}[M] + \text{freq}\left[\frac M X\right].
So, the occurrences of any element will be processed at most twice: once when it’s M, and once when it’s \frac M X.
This means the total amount of work we do becomes just \mathcal{O}(N), across all values of M.

All we need to do is make sure that P is constructed appropriately, i.e. all the zeros are ignored even during construction.
To do that, keep a list of positions corresponding to each value — let this be \text{positions}[M] for M.
Then, when processing M, merge only the lists \text{position}[M] and \text{position}\left[\frac M X\right], which can be done in linear time in the sum of their sizes (recall mergesort).

This ensures we obtain the complexity we want.
Note that you’ll have to use a map to store the frequency and position lists since elements can be as large as 10^9 (or use coordinate compression), leading to an extra \log N in the complexity.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define db(x) cout << #x << " = " << x << endl;
#define dbv(vec) cout << #vec << " = ["; for (auto v : vec) cout << " " << v; cout << " ]" << endl;
#define dbm(m, r, c) cout << #m << " =\n"; for(int i=0; i<r; i++) { for(int j=0; j<c; j++) cout << m[i][j] << " "; cout << endl; }
#define ll long long
#define pll pair<ll,ll>
#define vi vector<ll>
#define vii vector<pll>
#define vvi vector<vector<ll>>
#define vvii vector<vector<pll>>
#define rep(i,a,b) for(ll i=a;i<=b;i++)
#define repr(i,a,b) for(ll i=a;i>=b;i--)
#define bpc(x) __builtin_popcountll(x)
#define ed '\n'
#define mod 1000000007

void solve(){
  ll n,x;
  cin >> n >> x;
  vi a(n);
  map<ll,ll> cnt,pf,dp;
  rep(i,0,n-1){
  	cin >> a[i];
  	cnt[a[i]]++;
  }
  ll ans = 0;
  rep(i,0,n-1) ans=max(ans,cnt[a[i]]);
  if(x==1){
    cout << ans << ed;
    return;
  }
  rep(i,0,n-1){
     pf[a[i]]++;
     if(i!=(n-1)){
        if(a[i+1]%x==0) ans=max(ans,cnt[a[i+1]]+pf[a[i+1]/x]-pf[a[i+1]]+dp[a[i+1]]);
        if(a[i]%x==0) dp[a[i]]=max(dp[a[i]],pf[a[i]]-pf[a[i]/x]);
    }else{
      for(auto j:a){
        if(j%x==0) ans=max(ans,cnt[j]+pf[j/x]-pf[j]+dp[j]); 
      }
    }
  }
  cout << ans << ed;
}

int main(){

  ios_base::sync_with_stdio(false);
  cin.tie(NULL);
  #ifndef ONLINE_JUDGE
   freopen("oo11.txt", "r", stdin);
   freopen("ot11.txt", "w", stdout);
  #endif
   ll t;
   cin >> t;
   while(t--){
    solve();
   }
}
Tester's code (C++)
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
// #define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define di(a) int a;cin>>a;
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define sis string s;
#define sin string s;cin>>s;
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;

int smn = 0;

int maxSum(vi &a){
    int n = a.size();
    rep(i,1,n)a[i] += a[i-1];
    int ans = 0,mx = a.back();
    rrep(i,n-1,0){
        mx = max(mx,a[i]);
        ans = max(ans,mx-a[i]);
    }
    ans = max(ans,mx);
    return ans;
}

void solve()
{
    int n = inp.readInt(1,200'000);
    inp.readSpace();
    smn += n;
    int x = inp.readInt(1,1000'000'000);
    inp.readEoln();
    vi a(n);
    for(int i = 0;i < n;i++){
        a[i] = inp.readInt(1,1000'000'000);
        if(i == n-1)inp.readEoln();
        else inp.readSpace();
    }
    map<int,vector<int>> m;
    repin{
        m[a[i]].push_back(i);
    }
    set<int> s(be(a));
    int ans = 0;
    
    for(auto &e : s){
        int cur = m[e].size();
        if(x > 1 && ((e%x) == 0) && s.count(e/x)){
            vi chk;
            int sz = m[e].size(),sz1 = m[e/x].size();
            vi v = m[e],v1 = m[e/x];
            chk.reserve(sz+sz1);
            int c = 0,c1 = 0;
            while(c != sz || c1 != sz1){
                if(c == sz)chk.pb(1),c1++;
                else if(c1 == sz1)chk.pb(-1),c++;
                else if(v[c] < v1[c1])chk.pb(-1),c++;
                else chk.pb(1),c1++;
            }
            cur += maxSum(chk);
        }
        ans = max(ans,cur);
    }
    cout << ans << "\n";
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
        int t = inp.readInt(1,10'000);
        inp.readEoln();
        while(t--)
        solve();
        inp.readEof();
        assert(smn <= 200'000);
    return 0;
}
Editorialist's code (PyPy3)
def calc(a):
    ans = cur = 0
    for x in a:
        cur = max(0, cur + x)
        ans = max(ans, cur)
    return ans

for _ in range(int(input())):
    n, x = map(int, input().split())
    a = list(map(int, input().split()))
    
    pos = {}
    for i in range(n):
        if a[i] not in pos: pos[a[i]] = []
        pos[a[i]].append(i)
    
    ans = 0
    for y in pos.keys():
        ans = max(ans, len(pos[y]))
        if x == 1 or x*y not in pos: continue

        b = []
        p = q = 0
        while p < len(pos[y]) or q < len(pos[x*y]):
            if p == len(pos[y]):
                b.append(-1)
                q += 1
                continue
            if q == len(pos[x*y]):
                b.append(1)
                p += 1
                continue
            if pos[y][p] < pos[x*y][q]:
                b.append(1)
                p += 1
            else:
                b.append(-1)
                q += 1
        ans = max(ans, len(pos[x*y]) + calc(b))
    print(ans)
4 Likes

My submission gives WA. Here is my approach (somewhat similar to editorial):
freq(i, A_i) is the frequency of element A_i from 0 to i. If we multiply the range i to j (where A_i = A_j) by X, the new frequency of A_i*X would be:

freq(j, A_i) - freq(i, A_i) + 1 - \left(freq\left(j, A_i * X\right) - freq\left(i, A_i * X\right)\right) + freq(i, A_i * X) + freq(N-1, A_i * X) - freq(j, A_i * X)

which when simplified turns out to be:

freq(j, A_i) + freq(N-1, A_i * X) - 2 \cdot freq(j, A_i * X) + 1 - \underline{\left(freq\left(i, A_i\right)) - 2 \cdot freq\left(i, A_i * X\right)\right)}

For a given j, to find the required range (i, j) to multiply by X so that frequency of A_j * X is maximum, we need to minimize the underlined value. This can be done by maintaining a heap (or a set in C++).

Any insight into where the logic is wrong or pointing out corrections in the code would be useful.

1
20 4
19 9 10 18 12 17 15 20 5 9 5 15 20 5 5 12 2 9 13 3
expected output is 5 while your output is 4 in this case …

I am getting TLE and AC for the same code submission.
AC:
https://www.codechef.com/viewsolution/1136825786
https://www.codechef.com/viewsolution/1136826990
TLE:
https://www.codechef.com/viewsolution/1136826025
https://www.codechef.com/viewsolution/1136827065

And the time caused for TLE submissions are : 0.34 sec, 0.49 sec). (it is not 1 sec)
Some problem with the platform?

Thanks @yash_agrawal27 for pointing it out. There simple error in the formulation in the OP, which can be corrected to:
freq(j, A_i) - freq(i, A_i) + 1 - \left( freq(j, A_i * X) - freq(i, A_i * X) \right) + freq(N-1, A_i * X)

which can be simplified to:
freq(j, A_i) - freq(j, A_i * X) + freq(N-1, A_i * X) + 1 - \underline{\left( freq(i, A_i) - freq(i, A_i * X) \right)}
where we have to minimize the underlined part.

The submission now runs well.