MSUB121 - Editorial

PROBLEM LINK:

Practice
Div1
Div2
Div3

Setter: Ankush Chhabra
Testers: Lavish Gupta, Tejas Pandey
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Dynamic Programming

PROBLEM:

We are given an array A of N integers and Q ordered pairs of the form (x_i, y_i). We need to find the length of the longest magical subsequence of A and print any one of the subsequence indices. A subsequence of A say (A_{i_1}, A_{i_2}, \dots A_{i_k}) is said to be magical if for each 1 \leq j \lt k, (A_{i_j}, A_{i_{j+1}}) must be equal to one of the Q pairs given.

EXPLANATION:

An O(N \cdot Q) solution:

Let us initially forget about the time constraints and try to solve the problem. Let us try to solve this problem by using dynamic programming. For this, we need a state.

Let dp_i be defined as the length of longest subsequence ending at index i. We can initialize it to 1 initially. Now, dp_i is computed as follows: For every value x which forms one of the Q ordered pairs with A_i, update dp_i = max(dp_i, dp_{recent_x} + 1) where recent_x is a helper array which gives us the recent index inside the array A which has the value equal to x. This array recent will be continuously updated as we iterate over i from 1 to N.

The computation of dp values is fairly straightforward. But there is only one problem, its time complexity. Clearly, the worst case time complexity of this approach is O(N \cdot Q). Let us try to find a way to optimise it using square root idea.

An O(N \sqrt Q) solution:

Let us define count(y) as the number of elements x for which (x, y) belongs to one of the Q ordered pairs.

Let us also define two groups big and small. Any number y belonging to big has count(y) \gt \sqrt Q and any number y belonging to small has 1 \leq count(y) \leq \sqrt Q. These two groups can be computed in a straightforward manner. Clearly, the maximum number of values in big group possible will be O(\sqrt Q). (think why this is the case).

For the sake of reducing the time complexity, we need to define another array dpVal where dpVal_x is defined as the length of the longest subsequence ending at value x. Note that this array is different from dp since here I am talking about the actual values instead of the indices. Also, this array is intended only for the values in the large group which we will understand why shortly.

The algorithm is now proceeded as follows:

Initialize dpVal_x =1 for all values x belonging to the large group.

Let us now iterate over i from 1 to N and for every iteration i:

  • Initialize dp_i to 1.

  • If the value at the current index A_i belongs to the small group, then update dp_i by the steps mentioned in the first non-optimal solution. This takes O(\sqrt Q) time because of the definition of the small group.

  • If the value at the current index A_i belongs to the large group, then update dp_i as dpVal_{A_i}.

  • At the end of this iteration, we also need to update dpVal for values in the large group. For every value y of the large group, if (A_i, y) belongs to the one of the Q ordered pairs, update dpVal_x as dpVal_x = max(dpVal_x, dp_i +1 ). This takes O(\sqrt Q) time since there are atmost O(\sqrt Q) values in the large group.

Thus, by splitting the right hand side values of the Q pairs into small and large groups and cleverly utilizing them, we are able to reduce the time complexity of our solution from O(N \cdot Q) to O(N \sqrt Q). Note that I haven’t talked about constructing one such sequence which can be done easily by storing the last index information while calculating dp and dpVal values.

You can refer the code for further understanding.

TIME COMPLEXITY:

O(N \sqrt Q) or O(N \sqrt Q + N \log N) depending on the implementation.

SOLUTION:

Editorialist's solution
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1e6 + 5;

vector<int> lvals[MAXN]; // For a value l, list all r which has (l, r) pair in one of the Q pairs
vector<int> rvals[MAXN]; // For a value r, list all l which has (l, r) pair in one of the Q pairs
vector<int> largeVals[MAXN]; // For a value l, list all r belonging to large group which has (l, r) pair in one of the Q pairs
vector<bool> checkBig(MAXN);
vector<int> recent(MAXN, -1);
vector<int> dpVal(MAXN);
vector<int> prevIndForVal(MAXN, -1); // Used for constructing the final subsequence

