PRODMEX - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Mohammed Ehab
Tester: Istvan Nagy
Editorialist: Taranpreet Singh

DIFFICULTY

Medium-hard

PREREQUISITES

Number Theory, Segment Tree

PROBLEM

Given an array A of length N, answer queries of the following form.

Given interval L, R, consider all subsequences of values in subarray A[L: R] and write down their product. The answer to the query is the smallest positive integer not written down. The product of empty subsequence is 1.

QUICK EXPLANATION

  • Only the powers of primes matter, we can discard the rest of the elements.
  • Processing the queries offline, and maintaining the right end pointer at p, let’s maintain function f(x) denoting the rightmost position r such that there exists a subsequence of subarray A[r, p] with product x. We need to maintain this f (x) only for prime powers as well.
  • To answer query L, R, we need to find smallest x such that \displaystyle \min_{y = 1}^{x-1} f(y) \geq L and f(x) < L which can be found using binary search, if we store f(x) as the value at x-th leaf in segment tree.

EXPLANATION

Crucial Observation

Only the powers of prime numbers matter. Answer to any query can only be of the form p^a, where p is prime and a \geq 1.
Proof: WLOG assume answer to some query is X = p^a*q^b where p and q are prime and a, b \geq 1. So we assume that X doesn’t appear as a product of some subsequence in the query range.

This means that values p^a and q^b both appear as products of a subsequence in the query range. Let’s consider subsequence S_1 with product p^a and subsequence S_2 with product q^b. Since gcd(p^a, q^b) = 1, subsequences S_1 and S_2 do not have any common element.

Let’s take the union of S_1 and S_2, and find its the product. The product of S_1 \bigcup S_2 shall be p^a*q^b, which is same as X. Hence, we have a contradiction since X is the product of subsequence S_1 \bigcup S_2.

Therefore, the answer to some query can only be of the form p^a where p is a prime and a \geq 1

Upper Bound on Answer

Considering a query on range L, R, we can see that there can be at most R-L+1 primes in a range. In the worst case, all N entries of the array are distinct primes. The MEX in this case can be equal to the value of (N+1)-th prime. Hence, the value of (N+1)-th prime, which is 1299721 for N = 10^5, shall form an upper bound on the answer. Let’s call MX = 1299721

Handling Prime Powers

Now we can discard all $1$s and all elements which are not prime powers. Let us fix a right end R. The answer to the query L, R shall be the smallest prime power x such that it is not possible to make product x from a subsequence of subarray A[L: R].

To check whether we can make product x from a subsequence of subarray A[L: R], let’s determine the largest position R' such that a subsequence of subarray A[R': R] has product x. Let’s do this for each prime power x.

Defining f_R(x) as the largest position R' such that subarray A[R': R] has a subsequence with product x. f_R(x) = -1 if there’s no such R'.

Assuming 1-based indexing, We know that f_0(x) = -1 for all x. Let’s figure out a way of computing f_R(x) from f_{R-1}(x). Only new element to be added is A_R = p^a.

f_R(x) = f_{R-1}(x) if x is not a power of p. So we only need to update powers of p. We can see that following updates happen

  • f_R(p^a) = R
  • f_R(p^a*x) = max(f_{R-1}(p^a*x), f_{R-1}(x)) where x = p^b, b \geq 1

The first update happens only for x = p^a, and the second update happens for powers of prime p. Since we don’t care about x > MX, there can be log(MX) such positions, where the second update happens. So we can update one by one.

Handling Queries

Now that we have f_R(x) computed and we need to answer query on subarray A_[L: R], we need smallest x such that f_R(x) < L and f_R(y) \geq L \forall y < x.

Earlier, we visualized stored f_R(x) in an array, but we need a data structure to handle point updates and answer prefix minimum queries. Segment Tree is a perfect data structure.

Hence, Let’s build a segment tree where x-th leaf stores f_R(x), and the query answers the minimum of a range. By binary searching, we can find smallest x with f_R(x) < L. The segment tree needs to have at most MX leaves.

TIME COMPLEXITY

