FERM_SQUARE - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: piyush_2007
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Sieve of Eratosthenes, XOR hashing

PROBLEM:

The score of an array is defined to be the minimum number of squared integers whose sums equals the product of the array.
You’re given an array A. Find the sum of scores of all its subarrays.

EXPLANATION:

The main piece of knowledge required to solve this task is the four square theorem - namely that any non-negative integer can be written as the sum of four squares.

This immediately restricts the score of any subarray to be \leq 4.
So, to compute the answer, what we really need to do is find the number of subarrays with scores 1, 2, 3, and 4.

Let’s work on each of these separately.

Score = 1

A subarray has a score of 1 if and only if its product is a square number.
In terms of prime factorizations, this means every prime factor must occur an even number of times in the product.

Let’s define S_i to be the set of primes that occur an odd number of times in the factorization of
A_1 \times A_2\times\cdots\times A_i, i.e. in the prefix till i.

Then, observe that the product A_i\times\cdots\times A_j is a square if and only if S_{i-1} = S_j.
After all, S_{i-1} = S_j means that the product of numbers from indices i to j contains any prime number an even number of times.

So, assuming we’re able to compute all the sets S_i, all we need to do is count the number of equal pairs of sets.
However, even maintaining all the sets S_i isn’t possible directly, and will use quadratic memory.

To get around this, we use hashing - specifically, XOR hashing.
First, map every value from 1 to 10^7 to a random integer in the range [0, 2^{61}).
Now, we’ll represent the set S_i by the bitwise XOR of the (hashed values of the) integers in it.
This allows us to represent each S_i using a single integer, and updating it becomes quite easy as well: S_i can be obtained by taking S_{i-1} and then XOR-ing it with the (hashed values of the) prime factors of A_i that have odd powers.

Once this is done, simply count the number of equal pairs of elements which is a standard task.

Score = 2

An integer can be written as the sum of two squares if and only if for every p^k in its prime factorization (p being a prime), either p \not\equiv 3 \pmod 4 or k is even.
Here’s a page detailing this result.

So, our aim is now to count subarrays such that in their product, every prime of the form 4x + 3 appears an even number of times.
Observe that this is basically the same thing as the \text{score} = 1 case, except we’re now restricted to primes of the form 4x + 3 rather than all primes.
This means the exact same algorithm works once all primes not of the form 4x + 3 have been thrown out: that is, use XOR hashing to store information about each prefix, and then count the number of equal pairs of elements.

One thing to note is that the count obtained here is in fact the number of subarrays with score at most 2 - it will include anything with a score of 1 as well.
So, we must subtract the count of subarrays with score 1, to get the number of subarrays with score exactly 2.

Score = 3

Once again, numbers that can be written as the sum of three squares can be characterized by their prime factorizations: they must not be of the form 4^a(8b + 7).

There are several ways to count subarrays with such products.
One relatively straightforward method is divide-and-conquer.

Let f(l, r) denote the number of subarrays within [l, r) satisfying the condition.
Let m = \frac{l+r}{2}. Then, f(l, m) and f(m, r) will give us the count of subarrays that don’t cross the midpoint; we only focus on subarrays that do cross m.

Essentially, we’re looking to combine some [x, m) with some [m, y) so that [x, y) is not of the form 4^a(8b + 7).

Let’s fix the value of x, and try to count valid y.
Note that the product of [x, m) only really needs to be stored modulo 8 once all powers of 4 have been removed. This allows us to not have to work with large numbers.

Let p be the product of [x, m), with powers of 4 removed and modulo 8.
For each m \lt y \leq r, if q denotes the product of [m, y) under this same reduction, we only need to check if p\cdot q after reduction is not 7 modulo 8.
Since q is stored modulo 8 anyway, rather than iterate over all y for a fixed x, we can iterate over the value of q instead, and for each valid one add the number of such q to the answer (which can be precomputed).

This runs in \mathcal{O}(8N\log N) which is fast enough.
Once again, note that this count is of subarrays with a score of at most three, so subtract the count of subarrays with score at most 2 to obtain the count of those with score equal to 3.

With the above three quantities known, the number of subarrays with score 4 is simply the total number of subarrays minus the sum of the above three values.


Note that everything above assumed that every A_i was prime factorized.
This can be done quickly using the sieve of Eratosthenes, since all elements are \leq 10^7.

TIME COMPLEXITY:

\mathcal{O}(N\log M) per testcase, where M = 10^7.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define int long long

int rand(int l, int r){
     static mt19937 
     rng(chrono::steady_clock::now().time_since_epoch().count());
     uniform_int_distribution<int> ludo(l, r); 
     return ludo(rng);
}

vector<int> pr;
const int N = 10000000;
vector<int> lp(N + 1);
vector<int> val(N + 1);