int main()
{
      int tests;
      cin >> tests;
      while (tests--)
      {
            int n;
            cin >> n;

            vector<int> a(n);
            set<int> rnums, lnums;

            for (int i = 0; i < n; i++)
            {
                  cin >> a[i];
            }

            int q;
            cin >> q;
            
            set<pair<int,int>> pairs;

            for (int i = 0; i < q; i++) {
                  int x, y;
                  cin >> x >> y;
                  pairs.insert({x, y});
            }
            
            for (pair<int,int> p: pairs) {
                int x = p.first;
                int y = p.second;
                rvals[y].push_back(x);
                lvals[x].push_back(y);
                rnums.insert(y);
                lnums.insert(x);
            }

            vector<int> dp(n, 1);
            vector<int> prev(n, -1); // Used for constructing the final subsequence
            
            int BLOCK = sqrt(q + 0.5);

            for (int x : rnums)
            {
                  if ((int)rvals[x].size() > BLOCK)
                  {
                        checkBig[x] = true;
                        dpVal[x] = 1;
                  }
            }
            
            for (int x: lnums) {
                for (int r: lvals[x]) {
                    if (checkBig[r]) {
                        largeVals[x].push_back(r);
                    }
                }
            }

            for (int i = 0; i < n; i++)
            {
                  int x = a[i];

                  // For big category number
                  if ((int)rvals[x].size() > BLOCK)
                  {
                        dp[i] = dpVal[x];
                        prev[i] = prevIndForVal[x];
                  }
                  // For small category number
                  else
                  {
                        for (int l : rvals[x])
                        {
                              if (recent[l] != -1 && dp[i] < dp[recent[l]] + 1)
                              {
                                    int ind = recent[l];
                                    dp[i] = dp[ind] + 1;
                                    prev[i] = ind;
                              }
                        }
                  }

                  recent[x] = i;

                  // Update all dpVal values for big category numbers
                  for (int r : largeVals[x])
                  {
                        if (dpVal[r] < dp[i] + 1)
                        {
                              dpVal[r] = dp[i] + 1;
                              prevIndForVal[r] = i;
                        }
                  }
            }

            int cur = -1;
            int ans = 0;

            for (int i = 0; i < n; i++)
            {
                  if (ans < dp[i])
                  {
                        ans = dp[i];
                        cur = i;
                  }
            }

            cout << ans << endl;
            vector<int> seq;

            while (cur != -1)
            {
                  seq.push_back(cur + 1);
                  cur = prev[cur];
            }

            reverse(seq.begin(), seq.end());

            for (int x : seq)
            {
                  cout << x << " ";
            }
            cout << endl;

            // Reset the global vectors

            for (int x : a)
            {
                  recent[x] = -1;
            }

            for (int x : rnums)
            {
                  checkBig[x] = false;
                  rvals[x].clear();
                  dpVal[x] = 0;
                  prevIndForVal[x] = -1;
            }
            
            for (int x : lnums)
            {
                  lvals[x].clear();
                  largeVals[x].clear();
            }
      }
      return 0;
}
Setter's solution
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define ff first
#define ss second
#define ii insert
#define mem(l,r) memset(l,r,sizeof(l))
#define sorta(a,n) sort(a+1,a+1+n)
#define sortv(v) sort(v.begin(),v.end())
#define revs(s) reverse(s.begin(),s.end())
#define fastio ios::sync_with_stdio(false), cin.tie(NULL),cout.tie(NULL);
const int N=1e5+5;
const int mod=1e9+7;
const ll int_max=1e18;
const ll int_min=-1e18;
#define rep(i,j,k) for(ll i=j;i<=k;i++)
#define repr(i,j,k) for(ll i=j;i>=k;i--)
const long double PI = acos(-1);
using namespace std;
ll sumn=0,sumq=0;
void solve()
{
     int n;
     cin>>n;
     int a[n+2];
     assert(n>=1 && n<=100000);
     sumn+=n;
     vector<vector<int>>from_l(n+2);
     vector<vector<int>>from_r(n+2);
     vector<set<int>>act_as_l(n+2);
     ll unique=0;
     map<ll,ll>vis;
    rep(i,1,n)
     {
        cin>>a[i];
        if(vis[a[i]]==0)
        {
            unique++;
            vis[a[i]]=unique;//replacing large values with 1 to n
        }
        a[i]=vis[a[i]];
        assert(a[i]>=1 && a[i]<=1000000);
     }  
     int q;
     cin>>q;
     assert(q>=1 && q<=100000);
     sumq+=q;
     vector<int>l(q+2),r(q+2);
     vector<int>freqr(100001,0);
     vector<int>freql(100001,0);
     rep(i,1,q)
     {
        cin>>l[i]>>r[i];
        assert(l[i]>=1 && l[i]<=1000001);
        assert(r[i]>=1 && r[i]<=1000001);
        l[i]=vis[l[i]];
        r[i]=vis[r[i]];
        if(vis[l[i]]==0 || vis[r[i]]==0)
        {
            continue;
        } 
        freqr[r[i]]++;
        freql[l[i]]++;
     }
     int sq=sqrt(q);
    rep(i,1,q)
    {
        if(freql[l[i]]>freqr[r[i]])
        {
            from_r[r[i]].pb(l[i]);
        }
        else
        {
            from_l[l[i]].pb(r[i]);
        }
        act_as_l[l[i]].ii(r[i]);
    }
     int dp[n+2];
     mem(dp,0);
     int dpUpd[n+2][2];
     mem(dpUpd,0);
     int ans=0;
     rep(i,1,n)
     {
        dp[i]=dpUpd[a[i]][1];//replacing it from its previous l value when a[i] was assumed to be its future value
        int mx=dpUpd[a[i]][0];//answer taking from the previous values is stored in mx
        for(int j:from_r[a[i]])
        {  
            mx=max(mx,dpUpd[j][0]+1);
        }
        dp[i]=max(dp[i],mx);
        dp[i]=max(dp[i],1);
        for(int j:from_l[a[i]])//for storing the future values
        {
            dpUpd[j][1]=max(dp[i]+1,dpUpd[j][1]);//where i is the index and j is the element
        }
        dpUpd[a[i]][0]=dp[i];
        dpUpd[a[i]][1]=max(dpUpd[a[i]][1],dpUpd[a[i]][0]);
        ans=max(ans,dp[i]);
     }
     assert(ans>=2);
     cout<<ans<<'\n';
     int last=n;
     vector<int>subsequence;  
     int want=ans;
     int prev;
    while(want>=1 && last>=1)
    {
        if(dp[last]==want)
        {
            if(want==ans){
            prev=last;
            want--;
            subsequence.pb(last);
            }
            else
            {
                if(act_as_l[a[last]].find(a[prev])!=act_as_l[a[last]].end())
                {
                    want--;
                    subsequence.pb(last);
                    prev=last;
                }
            }
        }
        last--;
    }
    reverse(subsequence.begin(),subsequence.end());
    for(int i:subsequence)
    {
        cout<<i<<' ';
    }
    cout<<'\n';
    assert(sumn<=200000 && sumq<=200000);
}
int main() {
    fastio
    int t;
    cin>>t;
    assert(t>=1 && t<=50);
    while(t--)
    {
        solve();   
    }
}

