SQUARE_COUNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math, sieve of Eratosthenes

PROBLEM:

You’re given an array A.
Count the number of quadruples (L_1, R_1, L_2, R_2) such that:

  • 1 \leq L_1\leq R_1 \lt L_2\leq R_2\leq N
  • \gcd(A_{L_1}, A_{L_1+1}, \ldots, A_{R_1})\times \gcd(A_{L_2}, A_{L_2+1}, \ldots, A_{R_2}) is a perfect square.

That is, count the number of pairs of disjoint subarrays, the product of whose GCDs is a square.

EXPLANATION:

It is a well-known fact that if one end of a subarray (say i) is fixed, there are \mathcal{O}(\log A_i) distinct subarray GCDs among subarrays that start/end at i, since each successive distinct GCD must be a divisor of the previous one (and hence at least halves).
This also tells us that each GCD occurs on some range of other endpoint.

Further, they can all also be computed in \mathcal{O}(\log N\log A_i ) time, either using the method here (point 3), or with binary search and some range-gcd structure.

Let \text{st}_i denote the distinct GCDs of subarrays (along with the lengths of the corresponding segments) that start at i, and \text{en}_i denote the same but for subarrays ending at i.

Suppose we fix R_1, the right endpoint of the first subarray.
Let’s also fix the gcd g_1 of that subarray, which as mentioned above gives us some range for L_1.
We want to find the number of subarrays (L_2, R_2) with GCD g_2, such that L_2 \gt R_1 and g_2\cdot g_1 is a square.

Now, there are many possible candidates for g_2; however, there’s a rather nice classification for them all:
Let s(x) denote the square-free part of x.
g_1\cdot g_2 is a square if and only if s(g_1) = s(g_2).

Since A_i \leq 10^7, we’re only dealing with GCDs also upto 10^7.
With the help of a sieve, all the s(x) values upto 10^7 can be precomputed.


This allows us to restate our task: with R_1 and g_1 (and hence some range of L_1) fixed, how many subarrays after R_1 are such that s(g_1) = s(g_2)?

This can be computed quickly with the help of a sweep line.
Let \text{S} be an array, such that \text{S}[x] denotes the sum of lengths of all subarrays with square-free part x.
Initially, this is filled with zeros.
We’ll iterate R_1 from N down to 1.
When at R_1:

  • Process all subarrays ending at R_1, where as noted above there’ll be \mathcal{O}(\log R_1) distinct values of g_1, each with a corresponding range of L_1.
    For each range, say of length k, add k\cdot \text{S}[s(g_1)] to the answer.
  • Then, process all subarrays starting at R_1 to update \text{S}[s(g_1)] appropriately, so that they’re available for earlier elements.

This is enough to get AC.
Since there are multiple test cases, \text{S} must either be maintained as a global array (that is reset appropriately for each test), or as a map.

TIME COMPLEXITY:

\mathcal{O}(N\log N\log\max A) per testcase.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int M=10000010;
const int LOGN=25;
const ll  TMD=998244353;
const ll  INF=2147483647;
int T,n;
int a[N],key[M];

//-------------------------------------------------
int lg2[N];
int g[N][LOGN];

int gcd(int a,int b)
{
	return b?gcd(b,a%b):a;
}

void init()
{
	for(int i=1;i<=n;i++) lg2[i]=(int)log2(i);
	for(int i=1;i<=n;i++) g[i][0]=a[i];
	for(int i=1;i<LOGN;i++)
	{
		for(int j=1;j<=n;j++)
		{
			int p=j+(1<<(i-1));
			if(p>n) g[j][i]=g[j][i-1];
			else    g[j][i]=gcd(g[j][i-1],g[p][i-1]);
		}
	}
}

int getgcd(int L,int R)
{
	int t=lg2[R-L+1];
	return gcd(g[L][t],g[R-(1<<t)+1][t]);
}

//-------------------------------------------------

void cal_key()
{
	for(int i=1;i<M;i++) key[i]=i;
	for(int i=2;i<4000;i++)
	{
		for(int j=1;j*i<M;j++)
		{
			while(key[j*i]%(i*i)==0) key[j*i]/=(i*i);
		}
	}
	//
	//for(int i=1;i<=100;i++) printf("key[%d]=%d\n",i,key[i]);
	//
}

vector<pair<int,int> > cal_pair(int x)
{
	int cur=x;
	vector<pair<int,int> > v;
	while(cur)
	{
		int L=0,R=cur,M;
		while(L+1!=R)
		{
			M=(L+R)>>1;
			if(getgcd(M,x)==getgcd(cur,x)) R=M;
			else L=M;
		}
		v.push_back(make_pair(getgcd(cur,x),cur-R+1));
		cur=R-1;
	}
	return v;
}

vector<pair<int,int> > cal_pair2(int x)
{
	int cur=x;
	vector<pair<int,int> > v;
	while(cur<=n)
	{
		int L=cur,R=n+1,M;
		while(L+1!=R)
		{
			M=(L+R)>>1;
			if(getgcd(x,M)==getgcd(x,cur)) L=M;
			else R=M;
		}
		v.push_back(make_pair(getgcd(x,cur),L-cur+1));
		cur=L+1;
	}
	return v;
}

