GCDSORT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Satyam
Testers: Takuki Kurokawa, Utkarsh Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

2542

PREREQUISITES:

A bit of observation, (optional) a range GCD structure

PROBLEM:

You are given an array A. In one move, you can pick an index i and set A_{i+1} = \gcd(A_i, A_{i+1}).
Is it possible to sort A by using this move several times?

EXPLANATION

The solution of this task hinges on the following fact:

Let B_i be the final value at index i after some operations have been applied. Then, B_i = \gcd(A_i, A_{i-1}, A_{i-2}, \ldots, A_j) for some j \leq i.

Proof

This can be proved inductively.

For i = 1, the result is obvious since no operation can change A_1, so B_1 = A_1 always.
Now, consider i \gt 1.
If we don’t apply any operation to this position, B_i = A_i.
Otherwise, consider the last time an operation changed position i.
B_i is set to the gcd of A_i and whatever the current value at position i-1 is.

But, by the inductive hypothesis, the current value at position i-1 is itself some range gcd, of some range ending at i-1.
So, when we take its gcd with A_i, we simply get a range gcd ending at position i, proving our claim.

Now, how do we use this fact?

Consider the following:

  • There is no point changing the value of A_N: it might as well be as large as possible.
  • Now let’s look at A_{N-1}. We want A_{N-1} \leq A_N. To achieve this, we have to perform some operations.
  • From our earlier claim, the only thing we can do is replace A_{N-1} with some range gcd that ends at N-1.
  • Now it’s obvious what we should do: choose this range gcd to be as large as possible, while still remaining \leq A_N

This will form the crux of the solution:

  • Iterate across the array in decreasing order. When at position i, we will perform operations to ensure A_i \leq A_{i+1}.
  • To do this, find the largest index j such that j \leq i and \gcd(A_j, A_{j+1}, \ldots, A_i) \leq A_{i+1}, and perform these operations.
  • Continue on to i-1.

If the array is sorted at the end of this, we are done. Otherwise, there is no way to sort the array.

Now for the implementation: there’s an easy way that seems slow at first, and a less easy way using data structures whose runtime is easy to prove.

I like easy implementations!

You can just brute force!
That is, at each index i, start at j = i and keep decreasing j while maintaining the current \gcd till you reach a point where the \gcd falls below A_{i+1}.
Then, perform operations from index j to index i-1 in order, and continue on to i-1.

This might look like \mathcal{O}(N^2) at first glance, but it isn’t: it amortizes to something like \mathcal{O}(N\log M) where M is the maximum element of the array, here M\leq 10^9.

Proof

Note that an operation can change the value at a given index at most \log A_i times: each time the value changes, it turns into a proper divisor of itself, and so at least halves.

Every operation we perform during our bruteforce will change the value at the corresponding index: if we perform an operation that doesn’t change the current index, there would be no point to moving the left pointer j past this position anyway so we would not have done this.

So, the total number of operations we perform is bounded above by \mathcal{O}(N\log M).
Now, note that each such operation is a gcd operation and hence itself takes \mathcal{O}(\log M) time, making our total complexity \mathcal{O}(N\log^2 M).

However, taking the running gcd of N integers isn’t exactly a complexity of \mathcal{O}(N\log M): as can be seen in this blog, it’s actually \mathcal{O}(N + \log M).
A similar reasoning applies here, making our complexity \mathcal{O}(N\log M + \log M) = \mathcal{O}(N\log M).

I like obvious proofs!

Let p denote the position of the left pointer. Let’s bruteforce in \mathcal{O}(N) to find the position of p for index N-1.

Now, note that A_{N-2} is currently \gcd(A_p, A_{p+1}, \ldots, A_{N-2}).
In particular, the optimal position for N-2 is always going to be \leq p, so it’s enough to start the pointer for N-2 at p.

The only issue is that, it’s easy to calculate the gcd of a bunch of elements when you add a new one, but it’s not that easy when deleting an element. That is, knowing \gcd(A_p, A_{p-1}, \ldots, A_{N-1}) doesn’t really give us enough information to compute \gcd(A_p, A_{p-1}, \ldots, A_{N-2}), which is what we need to start the pointer at p.

However, note that this is just a range GCD, and can be computed in a number of ways: using a segment tree, or a sparse table, or even precomputing all subarray GCDs and binary searching on this list.

The two-pointer part clearly runs in \mathcal{O}(N), and the data structure part adds an extra \mathcal{O}(1) \sim \mathcal{O}(\log N\log M) depending on implementation, so this is certainly fast enough.

TIME COMPLEXITY

\mathcal{O}(N\log M) per test case.

CODE:

