ARREXPAND - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic Programming

PROBLEM:

Given an array A and a parameter X, the following updates are possible:

  • Choose i such that A_i \gt 0 and A_{i+1} = 0.
    Then set A_{i+1} := \gcd(A_i, X).
  • Choose i such that A_i \gt 0 and A_{i-1} = 0.
    Then set A_{i-1} := \gcd(A_i, X).
  • Choose i \lt j \lt k such that A_i \gt 0, A_k \gt 0, and A_{i+1} = \ldots = A_{k-1} = 0.
    Then set A_j := \gcd(A_i, A_j).

Array B is said to be reachable from array A if B can be formed by performing several such operations on A.

Given an array B containing positive integers, find all arrays A such that B is reachable from A.

EXPLANATION:

Clearly, any valid array A must be formed by setting some elements of B to zeros.
Further, if i and j are consecutive non-zero elements, it’s obvious from the operations that all the values between indices i and j must be generated by operating on these values only - stuff outside doesn’t matter.

This immediately gives rise to a natural dynamic programming solution.
Let dp_i denote the number of possible arrays A such that A_i = B_i.
We then have dp_i = \sum_j dp_j, where we consider all those j \lt i such that j is “good”, meaning it’s possible to generate the values between indices i and j using A_i and A_j as starting points.

We now have two things to do: figure out when a pair (j, i) is “good”, and then also optimize the above DP to be subquadratic.


First, let’s figure out what values can be generated between A_i and A_j.
Given that our only options are to take pairwise GCDs or GCDs with X, and GCD is associative, we can see that there are only four distinct values that can be achieved at all, namely:

  • \gcd(A_i, X)
  • \gcd(A_j, X)
  • \gcd(A_i, A_j)
  • \gcd(A_i, A_j, X)

For now, we assume that all these values are distinct.

Further, these values have to be distributed ‘nicely’ in the [j+1, i-1] segment.
Specifically, this means the following:

  1. All occurrences of \gcd(A_i, X) must form a contiguous subarray attached to A_i.
  2. All occurrences of \gcd(A_j, X) must form a contiguous subarray attached to A_j.

If these two conditions are satisfied, it doesn’t actually matter how the \gcd(A_i, A_j) and \gcd(A_i, A_j, X) values appear in the remaining part (which will be a contiguous segment.)
This is because, if we take any range [l, r] and want to fill it with \gcd(A_i, A_j) and \gcd(A_i, A_j, X) in some order, we can do the following:

  • First create an occurrence of \gcd(A_i, A_j) in this range anywhere we want, using A_i and A_j.
  • Use this occurrence along with A_i and A_j to create more occurrences of it in the range, since \gcd(A_i, \gcd(A_i, A_j)) = \gcd(A_i, A_j) (and a similar condition holds for A_j).
  • Then, all empty positions can be filled with \gcd(A_i, A_j, X) using the adjacent operation.

There is exactly one edge case to be wary of: it’s impossible to fill every index from j+1 to i-1 with the value \gcd(A_i, A_j, X), since we need at least two operations to obtain it at all.

So, we have a reasonable check for when i and j can be consecutive non-zero indices.


We now need to use this to optimize our DP.

To do this, observe that if we fix i, then there really aren’t that many ‘important’ indices j \lt i that need to be checked.
In particular, note that if j is “good”, and also satisfies A_j = A_{j+1}, then all indices in the contiguous block containing j (except maybe its right endpoint) will also be “good”.
So, for a contiguous block of equal elements, only its right endpoint and second-right element need to be checked.
If the second-right endpoint is good, then all j in some range will be good (so their values can be added in constant time using prefix sums.)

Similarly, if there are at least 5 distinct elements between j and i, such a j can never be “good”; which discounts all smaller j too.

Next, we make a couple of observations based on the \gcd(A_i, X) and \gcd(A_j, X) values.
Define two arrays L and R such that:

  • L_i is the smallest integer such that for each k satisfying L_i \le k \lt i, we have A_k = \gcd(A_{k+1}, X).
  • R_i is the largest integer such that for each k satisfying i \lt k \le R_i, we have A_k = \gcd(A_{k-1}, X).

Now, when checking some pair (j, i), observe that:

  • If R_j \ge L_i - 1, j is good for sure.
  • Otherwise, it’s exactly all elements in the range [R_j+1, L_i-1] that need to equal one of \gcd(A_i, A_j) or \gcd(A_i, A_j, X).