void gen(){
    for (int i = 0; i <= N; ++i){
        val[i] = rand(0,1e18);
    }
}

void linear_sieve(){
    for (int i = 2; i <= N; ++i){
        if (lp[i] == 0){
            lp[i] = i;
            pr.push_back(i);
        }
        for (int j = 0; i * pr[j] <= N; ++j){
            lp[i * pr[j]] = pr[j];
            if (pr[j] == lp[i]){
                break;
            }
        }
    }
}

int count1(int n, vector<int> a){
    int ans = 0;
    map<int, int> cnt;
    cnt[0]++;
    vector<int> h(n);
    for (int i = 0; i < n; ++i){
        if (i > 0){
            h[i] ^= h[i - 1];
        }
        while (a[i] > 1){
            int c = 0;
            int x = lp[a[i]];
            while (a[i] % x == 0){
                c ^= 1;
                a[i] /= x;
            }
            if (c){
                h[i] ^= val[x];
            }
        }
        ans += cnt[h[i]];
        cnt[h[i]]++;
    }
    return ans;
}

int count2(int n, vector<int> a){
    int ans = 0;
    map<int, int> cnt;
    cnt[0]++;
    vector<int> h(n);
    for (int i = 0; i < n; ++i){
        if (i > 0){
            h[i] ^= h[i - 1];
        }
        while (a[i] > 1){
            int c = 0;
            int x = lp[a[i]];
            while (a[i] % x == 0){
                c ^= 1;
                a[i] /= x;
            }
            if (c and (x % 4 == 3)){
                h[i] ^= val[x];
            }
        }
        ans += cnt[h[i]];
        cnt[h[i]]++;
    }
    return ans;
}

int count3(int n, vector<int> a){
    int ans = 0;
    vector<int> c(8);
    map<array<int, 4>, int> cnt;
    cnt[{0, 0, 0, 0}] = 1;
    for (int i = 0; i < n; ++i){
        while(a[i] % 2 == 0){
            c[0] ^= 1;
            a[i] /= 2;
        }
        c[a[i] % 8] ^= 1;
        ans += cnt[{c[0], c[3], c[5], c[7] ^ 1}];
        ans += cnt[{c[0], c[3] ^ 1, c[5] ^ 1, c[7]}];
        cnt[{c[0], c[3], c[5], c[7]}]++;
    }
    return ans;
}

int32_t main(){

    ios_base::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    int t;
    cin >> t;
    gen();
    linear_sieve();
    while (t--){
        int n;
        cin >> n;
        vector<int> a(n);
        for (int i = 0; i < n; ++i){
            cin >> a[i];
        }
        int tot = n * (n + 1) / 2;
        int c[4] = {0, count1(n, a), count2(n, a), count3(n, a)};
        int ans = 0;
        ans += c[1];
        ans += 2 * (c[2] - c[1]);
        ans += 3 * (tot - c[3] - c[2]);
        ans += 4 * c[3];
        ans%=1000000007;
        cout << ans << '\n';
    }
    return 0;
}
Tester's code (C++)
//Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());




const int N = 1000'0000;
vector<int> lp(N+1);
vector<int> pr;
int m[N+1],hsh[N+1],hsh1[N+1];

#define SIEVE

void sieve(){
    for (int i = 0;i <= N;++i)m[i] = rng();
    for (int i=2; i <= N; ++i) {
        if (lp[i] == 0) {
            lp[i] = i;
            pr.push_back(i);
        }
        for (int j = 0; i * pr[j] <= N; ++j) {
            lp[i * pr[j]] = pr[j];
            if (pr[j] == lp[i]) {
                break;
            }
        }
    }
    hsh[1] = 0;
    hsh1[1] = 0;
    for(int i = 2;i <= N;++i){
        hsh[i] = hsh[i/lp[i]]^m[lp[i]];
        if((lp[i]%4) == 3)hsh1[i] = hsh1[i/lp[i]]^m[lp[i]];
        else hsh1[i] = hsh1[i/lp[i]];
    }
}