int main()
{
	//freopen("7.in","r",stdin);
	//freopen("7.out","w",stdout);
	
	cal_key(); 
	scanf("%d",&T);
	while(T--)
	{
	    scanf("%d",&n);
		for(int i=1;i<=n;i++) scanf("%d",&a[i]);
		init();
		ll ans=0;
		map<ll,ll> cnt;
		//int cnt[101]={0}; 
		for(int i=1;i<=n;i++)
		{
			vector<pair<int,int> > v=cal_pair(i);
			for(int j=0;j<v.size();j++) cnt[key[v[j].first]]+=v[j].second;
		}
	 	for(int i=1;i<n;i++)
		{
			vector<pair<int,int> > v1=cal_pair(i),v2=cal_pair2(i);
			for(int j=0;j<v2.size();j++) 
				cnt[key[v2[j].first]]-=v2[j].second;
			for(int j=0;j<v1.size();j++)
  				ans=(ans+v1[j].second*cnt[key[v1[j].first]])%TMD;
		}
		printf("%lld\n",ans);
		//
		if(ans<0)
		{
			for(int i=1;i<=n;i++) printf("%d%c",a[i],i==n?'\n':' ');
		}
		// 
	}
	
	//fclose(stdin);
	return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=10000005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

int sq[N];
int pf[N];

struct custom_hash {
    static uint64_t splitmix64(uint64_t x) {
        // http://xorshift.di.unimi.it/splitmix64.c
        x += 0x9e3779b97f4a7c15;
        x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
        x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
        return x ^ (x >> 31);
    }

    size_t operator()(uint64_t x) const {
        static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
        return splitmix64(x + FIXED_RANDOM);
    }
};
unordered_map<long long, int, custom_hash> safe_map;
                             //or
gp_hash_table<long long, int, custom_hash> safe_hash_table;

int32_t main()
{
    IOS;
    for(int i=0;i<N;i++)
    {
        sq[i]=i;
        pf[i]=1;
    }
    for(int i=2;i<N;i++)
    {
        if(pf[i]){
            int x=i;
            while(x<N){
                pf[x]=0;
                while(sq[x]%(i*i)==0){
                    sq[x]/=(i*i);
                }
                x+=i;
            }
        }
    }
    int t;
    cin>>t;
    while(t--)
    {
        int n;
        cin>>n;
        int a[n];
        rep(i,0,n)
        cin>>a[i];
        vector <pii> pref[n], suf[n];
        suf[0].pb({a[0], 0});
        for(int i=1;i<n;i++)
        {
            suf[i].pb({a[i], i});
            int g=a[i];
            for(auto p:suf[i-1])
            {
                int ng=__gcd(g, p.ff);
                if(ng==g)
                suf[i].back().ss=p.ss;
                else
                suf[i].pb({ng, p.ss});
                g=ng;
            }
        }
        pref[n-1].pb({a[n-1], n-1});
        for(int i=n-2;i>=0;i--)
        {
            pref[i].pb({a[i], i});
            int g=a[i];
            for(auto p:pref[i+1])
            {
                int ng=__gcd(g, p.ff);
                if(ng==g)
                pref[i].back().ss=p.ss;
                else
                pref[i].pb({ng, p.ss});
                g=ng;
            }
        }
        gp_hash_table<long long, int, custom_hash> mp;
        int ans=0;
        for(int i=1;i<n;i++)
        {
            int prev=i;
            for(auto p:suf[i-1])
            {
                mp[sq[p.ff]]+=(prev-p.ss);
                prev=p.ss;
            }
            prev=i-1;
            for(auto p:pref[i])
            {
                ans+=(mp[sq[p.ff]]*(p.ss-prev));
                ans%=mod;
                prev=p.ss;
            }
        }
        cout<<ans<<"\n";
    }
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int maxA = 1e7 + 10;
    const int mod = 998244353;
    vector<int> spf(maxA), sqfree(maxA);
    for (int i = 2; i < maxA; ++i) {
        if (spf[i]) continue;
        for (int j = i; j < maxA; j += i)
            if (!spf[j]) spf[j] = i;
    }
    sqfree[1] = 1;
    for (int i = 2; i < maxA; ++i) {
        int p = spf[i];
        if (spf[i/p] == p) sqfree[i] = sqfree[(i/p)/p];
        else sqfree[i] = p*sqfree[i/p];
    }

    vector<ll> lensum(maxA);
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;
        vector<map<int, int>> end_gcd(n), start_gcd(n);
        ll ans = 0;
        for (int i = 0; i < n; ++i) {
            end_gcd[i][a[i]] = 0;
            if (i) for (auto &[x, y] : end_gcd[i-1]) {
                int g = gcd(x, a[i]);
                end_gcd[i][g] = max(end_gcd[i][g], y + 1);
            }
        }
        for (int i = n-1; i >= 0; --i) {
            auto it = begin(end_gcd[i]);
            while (it != end(end_gcd[i])) {
                auto [x, y] = *it;
                int pos = sqfree[x];

                int len = 0;
                if (next(it) != end(end_gcd[i])) {
                    auto [x2, y2] = *next(it);
                    len = y - y2;
                }
                else {
                    len = y+1;
                }
                ans += (len * lensum[pos] % mod) % mod;
                ++it;
            }
            

            start_gcd[i][a[i]] = 0;
            if (i+1 < n) for (auto &[x, y] : start_gcd[i+1]) {
                int g = gcd(x, a[i]);
                start_gcd[i][g] = max(start_gcd[i][g], y + 1);
            }
            auto it2 = begin(start_gcd[i]);
            while (it2 != end(start_gcd[i])) {
                auto [x, y] = *it2;
                int pos = sqfree[x];
                if (next(it2) != end(start_gcd[i])) {
                    auto [x2, y2] = *next(it2);
                    lensum[pos] += y - y2;
                }
                else {
                    lensum[pos] += y + 1;
                }
                ++it2;
            }
        }
        for (int i = 0; i < n; ++i) for (auto &[x, y] : start_gcd[i])
            lensum[sqfree[x]] = 0;
        cout << ans%mod << '\n';
    }
}
1 Like