In particular, the first condition there, along with the fact that R_j is a non-decreasing array, tells us that all “large enough” j will automatically be good.

We now have to worry about only those j for which R_j \lt L_i - 1.
Here, we can bring in the fact that the ‘middle’ elements must be \gcd(A_i, A_j) or \gcd(A_i, A_j, X) only.
In particular, A_{L_i - 1} must be one of these values; and the next distinct element must be the other one.

This greatly restricts our choices for what j can be.
The details are left as an exercise to the reader (or see the attached code); but in general a combination of all the above observations results in only \mathcal{O}(1) choices of j being “important” for a fixed i.

Each (j, i) pair can be checked for validity relatively easily: the L and R arrays along with storing a list of positions corresponding to each value and binary searching on that will help.
This is because the L and R arrays will shrink the range we need to check; and for the smaller range we only need to check if each element if \gcd(A_i, A_j) or \gcd(A_i, A_j, X), which can be done by adding up their total counts and checking if it equals the length of the range.

This allows for a solution in \mathcal{O}(N\log N) time, completing the problem.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

template<typename T>
struct fenwick {
    int n;
    vector<T> tr;
    int LOG = 0;

    fenwick() {

    }

    fenwick(int n_) {
        n = n_;
        tr = vector<T>(n + 1);
        while((1<<LOG) <= n) LOG++;
    }

    int lsb(int x) {
        return x & -x;
    }

    void pupd(int i, T v) {
        i++;
        for(; i <= n; i += lsb(i)){
            tr[i] += v;
            if(tr[i] >= MOD) tr[i] -= MOD;
        }
    }

    T sum(int i) {
        i++;
        if(i < 0) return 0;
        T res = 0;
        for(; i; i ^= lsb(i)){
            res += tr[i];
            if(res >= MOD) res -= MOD;
        }
        return res;
    }

    T query(int l, int r) {
        if (l > r) return 0;
        T res = sum(r) - sum(l - 1);
        return (res%MOD+MOD)%MOD;
    }

    int lower_bound(T s){
        // first pos with sum >= s
        if(sum(n) < s) return n+1;
        int i = 0;
        rev(bit,LOG-1,0){
            int j = i+(1<<bit);
            if(j > n) conts;
            if(tr[j] < s){
                s -= tr[j];
                i = j;
            }
        }

        return i+1;
    }

    int upper_bound(T s){
        return lower_bound(s+1);
    }
};