void solve()
{
    int n;
    cin >> n;

    vi a(n);
    take(a,n);

    vi b = a;
    vi pw(n);
    rep(i,0,n){
        while((b[i]%2) == 0)b[i]/=2,pw[i]++;
    }

    vi num(5);

    vector<vector<int>> cnt(2,vi(4,0));
    rep(i,0,2)rep(j,0,4)cnt[i][j] = 0;
    rrep(i,n-1,0){
        vector<vector<int>> tcnt(2,vi(4,0));
        if(pw[i]&1){
            rep(j,0,4)swap(cnt[0][j],cnt[1][j]);
        }
        tcnt = cnt;
        rep(j,0,4){
            tcnt[0][(((2*j+1)*b[i])%8)/2] = cnt[0][j];
            tcnt[1][(((2*j+1)*b[i])%8)/2] = cnt[1][j];
        }
        cnt = tcnt;
        if(pw[i]&1)cnt[1][(b[i]%8)/2]++;
        else cnt[0][(b[i]%8)/2]++;
        num[4] += cnt[0][3];
    }

    vi pf(n),pf1(n);
    repin{
        pf[i] = hsh[a[i]];
        pf1[i] = hsh1[a[i]];
    }
    rep(i,1,n){
        pf[i] ^= pf[i-1];
        pf1[i] ^= pf1[i-1];
    }
    map<int,int> c,c1;
    c[0]++;c1[0]++;
    repin{
        num[1] += c[pf[i]];
        num[2] += c1[pf1[i]];
        c[pf[i]]++;
        c1[pf1[i]]++;
    }

    num[2] -= num[1];
    num[3] = n*(n+1)/2 - num[4] - num[1] - num[2];
    long long ans = 0;
    rep(i,1,5)ans += i*num[i];
    ans %= M;
    cout << ans << "\n";

}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t;
    cin >> t;
    while(t--)
        solve();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    const int mx = 1e7 + 100;
    vector<int> pfac(mx);
    for (int i = 2; i < mx; ++i) {
        if (pfac[i]) continue;
        for (int j = i; j < mx; j += i)
            pfac[j] = i;
    }
    
    vector<ll> hash_val(mx);
    for (auto &x : hash_val) x = uniform_int_distribution<ll>(0, (1ll << 62) - 1)(RNG);
    auto calc = [&] (auto a) {
        map<ll, int> ct;
        ll res = 0;
        for (auto x : a) {
            res += ct[x];
            ++ct[x];
        }
        return res;
    };
    
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        ll ct[5] = {0, 0, 0, 0};
        
        // ct[4]
        ct[4] = 1ll*n*(n+1)/2;
        
        // ct[1]
        vector<ll> pref(n+1);
        for (int i = 0; i < n; ++i) {
            pref[i+1] = pref[i];
            int x = a[i];
            while (x > 1) {
                pref[i+1] ^= hash_val[pfac[x]];
                x /= pfac[x];
            }
        }
        ct[1] = calc(pref);

        // ct[2]
        for (int i = 0; i < n; ++i) {
            pref[i+1] = pref[i];
            int x = a[i];
            while (x > 1) {
                if (pfac[x]%4 == 3) pref[i+1] ^= hash_val[pfac[x]];
                x /= pfac[x];
            }
        }
        ct[2] = calc(pref);

        // ct[3]
        vector<int> b(n), c(n);
        auto dnc = [&] (const auto &self, int L, int R) -> ll {
            if (L+1 == R) {
                return (b[L] != 7 or c[L] != 0);
            }
            int mid = (L + R)/2;
            ll res = self(self, L, mid) + self(self, mid, R);

            array<int, 8> freq{};
            int prod = 1, two = 0;
            for (int i = mid; i < R; ++i) {
                prod = (prod * b[i]) % 8;
                two ^= c[i];
                ++freq[prod ^ two];
            }
            
            prod = 1, two = 0;
            for (int i = mid-1; i >= L; --i) {
                prod = (prod * b[i]) % 8;
                two ^= c[i];

                for (int x = 0; x < 8; ++x) {
                    int q = (prod * (2*(x/2) + 1)) % 8;
                    int ctwo = two ^ ((x+1) % 2);

                    if (ctwo or q != 7) res += freq[x];
                }
            }
            return res;
        };
        for (int i = 0; i < n; ++i) {
            int x = a[i];
            while (x%4 == 0) x /= 4;
            if (x%2 == 0) {
                b[i] = (x/2) % 8;
                c[i] = 1;
            }
            else b[i] = x % 8;
        }
        ct[3] = dnc(dnc, 0, n);

        ll ans = 0;
        for (int i = 1; i <= 4; ++i)
            ans += i * (ct[i] - ct[i-1]);
        cout << ans%1000000007 << '\n';
    }
}
1 Like

Alternatively, you can scan once and maintain a frequency table of reduced prefix products, i.e. store (a \%2, c) for 4^a (8b + c). When looking at a prefix product of 4^{a'}(8b'+c'), for the product in-between to require 4 squares a' \equiv a \pmod 2 and c' / c = 7 \pmod 8 must hold, which means (a \%2, c) = (a' \%2, 7/c' \%8).

mint count_four = 0;
int two = 0, eight = 1;
map<pair<int, int>, int> cnt2 = {{{0, 1}, 1}};
for (int i = 0; i < n; ++i) {
  for (auto f : fs[i]) {  // prime factors of a[i]
    if (f == 2) {
      two ^= 1;
    } else {
      eight *= f;
      eight %= 8;
    }
  }
  count_four += cnt2[{two, 8 - eight}];  // shortcut for (7/eight) % 8
  ++cnt2[{two, eight}];
}
1 Like