The time complexity is O(MX*log^2(MX)) per test case.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
#define MX 1299721
vector<int> occ[MX+5];
vector<pair<int,int> > qu[100005];
int a[100005],tree[4*MX+5],p[MX+5],k[MX+5],ans[100005];
void update(int node,int st,int en,int idx,int val)
{
    if (st==en)
    tree[node]=val;
    else
    {
	    int mid=(st+en)/2;
	    if (idx<=mid)
	    update(2*node,st,mid,idx,val);
	    else
	    update(2*node+1,mid+1,en,idx,val);
	    tree[node]=max(tree[2*node],tree[2*node+1]);
    }
}
int find(int node,int st,int en,int r)
{
    if (st==en)
    return st;
    int mid=(st+en)/2;
    if (tree[2*node]>r)
    return find(2*node,st,mid,r);
    return find(2*node+1,mid+1,en,r);
}
int main()
{
    for (int i=2;i<=MX;i++)
    {
	    if (!p[i])
	    {
		    for (int j=i;j<=MX;j+=i)
		    p[j]=i;
	    }
	    int tmp=i;
	    while (tmp%p[i]==0)
	    {
		    k[i]++;
		    tmp/=p[i];
	    }
	    if (tmp!=1)
	    k[i]=0;
    }
    int t;
    scanf("%d",&t);
    while (t--)
    {
	    for (int i=1;i<=MX;i++)
	    occ[i].clear();
	    int n,q;
	    scanf("%d%d",&n,&q);
	    for (int i=1;i<=n;i++)
	    {
		    scanf("%d",&a[i]);
		    if (a[i]<=MX)
		    occ[a[i]].push_back(i);
		    qu[i].clear();
	    }
	    for (int i=0;i<q;i++)
	    {
		    int l,r;
		    scanf("%d%d",&l,&r);
		    qu[l].push_back({r,i});
	    }
	    for (int i=1;i<=MX;i++)
	    update(1,1,MX,i,(k[i]? n+1:0));
	    for (int i=n;i>0;i--)
	    {
		    if (a[i]<=MX && k[a[i]])
		    {
			    set<pair<int,int> > s;
			    int pp=p[a[i]],tmp=lower_bound(occ[pp].begin(),occ[pp].end(),i)-occ[pp].begin(),cur=0;
			    if (tmp!=occ[pp].size())
			    s.insert({occ[pp][tmp],pp});
			    long long mex=pp;
			    while (mex<=MX && !s.empty())
			    {
				    auto mn=*s.begin();
				    s.erase(s.begin());
				    cur=max(cur,mn.first);
				    int tmp=upper_bound(occ[mn.second].begin(),occ[mn.second].end(),mn.first)-occ[mn.second].begin();
				    if (tmp!=occ[mn.second].size())
				    s.insert({occ[mn.second][tmp],mn.second});
				    while (mn.second!=1)
				    {
					    update(1,1,MX,mex,cur);
					    mex*=pp;
					    if (mex>MX)
					    break;
					    int tmp=lower_bound(occ[mex].begin(),occ[mex].end(),i)-occ[mex].begin();
					    if (tmp!=occ[mex].size())
					    s.insert({occ[mex][tmp],mex});
					    mn.second/=pp;
				    }
			    }
		    }
		    for (auto cur:qu[i])
		    ans[cur.second]=find(1,1,MX,cur.first);
	    }
	    for (int i=0;i<q;i++)
	    printf("%d\n",ans[i]);
    }
}
Tester's Solution
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>

#ifdef HOME
#define NOMINMAX
    #include <windows.h>
#endif

#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

template<class T> bool umin(T &a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T &a, T b) { return a < b ? (a = b, true) : false; }

using namespace std;


long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
	    char g = getchar();
	    if (g == '-') {
		    assert(fi == -1);
		    is_neg = true;
		    continue;
	    }
	    if ('0' <= g && g <= '9') {
		    x *= 10;
		    x += g - '0';
		    if (cnt == 0) {
			    fi = g - '0';
		    }
		    cnt++;
		    assert(fi != 0 || cnt == 1);
		    assert(fi != 0 || is_neg == false);

		    assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
	    }
	    else if (g == endd) {
		    assert(cnt > 0);
		    if (is_neg) {
			    x = -x;
		    }
		    assert(l <= x && x <= r);
		    return x;
	    }
	    else {
		    assert(false);
	    }
    }
}