Setter's code (C++) (Data structures)
#pragma GCC optimize("O3")
#pragma GCC target("popcnt")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back               
#define mp make_pair        
#define nline "\n"                           
#define f first                                          
#define s second                                               
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()   
#define vl vector<ll>         
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else         
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(int x){cerr<<x;} 
void _print(char x){cerr<<x;} 
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;   
const ll MAX=200200;
ll track[20][MAX];
vector<ll> use(MAX,1);     
vector<ll> anot(MAX,0);   
ll getv(ll l,ll r){      
    ll len=r-l+1;
    len=use[len];  
    ll x=anot[len];
    return __gcd(track[x][l],track[x][r-len+1]);
}
void solve(){       
    ll n; cin>>n;
    vector<ll> a(n+5);          
    for(ll i=1;i<=n;i++){
        cin>>a[i];   
    }        
    for(ll i=0;i<20;i++){
        track[i][n+1]=0;
    }   
    for(ll i=n;i>=1;i--){
        track[0][i]=a[i];    
        for(ll j=1;j<20;j++){
            ll len=use[j];
            ll r=min(i+len,n+1);
            track[j][i]=__gcd(track[j-1][i],track[j-1][r]); 
        }
    }
    ll l=n-1;
    for(ll i=n-1;i>=1;i--){  
        l=min(l,i);
        a[i]=__gcd(a[i],getv(l,i));      
        if(a[i]<=a[i+1]){
            continue;  
        }
        ll found=0;  
        while(l>=2){
            l--;  
            ll comp=getv(l,i);  
            comp=__gcd(comp,a[i]); 
            if(comp<=a[i+1]){
                a[i]=__gcd(a[i],comp);   
                found=1;
                break; 
            }
        }
        if(found==0){
            cout<<"NO\n";
            return;
        }
    }
    cout<<"YES\n";
    return;       
}                              
int main()                                                                               
{ 
    ios_base::sync_with_stdio(false);                           
    cin.tie(NULL);                         
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                                
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif         
    ll test_cases=1;                   
    cin>>test_cases;
    ll cur=2,pos=1;
    while(cur<MAX){
        use[cur]=cur;
        anot[cur]=pos; 
        cur*=2; pos++; 
    }
    for(ll i=2;i<MAX;i++){
        use[i]=max(use[i],use[i-1]); 
    }
    while(test_cases--){ 
        solve();       
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
} 
Tester's code (C++) (Data structures)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string& pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

struct sparse {
    using T = int;
    int n;
    int h;
    vector<vector<T>> table;

    T op(T x, T y) {
        return gcd(x, y);
    }

    sparse(const vector<T>& v) {
        n = (int) v.size();
        h = 32 - __builtin_clz(n);
        table.resize(h);
        table[0] = v;
        for (int j = 1; j < h; j++) {
            table[j].resize(n - (1 << j) + 1);
            for (int i = 0; i <= n - (1 << j); i++) {
                table[j][i] = op(table[j - 1][i], table[j - 1][i + (1 << (j - 1))]);
            }
        }
    }

    T get(int l, int r) {
        assert(0 <= l && l < r && r <= n);
        int k = 31 - __builtin_clz(r - l);
        return op(table[k][l], table[k][r - (1 << k)]);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 3e4);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(1, 1e6);
            (i == n - 1 ? in.readEoln() : in.readSpace());
        }
        string ans = "YES";
        sparse sp(a);
        for (int i = n - 2, j = n - 2; i >= 0; i--) {
            j = min(j, i);
            while (j > 0 && sp.get(j, i + 1) > a[i + 1]) {
                j--;
            }
            a[i] = sp.get(j, i + 1);
            if (a[i] > a[i + 1]) {
                ans = "NO";
                break;
            }
        }
        cout << ans << endl;
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Tester's code (C++) ('Brute force')
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
void solve()
{
    int n;
    cin>>n;
    int A[n+1]={0};
    for(int i=1;i<=n;i++)
        cin>>A[i];
    for(int i=n-1;i>=1;i--)
    {   
        if(A[i]<=A[i+1])
            continue;
        int curr=A[i];
        int j;
        for(j=i-1;j>=1;j--)
        {
            curr=(__gcd(curr,A[j]));
            if(curr<=A[i+1])
                break;
        }
        if(curr>A[i+1])   
        {
            cout<<"NO\n";
            return;  
        }   
        for(int k=j+1;k<=i;k++)  
            A[k]=(__gcd(A[k],A[k-1]));
    }
    for(int i=1;i<n;i++)
    {   
        if(A[i]>A[i+1])
        {  
            cout<<"NO\n";
            return;
        }  
    }
    cout<<"YES\n";
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=1;
    cin>>T;
    while(T--)
        solve();
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<int> a(n);
		for (int &x : a) cin >> x;
		for (int i = n-1; i >= 0; --i) {
			int g = a[i];
			for (int j = i; j >= 0; --j) {
				g = gcd(g, a[j]);
				if (g > a[i+1]) continue;
				for (int x = j; x < i; ++x) a[x+1] = gcd(a[x], a[x+1]);
				break;
			}
		}
		if (is_sorted(begin(a), end(a))) cout << "Yes\n";
		else cout << "No\n";
	}
}

What is the error (or counterexample) of this greedy algorithm?

For each i=1,2,...,if A_{i+1}==k*A_i,let A_{i+1}=A_i.

greedy code

Consider [5, 90, 36, 18].

This greedy was in fact the original proposed solution (and the task was placed at an easier position) till we realized it was actually wrong.

3 Likes

i also did similar to this with extra condition and it passed , is this correct solution or the tests are weak :frowning:


  int ele;
  for(i=1;i<=n;){
      ele=a[i];
      j=i+1;
      while(j<=n && a[j]%ele==0)
      {
          j++;
      }
      if(j<=n && a[j]<a[i]){
          cout<<"NO\n";
          return;
      }
      if(j<=n && gcd(a[j],a[j-1])>=a[i]){
          a[j]=gcd(a[j],a[j-1]);
      }
      i=j;
  }
 
  cout<<"YES\n";

solution

looks like tests are weak

[6 ,210 ,2310 ,385 ,35]

Indeed, it does look like the tests were weak :frowning:

Luckily, I believe it didn’t affect the actual contest much: we were looking through several of the AC solutions submitted during the contest and they were legit solves.