Tester's solution
 #include <bits/stdc++.h>
#define ll long long int
using namespace std;


/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}


/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 50;
const int MAX_N = 100000;
const int MAX_A = 1000000;
const int MAX_Q = 100000;
const int MAX_X = 1000000;
const int MAX_Y = 1000000;
const int SUM_N = 200000;
const int SUM_Q = 200000;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

int sum_n=0;
int sum_q=0;

int bans[MAX_X + 1], lst[MAX_X + 1];

void solve()
{
    int n = readIntLn(1, MAX_N);
    sum_n += n;
    assert(sum_n <= SUM_N);
    int a[n];
    for(int i = 0; i < n - 1; i++) a[i] = readIntSp(1, MAX_A), bans[a[i]] = lst[a[i]] = 0;
    a[n - 1] = readIntLn(1, MAX_A);
    bans[a[n - 1]] = lst[a[n - 1]] = 0;
    int q = readIntLn(1, MAX_Q);
    sum_q += q;
    assert(sum_q <= SUM_Q);
    map<int, vector<int>> ed;
    set<pair<int, int>> pts;
    vector<int> big;
    int lim = sqrt(q) + 100;
    for(int i = 0; i < q; i++) {
        int x = readIntSp(1, MAX_X);
        int y = readIntLn(1, MAX_Y);
        ed[y].push_back(x);
        pts.insert({x, y});
        if(ed[y].size() == lim) big.push_back(y);
    }
    int ans[n], par[n];
    for(int i = 0; i < n; i++) {
        if(ed[a[i]].size() >= lim) {
            if(!bans[a[i]]) {
                ans[i] = 1;
                par[i] = -1;
            } else {
                int val = bans[a[i]];
                par[i] = lst[val] - 1;
                ans[i] = ans[par[i]] + 1;
            }
        } else {
            ans[i] = 1;
            par[i] = -1;
            for(auto it: ed[a[i]]) {
                if(!lst[it]) continue;
                int pos = lst[it] - 1;
                if(ans[pos] >= ans[i]) {
                    ans[i] = ans[pos] + 1;
                    par[i] = pos;
                }
            }
        }
        for(int j = 0; j < big.size(); j++) {
            if(pts.find({a[i], big[j]}) != pts.end()) {
                if(!bans[big[j]] || ans[i] > ans[lst[bans[big[j]]] - 1]) {
                    bans[big[j]] = a[i];
                }
            }
        }
        lst[a[i]] = i + 1;
    }
    int res = 0;
    for(int i = 1; i < n; i++)
        if(ans[i] > ans[res]) res = i;
    cout << ans[res] << "\n";
    int st = res;
    vector<int> r;
    while(st != -1) {
        r.push_back(st + 1);
        st = par[st];
    }
    reverse(r.begin(), r.end());
    assert(r.size() == ans[res]);
    for(int i = 0; i < ans[res]; i++) cout << r[i] << " ";
    cout << "\n";
}

signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #endif


    int t = readIntLn(1, MAX_T);

    for(int i=1;i<=t;i++)
    {
        solve();
    }

    assert(getchar() == -1);
}


Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:

3 Likes

My solution was similar except minor difference.

For every query (x,y),
update dp values first way at a[i]=y if count(x)>count(y)
and
update dp values second way at a[i]=x if count(x)<count(y)

2 Likes