string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
	    char g = getchar();
	    assert(g != -1);
	    if (g == endd) {
		    break;
	    }
	    cnt++;
	    ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
    return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
    return readString(l, r, ' ');
}

uint32_t tzCount(uint64_t v)
{
#ifdef WIN32
    return static_cast<uint32_t>(_tzcnt_u64(v));
#else
    return __builtin_ctzll(v);
#endif
}

int main(int argc, char** argv) 
{
#ifdef HOME
    if(IsDebuggerPresent())
    {
	    //freopen("../in.txt", "rb", stdin);
	    freopen("../in.txt", "rb", stdin);
	    
	    freopen("../out.txt", "wb", stdout);
    }
#endif

    int maxP = 1'300'000;
    vector<bool> pr(maxP, true);
    pr[0] = pr[1] = false;
    vector<int> vPrimes;//list of primes
    vector<vector<int> > vPrimePws;//list of prime with powers

    vector<bool> vPrPw(maxP);
    
    vector<int> vPrimesIndex(maxP, -1);//prime index in the vPrimes
    vector<int> vPrimesPwIndex(maxP, -1);//index in the  primePWS
    vector<int> vPrimesPwPos(maxP, -1);

    forn(i, maxP)
    {
	    if(pr[i] == false)
		    continue;
	    vPrimesIndex[i] = vPrimes.size();
	    vPrimes.push_back(i);
	    vPrPw[i] = true;
	    int j = 2 * i;
	    while (j < maxP)
	    {
		    pr[j] = false;
		    j += i;
	    }
	    int64_t k = i;
	    k *= i;
	    if (k < maxP)
	    {
		    int ppi = vPrimePws.size();
		    vPrimesPwIndex[i] = ppi;
		    vPrimesPwPos[i] = 0;
		    vPrimePws.push_back({ i });
		    int bpc = 1;
		    while (k < maxP)
		    {
			    vPrimesPwIndex[k] = ppi;
			    vPrimesPwPos[k] = bpc++;
			    vPrimePws.back().push_back(k);		
			    vPrPw[k] = true;
			    k *= i;
		    }
	    }
    }

    int T = readIntLn(1, 5);
    
    forn(tc, T)
    {
	    const int mPI = vPrimes.size();
	    vector<int> vPO(mPI);
	    vector<vector<int>> vPWO;
	    for (const auto& vppi : vPrimePws)
		    vPWO.push_back(vector<int>(vppi.size()));
	    vector<uint64_t> vFound(vPrimes.size() / 64 + 1);

	    auto addVal = [&](int val) {
		    if (vPrimesIndex[val] != -1)
		    {
			    if (0 == vPO[vPrimesIndex[val]]++)
			    {
				    int idx = vPrimesIndex[val];
				    vFound[idx / 64] |= (1ull << (idx % 64));
			    }
		    }
		    if (vPrimesPwIndex[val] != -1)
		    {
			    vPWO[vPrimesPwIndex[val]][vPrimesPwPos[val]]++;
		    }
	    };

	    auto remVal = [&](int val) {
		    if (vPrimesIndex[val] != -1)
		    {
			    if (0 == --vPO[vPrimesIndex[val]])
			    {
				    int idx = vPrimesIndex[val];
				    vFound[idx / 64] &= ~(1ull << (idx % 64));
			    }
		    }
		    if (vPrimesPwIndex[val] != -1)
		    {
			    vPWO[vPrimesPwIndex[val]][vPrimesPwPos[val]]--;
		    }
	    };
	    
	    int N = readIntSp(1, 100'000);
	    int Q = readIntLn(1, 100'000);
	    vector<int> a(N);
	    int actr = 0;
	    for (auto& ai : a)
	    {
		    if (++actr == N)
			    ai = readIntLn(1, 1'000'000'000);
		    else
			    ai = readIntSp(1, 1'000'000'000);
		    if (ai >= maxP || vPrPw[ai] == false)
			    ai = 1;
	    }

	    vector<tuple<int, int, int, int, int>> vQ(Q);
	    int ctr = 0;
	    for (auto& qi : vQ)
	    {
		    int& actl = get<0>(qi);
		    int& actr = get<1>(qi);
		    int& actp = get<2>(qi);
		    int& actc = get<3>(qi);

		    actl = readIntSp(1, N);
		    actr = readIntLn(actl, N);
		    --actl;
		    --actr;
		    actp = ctr++;
		    actc = actl / 512;
	    }

	    sort(vQ.begin(), vQ.end(), [](auto a, auto b) {
		    const int ar = get<1>(a);
		    const int ac = get<3>(a);
		    const int br = get<1>(b);
		    const int bc = get<3>(b);
		    if (ac != bc)
		    {
			    return ac < bc;
		    }
		    return ar < br;
	    });

	    int lastl = N;
	    int lastr = 0;
	    int lastc = -1;

	    for (auto& actq : vQ)
	    {
		    const int actl = get<0>(actq);
		    const int actr = get<1>(actq);
		    const int actp = get<2>(actq);
		    const int actc = get<3>(actq);

		    if (lastc != actc)
		    {
			    lastc = actc;
			    lastl = actl;
			    lastr = actl - 1;

			    fill(vPO.begin(), vPO.end(), 0);
			    for (auto& vppi : vPWO)
				    fill(vppi.begin(), vppi.end(), 0);
			    fill(vFound.begin(), vFound.end(), 0);
		    }

		    {
			    while (lastl > actl)
			    {
				    --lastl;
				    int actV = a[lastl];
				    addVal(actV);
			    }
			    while (lastr < actr)
			    {
				    ++lastr;
				    int actV = a[lastr];
				    addVal(actV);
			    }
			    while (lastl < actl)
			    {
				    int actV = a[lastl];
				    remVal(actV);
				    ++lastl;
			    }

			    while (lastr > actr)
			    {
				    int actV = a[lastr];
				    remVal(actV);
				    --lastr;
			    }
		    }

		    //find the smallest missing prime
		    int smallestIdx = 0;
		    for (size_t i = 0; i < vFound.size(); ++i)
		    {
			    if (vFound[i] != numeric_limits<uint64_t>::max())
			    {
				    uint64_t val = ~vFound[i];
				    int extra = val > 0 ? tzCount(val) : 0;
				    smallestIdx = 64 * i + extra;
				    break;
			    }
		    }
		    int& smallestMissing = get<4>(actq);
		    smallestMissing = vPrimes[smallestIdx];
		    //find the smallest missing prime pwr
		    for (size_t i = 0; i < vPrimePws.size(); ++i)
		    {
			    const auto& actPw = vPrimePws[i];
			    const auto& actPwo = vPWO[i];
			    if (smallestMissing < actPw[1])
				    break;
			    int maxj = 2;
			    for (size_t j = 2; j < actPw.size(); ++j)
			    {
				    if (smallestMissing < actPw[j])
					    break;
				    else
					    ++maxj;
			    }
			    int dp = 1;
			    for (size_t j = 0; j < maxj; ++j)
			    {
				    int tmp = actPwo[j];
				    int mask = 0;
				    while (tmp && mask < (1 << maxj))
				    {
					    mask |= 1;
					    mask <<= (j+1);
					    --tmp;
				    }
				    if(mask == 0)
					    continue;
				    for (int k = maxj; k >= 0; --k)
				    {
					    if (dp & (1 << k))
					    {
						    dp |= (mask << k);
					    }
				    }
			    }
			    dp = ~dp;
			    int tz = tzCount(dp) - 1;
			    if (tz < actPw.size() && smallestMissing > actPw[tz])
			    {
				    smallestMissing = actPw[tz];
			    }
		    }
	    }
	    sort(vQ.begin(), vQ.end(), [](auto a, auto b) {
		    const int ap = get<2>(a);
		    const int bp = get<2>(b);
		    return ap < bp;
		    });
	    for (const auto& actq : vQ)
	    {
		    const int ar = get<4>(actq);
		    printf("%d\n", ar);
	    }
    }
    
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class Main{
    //SOLUTION BEGIN
    int MAX = 1299721;
    boolean[] primePower;
    int[] spf, exp;
    void pre() throws Exception{
        spf = spf(MAX);
        exp = new int[1+MAX];
        primePower = new boolean[1+MAX];
        Arrays.fill(exp, -1);
        exp[1] = 0;
        spf[1] = 1;
        for(int i = 2; i<= MAX; i++)
            if(spf[i] == i){
                long p = i;
                for(int e = 1; p <= MAX; e++, p*= i){
                    exp[(int)p] = e;
                    primePower[(int)p] = true;
                }
            }
            
        
    }
    void solve(int TC) throws Exception{
        int N = ni(), Q = ni();
        int[] A = new int[N];
        for(int i = 0; i< N; i++){
            A[i] = ni();
            if(A[i] <= 1 || A[i] > MAX || !primePower[A[i]])A[i] = 0;
        }
        int[][] qu = new int[Q][];
        for(int q = 0; q< Q; q++)
            qu[q] = new int[]{q, ni()-1, ni()-1};
        Arrays.sort(qu, (int[] i1, int[] i2) -> Integer.compare(i1[2], i2[2]));
        int q = 0;
        int[] ans = new int[Q];
        SegmentTree min = new SegmentTree(1+MAX);
        for(int i = 1; i<= MAX; i++)if(!primePower[i])min.update(i, N);
        for(int r = 0; r < N; r++){
            if(A[r] != 0){
                int p = spf[A[r]];
                int max = p;
                for(long cur = A[r]; cur <= MAX; cur *= p){
                    max = (int)cur;
                }
                while(max > A[r]){
                    min.update(max, min.query(max/A[r], max/A[r]));
                    max /= p;
                }
                min.update(A[r], r);
            }
            //Answering queries
            while(q < Q && qu[q][2] == r){
                int lo = 1, hi = MAX;
                while(lo < hi){
                    int mid = lo+(hi-lo)/2;
                    if(min.query(2, mid) < qu[q][1])hi = mid;
                    else lo = mid+1;
                }
                ans[qu[q++][0]] = hi;
                
            }
        }
        for(int x:ans)pn(x);
    }
    int[] spf(int max){
        int[] spf = new int[1+max];
        for(int i = 2; i<= max; i++)
            if(spf[i] == 0)
                for(int j = i; j <= max; j += i)
                    if(spf[j] == 0)
                        spf[j] = i;
        return spf;
    }
    class SegmentTree{
        long INF = (long)1e12;
        private long initValue(){return -INF;}
        private long update(long oldValue, long newValue){return Math.max(oldValue, newValue);}
        private long merge(long le, long ri){return Math.min(le, ri);}
        private long initQuery(){return INF;}

        private int m= 1;
        private long[] t;
        public SegmentTree(int n){
            while(m<n)m<<=1;
            t = new long[m<<1];
            Arrays.fill(t, initValue());
        }
        public SegmentTree(long[] a){
            while(m<a.length)m<<=1;
            t = new long[m<<1];
            Arrays.fill(t, initValue());
            for(int i = 0; i< a.length; i++)t[i+m] = a[i];
            for(int i = m-1; i>0; i--)t[i] = merge(t[i<<1], t[i<<1|1]);
        }
        public void update(int i, long val){
            t[i += m]  = update(t[i], val);
            for(i>>=1;i>0;i>>=1)t[i] = merge(t[i<<1], t[i<<1|1]);
        }
        public long query(int l, int r){
            long lans = initQuery(), rans = initQuery();
            for(l+=m,r+=m+1;l<r;l>>=1,r>>=1){
                if((l&1)==1)lans = merge(lans, t[l++]);
                if((r&1)==1)rans = merge(t[--r], rans);
            }
            return merge(lans, rans);
        }
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new Main().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

2 Likes