void solve(int test_case){
    ll n,X; cin >> n >> X;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];

    // partition into segs
    vector<pll> segs;
    rep1(i,n){
        if(a[i] == a[i-1]) segs.back().ss++;
        else segs.pb({a[i],1});
    }

    // make groups based on segs
    vector<pll> b;
    for(auto [x,c] : segs){
        b.pb({x,1});
        if(c >= 3){
            b.pb({x,c-2});
        }
        if(c >= 2){
            b.pb({x,1});
        }
    }

    // precomp pows of 2
    vector<ll> pow2(n+5);
    pow2[0] = 1;
    rep1(i,n) pow2[i] = pow2[i-1]*2%MOD;

    // find nxt and prev for each position
    ll siz = sz(b);
    vector<ll> nxt(siz), prev(siz);
    vector<ll> dp(siz);

    rep(i,siz){
        ll want = gcd(b[i].ff,X);
        nxt[i] = i+1;
        for(int j = i+1; j < siz; ++j){
            if(b[j].ff == want) nxt[i]++;
            else break;
        }

        prev[i] = i-1;
        rev(j,i-1,0){
            if(b[j].ff == want) prev[i]--;
            else break;
        }
    }

    // precalc critical set for each pref
    vector<pll> critical_i[siz+5];
    map<ll,ll> mp;
    set<pll> st; // (pos,val)

    rep(i,siz){
        ll x = b[i].ff;
        if(mp.count(x)){
            st.erase({mp[x],x});
        }
        mp[x] = i;
        st.insert({i,x});
        ll cnt = 0;
        
        vector<pll> critical;

        for(auto it = st.rbegin(); it != st.rend(); ++it){
            cnt++;
            critical.pb(*it);
            if(cnt >= 4) break;
        }

        critical_i[i] = critical;
    }

    fenwick<ll> bigger_nxt_fenw(siz+5);

    vector<ll> unique_b;
    rep(i,siz) unique_b.pb(b[i].ff);
    sort(all(unique_b));
    unique_b.resize(unique(all(unique_b))-unique_b.begin());
    ll cc_siz = sz(unique_b);
    vector<ll> pos[cc_siz];
    fenwick<ll> fenw_sum[cc_siz], fenw_invalid_sum[cc_siz];

    auto get_ind = [&](ll v){
        return lower_bound(all(unique_b),v)-unique_b.begin();
    };

    rep(i,siz){
        pos[get_ind(b[i].ff)].pb(nxt[i]);
    }

    rep(i,cc_siz){
        sort(all(pos[i]));
        pos[i].resize(unique(all(pos[i]))-pos[i].begin());
        fenw_sum[i] = fenw_invalid_sum[i] = fenwick<ll>(sz(pos[i])+5);
    }

    rep(i,siz){
        if(i){
            bigger_nxt_fenw.pupd(nxt[i-1],dp[i-1]);
            ll cc_ind = get_ind(b[i-1].ff);
            auto &pos_vec = pos[cc_ind];
            ll val_pos = lower_bound(all(pos_vec),nxt[i-1])-pos_vec.begin();
            fenw_sum[cc_ind].pupd(val_pos,dp[i-1]);
            if(nxt[i-1] == i){
                fenw_invalid_sum[cc_ind].pupd(val_pos,dp[i-1]);
            }
        }

        ll ways = 0;

        vector<pll> critical;
        if(prev[i] >= 0) critical = critical_i[prev[i]];

        // add all guys with nxt(j) > prev(i)
        ways += bigger_nxt_fenw.query(prev[i]+1,siz+3);
        critical.pb({-1,0});
        reverse(all(critical));

        // fix val
        rep(ind1,sz(critical)){
            // fix unique set
            ll v = critical[ind1].ss;
            if(!v) conts;
            ll g1 = gcd(b[i].ff,v), g2 = gcd(g1,X);

            rev(ind2,sz(critical)-1,1){
                // check if ok
                if(critical[ind2].ss != g1 and critical[ind2].ss != g2) break;

                // [ind2..sz(critical)-1] --> unique vals in range
                ll l = critical[ind2-1].ff+1, r = critical[ind2].ff;                
                // ways += get_sum(l,r,v);
                
                ll good_sum = 0;

                {
                    ll cc_ind = get_ind(v);
                    auto &pos_vec = pos[cc_ind];
                    ll lx = lower_bound(all(pos_vec),l)-pos_vec.begin();
                    ll rx = upper_bound(all(pos_vec),r)-pos_vec.begin()-1;
                    good_sum = fenw_sum[cc_ind].query(lx,rx);
                }

                ways += good_sum;
                ways %= MOD;

                // subtract out invalid guys i.e nxt[j] = j+1, prev[i] = i-1, only single value in btwn which is equal to gcd(b[i],b[j],x) and not equal to gcd(b[i],b[j])                
                if(prev[i] == i-1 and ind2 == sz(critical)-1 and critical[ind2].ss == g2 and critical[ind2].ss != g1){
                    ll bad_sum = 0;

                    // ways -= invalid_sum(l,r,v);

                    {
                        ll cc_ind = get_ind(v);
                        auto &pos_vec = pos[cc_ind];
                        ll lx = lower_bound(all(pos_vec),l)-pos_vec.begin();
                        ll rx = upper_bound(all(pos_vec),r)-pos_vec.begin()-1;
                        bad_sum = fenw_invalid_sum[cc_ind].query(lx,rx);
                    }
                    
                    ways -= bad_sum;
                    ways = (ways%MOD+MOD)%MOD;
                }
            }
        }

        if(prev[i] < 0) ways++;

        dp[i] = ways*(pow2[b[i].ss]-1+MOD)%MOD;
    }

    // pull from valid guys
    ll ans = 0;

    rep(i,siz){
        ll want = gcd(b[i].ff,X);
        bool good = true;
        for(int j = i+1; j < siz; ++j){
            if(b[j].ff != want){
                good = false;
                break;
            }
        }

        if(good){
            ans += dp[i];
            ans %= MOD;
        }
    }

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}