GUNDIS - Editorial

Problem statement
Contest source

Author : Prince Raaz
Editorialist : Raghav Agarwal
Tester : Miten Shah

DIFFICULTY:

Easy Medium

PREREQUISITES:

DP

PROBLEM:

Given an array divide it into subarrays such that the product of number of unique elements in each subarray is maximized.

EXPLANATION:

Observation:

If we have a subarray having more than 3 unique elements then dividing it into smaller subarrays (of size greater than 1) always results in a product greater or equal to that of the initial subarray. As after dividing the subarray, the parts will together contain either equal or more unique elements than the initial and for any number greater than 3 the product of its parts (greater than 1) will be greater than or equal to the number itself.

This means that in the optimal division of array. Each subarray will have either 2 or 3 unique elements.

Dynamic programming

If we partition the last i elements from our array of size n then the problem will be converted to a sub problem with n-i elements, thus we can use dynamic programming here.

  • State - dp[x] = maximum product achieved using first x elements of array in which the x^{ th } element belongs to last subarray in division.

  • Transition - dp[x] = \text{ max } dp[x-i]*(\text{Number of unique elements in last i elements}).
    As observed we only need to do this till the number of unique elements is smaller or equal to 3, we can track unique elements using a std::set.

SOLUTION :

c++ Solution (Setter's)
#include <bits/stdc++.h>
#define fast_io std::ios::sync_with_stdio(false), cin.tie(NULL), cout.tie(NULL)

using namespace std;

typedef long long int ll;
typedef pair<ll, ll> pll;
typedef vector<vector<ll>> matrix;

const int mod = 1000000007;
const long double ratio32 = 1.0 / log2(3);

// The product will not fit inside the numerical data types of most languages. However since the product if of form 2^x * 3^y, we can keep store the count of $x$ and $y$ instead of the product, and use logarithm to compare products.
bool comp(pll a, pll b) {
        return (a.first + a.second * ratio32) < (b.first + b.second * ratio32);
}

int sum = 0;
void ca() {
        ll n;
        cin >> n;
        sum += n;
        ll a[n];
        for (int i = 0; i < n; i++)
                cin >> a[i];

        // compress array
        // [7 7 7 7 7] 2 2 2 3
        // 7 7 2 2 3


        // Since for each dp state we will iterate backwards till we find 3 unique elements, this could take $O(n^2)$ worst case time when there are large number of identical elements in array. To correct this, we will *compress* the array. 

// Consider a subarray consisting of only identical elements like - $[7, 7, 7, 7, 7]$. In optimal division, this subarray will either be part of one larger subarray or it will be broken in two and divided between two adjacent subarrays. In first case it will provide $7$ to large subarray, and in second case a $7$ each to adjacent subarrays. If we replaced the subarray of $7s$ with just $[7, 7]$ then the situation will still be identical, it could be  part of a single larger subarray or two adjacent subarrays and provide them with a $7$ each. That is for any subarray of identical elements larger in size than $3$, we can replace it with just two of those identical elements.

// Doing this *compression* ensures that every $3$ elements contains atleast $2$ unique elements so our worst case is handled. 
        vector<ll> arr;
        arr.push_back(0);
        for (int i = 0; i < n && i < 2; i++)
                arr.push_back(a[i]);
        for (int i = 2; i < n; i++) {
                int m = arr.size();
                if (!(a[i] == arr[m - 1] && a[i] == arr[m - 2]))
                        arr.push_back(a[i]);
        }
        n = arr.size() - 1; // -1 for 1 indexing

        vector<pll> dp(n + 1);
        map<int, int> seen;

        dp[0] = {0, 0}; // number of 3, number of 2

        for (int i = 1; i <= n; i++) {
                dp[i] = dp[i - 1]; // Take no element
                set<ll> seen;
                // we could go backwards beyond 6 in case of 1, 2, 1, 2, 1, 2, 1, 2 ...
                // however making a subarray larger than 6 elements is suboptimal, so we check it
                for (int j = i; j > max(0, i - 6); j--) {
                        seen.insert(arr[j]);
                        if (seen.size() == 2)
                                dp[i] = max(dp[i],
                                            (pll){dp[j - 1].first,
                                                  dp[j - 1].second + 1},
                                            comp);
                        else if (seen.size() == 3) {
                                dp[i] = max(dp[i],
                                            (pll){dp[j - 1].first + 1,
                                                  dp[j - 1].second},
                                            comp);
                                break;
                        }
                }
        }
        ll ans = 1;
        for (int i = 0; i < dp[n].second; i++)
                (ans *= 2) %= mod;
        for (int i = 0; i < dp[n].first; i++)
                (ans *= 3) %= mod;
        cout << (ans) << endl;
}

int main() {
        fast_io;
        int t;
        cin >> t;
        while (t--)
                ca();
}
C++ Tester's solution
// created by mtnshh

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
 
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
 
#define err() cout<<"--------------------------"<<endl; 
#define errA(A) for(auto i:A)   cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl

#define all(A)  A.begin(),A.end()
#define allr(A)    A.rbegin(),A.rend()
#define ft first
#define sd second

#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
 
#define endl "\n"

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            // cerr << x << " " << l << " " << r << endl;
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
const ll N = 500005;
const ll M = 1000000007;

pll dp[N];
ll A[N], B[N][2];

long double f(pll p){
	return log(2) * p.ft + log(3) * p.sd; 
}

pll max(pll a, pll b){
	if(f(a) > f(b)){
		return a;
	}
	return b;
}

ll power(ll a,ll n,ll m=M){
    ll ans=1;
    while(n){
        if(n&1) ans=ans*a;
        a=a*a;
        n=n>>1;
        ans=ans%m;
        a=a%m;
    }
    return ans;
}

map<ll,ll> m;
set<ll,greater<ll>> s;

void solve(ll n){
    m.clear();
    s.clear();
	rep(i,1,n+1)	A[i] = (i != n) ? readIntSp(0, 1e9) : readIntLn(0, 1e9);
	B[1][0] = 0; B[1][1] = 0;
	m[A[1]] = 1;
    A[0] = -1;
    s.insert(1);
    rep(i,2,n+1){
        ll a = 0, b = 0, c = 0;
        if(s.size() > 0)    a = *s.begin(),s.erase(s.begin());
        if(s.size() > 0)    b = *s.begin(),s.erase(s.begin());
        if(s.size() > 0)    c = *s.begin(),s.erase(s.begin()); 
        ll cnt = 0;
        if(cnt < 2 and A[a] != A[i])    B[i][cnt] = a, cnt++;
        if(cnt < 2 and A[b] != A[i])    B[i][cnt] = b, cnt++;
        if(cnt < 2 and A[c] != A[i])    B[i][cnt] = c, cnt++;
        s.insert(a);s.insert(b);s.insert(c);
        if(m.count(A[i])){
            ll x = m[A[i]];
            s.erase(x);
        }
        s.insert(i);
        m[A[i]] = i;
    }
    dp[0] = {0, 0};
	rep(i,1,n+1){
		dp[i] = dp[i-1];
		if(B[i][0] != 0){
			ll tmp = B[i][0];
			dp[i] = max(dp[i], {dp[tmp-1].ft + 1, dp[tmp-1].sd});
		}
		if(B[i][1] != 0){
			ll tmp = B[i][1];
			dp[i] = max(dp[i], {dp[tmp-1].ft, dp[tmp-1].sd + 1});
		}
	}
    cout << (power(2, dp[n].ft) * power(3, dp[n].sd)) % M << endl;
	return;
}

int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	ll no = readIntLn(1, 1e5);
    ll sum_n = 0;
	while(no--){
        ll n = readIntLn(2, 2e5);
        sum_n += n;
		solve(n);
	}
    assert(sum_n <= 2e5);
    assert(getchar()==-1);
    return 0;
}

2 Likes

Can someone please tell me what is wrong with my code
here