SQRSUBMA - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Tester: Trung Nguyen
Editorialist: Taranpreet Singh

DIFFICULTY

Easy-Medium

PREREQUISITES

Sliding window, Observation, and Maths.

PROBLEM

Given an array A of length N and an integer X, construct an N \times N matrix B such that B_{i, j} = A_i+A_j. Find the number of square submatrices having sum X.

QUICK EXPLANATION

  • We only need to consider submatrices with side length sz such that sz | X.
  • For side length sz, the number of pairs of subarrays of A of length sz such that sum of pair of subarrays is X/sz. Use the sliding window to obtain sz-length subarray sums and use frequency array to count the number of pairs for each sz in O(N)

EXPLANATION

Consider submatrix with top left cell (i, j) and side length sz. We can write submatrix sum as

\displaystyle \sum_{k = 0}^{sz-1} \sum_{l = 0}^{sz-1} B_{i+k,j+l} = \sum_{k = 0}^{sz-1} \sum_{l = 0}^{sz-1} A_{i+k} + A_{j+l} = \sum_{k = 0}^{sz-1} \sum_{l = 0}^{sz-1} A_{i+k} + \sum_{k = 0}^{sz-1} \sum_{l = 0}^{sz-1} A_{j+l}

which can be written as

\displaystyle sz* \Bigg[ \sum_{k = 0}^{sz-1} A_{i+k} + \sum_{k = 0}^{sz-1} A_{j+k} \Bigg]

Denoting subarray sum from l-th element to r-th element inclusive as A_{l, r}, we need X = sz*(A_{i, i+sz-1}+A_{j, j+sz-1}) when considering sub-squares of side length sz

Hence, Only subsquares with side length sz can have sum X such that sz | X

So, we can factorize X and try all factors of X as side length.

Now, for side length sz, the problem becomes, consider all sz-length subarrays of A and find their sums. Among these values, find the number of ordered pairs which have a total sum X/sz

What we can do is to use the sliding window, keeping window length fixed, adding one element to the end, and removing one element from the start to obtain all sz-length subarray sums in O(N) time. Since subarray sums are positive, we only need to care about subarray sums \leq X, so we can make a frequency array, maintaining frequency for each subarray sum. Then we can iterate over subarray sum, say z is current sum, we add freq_{X/sz - z}

Implementation points:

  • Declare the frequency array only once, rather than for each test case. There can be cases with max T and maximum X, which may TLE.
  • When moving to next value of sz, instead of clearing the whole frequency array, delete the subarray sums individually.

A worthy problem to try DOTTIME

TIME COMPLEXITY

The time complexity is O(\sqrt X + d(X)*N) per test case where d(X) denotes the number of factors of X.

SOLUTIONS

Setter's Solution
/**
*
* Author: MARS
* Lang: GNU C++14
*
**/

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;

const int N=100100;
int frq[N*10];
vector<int>v;
ll sum[N];
int a[N];

int main(){
	int t;
	scanf("%d",&t);
	while(t--){
	    int n,x;
	    scanf("%d%d",&n,&x);
	    for(int i=1 ; i<=n ; i++)
	        scanf("%d",&a[i]);


	    for(int i=1 ; i<=n ; i++)
	        sum[i]=sum[i-1]+a[i];


	    for(int i=1 ; i*i<=x ; i++){
	        if(x%i == 0){
	            v.push_back(i);
	            if(i*i != x)
	                v.push_back(x/i);
	        }
	    }

	    ll ans=0;
	    for(auto len:v){
	        if(len > n) continue;
	        int z=x/len;
	        for(int i=len ; i<=n ; i++){
	            ll s=sum[i]-sum[i-len];
	            if(s > z) continue;
	            frq[s]++;
	        }

	        for(int i=len ; i<=n ; i++){
	            ll s=sum[i]-sum[i-len];
	            if(s > z) continue;
	            ans+=frq[z-s];
	        }

	        for(int i=len ; i<=n ; i++){
	            ll s=sum[i]-sum[i-len];
	            if(s > z) continue;
	            frq[s]=0;
	        }
	    }

	    printf("%lld\n",ans);
	    v.clear();
	}
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
 
#define ms(s, n) memset(s, n, sizeof(s))
#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define FORd(i, a, b) for (int i = (a) - 1; i >= (b); --i)
#define FORall(it, a) for (__typeof((a).begin()) it = (a).begin(); it != (a).end(); it++)
#define sz(a) int((a).size())
#define present(t, x) (t.find(x) != t.end())
#define all(a) (a).begin(), (a).end()
#define uni(a) (a).erase(unique(all(a)), (a).end())
#define pb push_back
#define pf push_front
#define mp make_pair
#define fi first
#define se second
#define prec(n) fixed<<setprecision(n)
#define bit(n, i) (((n) >> (i)) & 1)
#define bitcount(n) __builtin_popcountll(n)
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pi;
typedef vector<int> vi;
typedef vector<pi> vii;
const int MOD = (int) 1e9 + 7;
const int FFTMOD = 119 << 23 | 1;
const int INF = (int) 1e9 + 23111992;
const ll LINF = (ll) 1e18 + 23111992;
const ld PI = acos((ld) -1);
const ld EPS = 1e-9;
inline ll gcd(ll a, ll b) {ll r; while (b) {r = a % b; a = b; b = r;} return a;}
inline ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
inline ll fpow(ll n, ll k, int p = MOD) {ll r = 1; for (; k; k >>= 1) {if (k & 1) r = r * n % p; n = n * n % p;} return r;}
template<class T> inline int chkmin(T& a, const T& val) {return val < a ? a = val, 1 : 0;}
template<class T> inline int chkmax(T& a, const T& val) {return a < val ? a = val, 1 : 0;}
inline ull isqrt(ull k) {ull r = sqrt(k) + 1; while (r * r > k) r--; return r;}
inline ll icbrt(ll k) {ll r = cbrt(k) + 1; while (r * r * r > k) r--; return r;}
inline void addmod(int& a, int val, int p = MOD) {if ((a = (a + val)) >= p) a -= p;}
inline void submod(int& a, int val, int p = MOD) {if ((a = (a - val)) < 0) a += p;}
inline int mult(int a, int b, int p = MOD) {return (ll) a * b % p;}
inline int inv(int a, int p = MOD) {return fpow(a, p - 2, p);}
inline int sign(ld x) {return x < -EPS ? -1 : x > +EPS;}
inline int sign(ld x, ld y) {return sign(x - y);}
mt19937 mt(chrono::high_resolution_clock::now().time_since_epoch().count());
inline int mrand() {return abs((int) mt());}
inline int mrand(int k) {return abs((int) mt()) % k;}
#define db(x) cerr << "[" << #x << ": " << (x) << "] ";
#define endln cerr << "\n";

#define double long double
namespace FFT {
	const int maxf = 1 << 18; //Up to one million digits
	struct cp {
	    double x, y;
	    cp(double x = 0, double y = 0) : x(x), y(y) {}
	    cp operator + (const cp& rhs) const {
	        return cp(x + rhs.x, y + rhs.y);
	    }
	    cp operator - (const cp& rhs) const {
	        return cp(x - rhs.x, y - rhs.y);
	    }
	    cp operator * (const cp& rhs) const {
	        return cp(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x);
	    }
	    cp operator !() const {
	        return cp(x, -y);
	    }
	} rts[maxf + 1];
	cp fa[maxf], fb[maxf];
	cp fc[maxf], fd[maxf];

	int bitrev[maxf];
	void fftinit() {
	    int k = 0; while ((1 << k) < maxf) k++;
	    bitrev[0] = 0;
	    for (int i = 1; i < maxf; i++) {
	        bitrev[i] = bitrev[i >> 1] >> 1 | ((i & 1) << k - 1);
	    }
	    double PI = acos((double) -1.0);
	    rts[0] = rts[maxf] = cp(1, 0);
	    for (int i = 1; i + i <= maxf; i++) {
	        rts[i] = cp(cos(i * 2 * PI / maxf), sin(i * 2 * PI / maxf));
	    }
	    for (int i = maxf / 2 + 1; i < maxf; i++) {
	        rts[i] = !rts[maxf - i];
	    }
	}
	void dft(cp a[], int n, int sign) {
	    static int isinit;
	    if (!isinit) {
	        isinit = 1;
	        fftinit();
	    }
	    int d = 0; while ((1 << d) * n != maxf) d++;
	    for (int i = 0; i < n; i++) {
	        if (i < (bitrev[i] >> d)) {
	            swap(a[i], a[bitrev[i] >> d]);
	        }
	    }
	    for (int len = 2; len <= n; len <<= 1) {
	        int delta = maxf / len * sign;
	        for (int i = 0; i < n; i += len) {
	            cp *x = a + i,*y = a + i + (len >> 1), *w = sign > 0 ? rts : rts + maxf;
	            for (int k = 0; k + k < len; k++) {
	                cp z = *y * *w;
	                *y = *x - z, *x = *x + z;
	                x++, y++, w += delta;
	            }
	        }
	    }
	    if (sign < 0) {
	        for (int i = 0; i < n; i++) {
	            a[i].x /= n;
	            a[i].y /= n;
	        }
	    }
	}
	void multiply(int a[], int b[], int na, int nb, int c[], int& nc, int base, int dup) {
	    int n = na + nb - 1;
	    while (n != (n & -n)) n += n & -n;
	    for (int i = 0; i < n; i++) fa[i] = fb[i] = cp();
	    static const int magic = 15;
	    for (int i = 0; i < na; i++) fa[i] = cp(a[i] >> magic, a[i] & (1 << magic) - 1);
	    for (int i = 0; i < nb; i++) fb[i] = cp(b[i] >> magic, b[i] & (1 << magic) - 1);
	    dft(fa, n, 1);
	    if (dup) {
	        for (int i = 0; i < n; i++) fb[i] = fa[i];
	    }
	    else {
	        dft(fb, n, 1);
	    }
	    for (int i = 0; i < n; i++) {
	        int j = (n - i) % n;
	        cp x = fa[i] + !fa[j];
	        cp y = fb[i] + !fb[j];
	        cp z = !fa[j] - fa[i];
	        cp t = !fb[j] - fb[i];
	        fc[i] = (x * t + y * z) * cp(0, 0.25);
	        fd[i] = x * y * cp(0, 0.25) + z * t * cp(-0.25, 0);
	    }
	    dft(fc, n, -1), dft(fd, n, -1);
	    nc = 0;
	    long long carry = 0;
	    for (int i = 0; i < n; i++) {
	        long long u = (long long) round(fc[i].x);
	        long long v = (long long) round(fd[i].x);
	        long long w = (long long) round(fd[i].y);
	        long long ncarry = (u / base << 15) + (u % base << 15) / base + v / base + (w / base << 30) + (w % base << 30) / base;
	        assert(0 <= ncarry && ncarry < 2e18);
	        long long t = (u % base << 15) % base + v % base + (w % base << 30) % base + carry;
	        carry = ncarry + t / base;
	        assert(carry < 2e18);
	        c[nc++] = t % base;
	    }
	    while (carry) {
	        c[nc++] = carry % base;
	        carry /= base;
	    }
	}
	vector<int> multiply(vector<int> a, vector<int> b, int base) {
	    static int fa[maxf], fb[maxf], fc[maxf + 5];
	    int na = a.size(), nb = b.size();
	    for (int i = 0; i < na; i++) fa[i] = a[i];
	    for (int i = 0; i < nb; i++) fb[i] = b[i];
	    int nc = 0;
	    multiply(fa, fb, na, nb, fc, nc, base, a == b);
	    vector<int> res(nc);
	    for (int i = 0; i < nc; i++) res[i] = fc[i];
	    while (1 < res.size() && !res.back()) res.pop_back();
	    return res;
	}
}
#undef double

const int base = 10;
const int nblock = 9;
const int blockbase = (int) round(pow(base, nblock));
struct Bignum {
	vector<int> a;
	int sign;
	Bignum() : sign(1) {}
	Bignum(long long v) {*this = v;}
	Bignum(const string& s) {read(s);}
	void operator = (const Bignum& v) {sign = v.sign; a = v.a;}
	void operator = (long long v) {
	    a.clear();
	    sign = 1;
	    if (v < 0)
	        sign = -1, v = -v;
	    for (; v > 0; v = v / blockbase)
	        a.push_back(v % blockbase);
	}
	Bignum operator + (const Bignum& v) const {
	    if (sign == v.sign) {
	        Bignum res = v;
	        for (int i = 0, carry = 0; i < (int) max(a.size(), v.a.size()) || carry; i++) {
	            if (i == (int) res.a.size()) res.a.push_back(0);
	            res.a[i] += carry + (i < (int) a.size() ? a[i] : 0);
	            carry = res.a[i] >= blockbase;
	            if (carry) res.a[i] -= blockbase;
	        }
	        return res;
	    }
	    return *this - (-v);
	}
	Bignum operator - (const Bignum& v) const {
	    if (sign == v.sign) {
	        if (abs() >= v.abs()) {
	            Bignum res = *this;
	            for (int i = 0, carry = 0; i < (int) v.a.size() || carry; i++) {
	                res.a[i] -= carry + (i < (int) v.a.size() ? v.a[i] : 0);
	                carry = res.a[i] < 0;
	                if (carry) res.a[i] += blockbase;
	            }
	            res.trim();
	            return res;
	        }
	        return -(v - *this);
	    }
	    return *this + (-v);
	}
	void operator *= (int v) {
	    if (v < 0) sign = -sign, v = -v;
	    for (int i = 0, carry = 0; i < (int) a.size() || carry; i++) {
	        if (i == (int) a.size()) a.push_back(0);
	        long long cur = a[i] * (long long) v + carry;
	        carry = (int) (cur / blockbase);
	        a[i] = (int) (cur % blockbase);
	    }
	    trim();
	}
	void operator *= (long long v) {
	    if (v >= (long long) blockbase * blockbase) {
	        *this *= Bignum(v);
	    }
	    int a = v / blockbase;
	    int b = v % blockbase;
	    *this = *this * a * blockbase + *this * b;
	}
	Bignum operator * (int v) const {
	    Bignum res = *this;
	    res *= v;
	    return res;
	}
	Bignum operator * (long long v) const {
	    Bignum res = *this;
	    res *= v;
	    return res;
	}
	friend pair<Bignum, Bignum> divmod(const Bignum& a1, const Bignum& b1) {
	    int norm = blockbase / (b1.a.back() + 1);
	    Bignum a = a1.abs() * norm;
	    Bignum b = b1.abs() * norm;
	    Bignum q, r;
	    q.a.resize(a.a.size());
	    for (int i = a.a.size() - 1; i >= 0; i--) {
	        r *= blockbase;
	        r += a.a[i];
	        int s1 = r.a.size() <= b.a.size() ? 0 : r.a[b.a.size()];
	        int s2 = r.a.size() <= b.a.size() - 1 ? 0 : r.a[b.a.size() - 1];
	        int d = ((long long) blockbase * s1 + s2) / b.a.back();
	        r -= b * d;
	        while (r < 0)
	            r += b, d--;
	        q.a[i] = d;
	    }
	    q.sign = a1.sign * b1.sign;
	    r.sign = a1.sign;
	    q.trim();
	    r.trim();
	    return make_pair(q, r / norm);
	}
	Bignum operator / (const Bignum& v) const {
	    return divmod(*this, v).first;
	}
	Bignum operator % (const Bignum& v) const {
	    return divmod(*this, v).second;
	}
	void operator /= (int v) {
	    if (v < 0) sign = -sign, v = -v;
	    for (int i = (int) a.size() - 1, rem = 0; i >= 0; i--) {
	        long long cur = a[i] + rem * (long long) blockbase;
	        a[i] = (int) (cur / v);
	        rem = (int) (cur % v);
	    }
	    trim();
	}
	void operator /= (long long v) {
	    *this /= Bignum(v);
	}
	Bignum operator / (int v) const {
	    Bignum res = *this;
	    res /= v;
	    return res;
	}
	Bignum operator / (long long v) const {
	    Bignum res = *this;
	    res /= v;
	    return res;
	}
	int operator % (int v) const {
	    if (v < 0) v = -v;
	    int m = 0;
	    for (int i = a.size() - 1; i >= 0; i--) m = (a[i] + m * (long long) blockbase) % v;
	    return m * sign;
	}
	long long operator % (long long v) const {
	    return (*this % Bignum(v)).longValue();
	}
	void operator += (const Bignum& v) {
	    *this = *this + v;
	}
	void operator -= (const Bignum& v) {
	    *this = *this - v;
	}
	void operator *= (const Bignum& v) {
	    *this = *this * v;
	}
	void operator /= (const Bignum& v) {
	    *this = *this / v;
	}
	bool operator < (const Bignum& v) const {
	    if (sign != v.sign) return sign < v.sign;
	    if (a.size() != v.a.size()) return a.size() * sign < v.a.size() * v.sign;
	    for (int i = a.size() - 1; i >= 0; i--) if (a[i] != v.a[i]) return a[i] * sign < v.a[i] * sign;
	    return false;
	}
	bool operator > (const Bignum& v) const {
	    return v < *this;
	}
	bool operator <= (const Bignum& v) const {
	    return !(v < *this);
	}
	bool operator >= (const Bignum& v) const {
	    return !(*this < v);
	}
	bool operator == (const Bignum& v) const {
	    return !(*this < v) && !(v < *this);
	}
	bool operator != (const Bignum& v) const {
	    return *this < v || v < *this;
	}
	void trim() {
	    while (!a.empty() && !a.back()) a.pop_back();
	    if (a.empty()) sign = 1;
	}
	bool isZero() const {
	    return a.empty() || (a.size() == 1 && !a[0]);
	}
	Bignum operator - () const {
	    Bignum res = *this;
	    res.sign = -sign;
	    return res;
	}
	Bignum abs() const {
	    Bignum res = *this;
	    res.sign *= res.sign;
	    return res;
	}
	long long longValue() const {
	    long long res = 0;
	    for (int i = a.size() - 1; i >= 0; i--) res = res * blockbase + a[i];
	    return res * sign;
	}
	friend Bignum gcd(const Bignum& a, const Bignum& b) {
	    return b.isZero() ? a : gcd(b, a % b);
	}
	friend Bignum lcm(const Bignum& a, const Bignum& b) {
	    return a / gcd(a, b) * b;
	}
	void read(const string& s) {
	    sign = 1; a.clear(); int pos = 0;
	    while (pos < (int) s.size() && (s[pos] == '-' || s[pos] == '+')) {if (s[pos] == '-') sign = -sign; pos++;}
	    for (int i = s.size() - 1; i >= pos; i -= nblock) {
	        int x = 0;
	        for (int j = max(pos, i - nblock + 1); j <= i; j++) x = x * base + s[j] - '0';
	        a.push_back(x);
	    }
	    trim();
	}
	friend istream& operator>>(istream& stream, Bignum& v) {
	    string s; stream>>s; v.read(s);
	    return stream;
	}
	friend ostream& operator<<(ostream& stream, const Bignum& v) {
	    if (v.sign == -1) stream << '-';
	    stream<<(v.a.empty() ? 0 : v.a.back());
	    for (int i = (int) v.a.size() - 2; i >= 0; i--) stream << setw(nblock) << setfill('0') << v.a[i];
	    return stream;
	}
	static vector<int> convert_base(const vector<int>& a, int old_digits, int new_digits) {
	    vector<long long> p(max(old_digits, new_digits) + 1);
	    p[0] = 1;
	    for (int i = 1; i < (int) p.size(); i++) p[i] = p[i - 1] * base;
	    vector<int> res;
	    long long cur = 0;
	    int cur_digits = 0;
	    for (int i = 0; i < (int) a.size(); i++) {
	        cur += a[i] * p[cur_digits];
	        cur_digits += old_digits;
	        while (cur_digits >= new_digits) {
	            res.push_back(int(cur % p[new_digits]));
	            cur /= p[new_digits];
	            cur_digits -= new_digits;
	        }
	    }
	    res.push_back((int) cur);
	    while (!res.empty() && !res.back()) res.pop_back();
	    return res;
	}
	static vector<long long> karatsuba(vector<long long>& a, vector<long long>& b) {
	    int n = a.size();
	    vector<long long> res(n << 1);
	    if (n <= 32) {
	        for (int i = 0; i < n; i++)
	            for (int j = 0; j < n; j++)
	                res[i + j] += a[i] * b[j];
	        return res;
	    }
	    int k = n >> 1;
	    vector<long long> a1(a.begin(), a.begin() + k);
	    vector<long long> a2(a.begin() + k, a.end());
	    vector<long long> b1(b.begin(), b.begin() + k);
	    vector<long long> b2(b.begin() + k, b.end());
	    vector<long long> a1b1 = karatsuba(a1, b1);
	    vector<long long> a2b2 = karatsuba(a2, b2);
	    for (int i = 0; i < k; i++) a2[i] += a1[i];
	    for (int i = 0; i < k; i++) b2[i] += b1[i];
	    vector<long long> r = karatsuba(a2, b2);
	    for (int i = 0; i < (int) a1b1.size(); i++) r[i] -= a1b1[i];
	    for (int i = 0; i < (int) a2b2.size(); i++) r[i] -= a2b2[i];
	    for (int i = 0; i < (int) r.size(); i++) res[i + k] += r[i];
	    for (int i = 0; i < (int) a1b1.size(); i++) res[i] += a1b1[i];
	    for (int i = 0; i < (int) a2b2.size(); i++) res[i + n] += a2b2[i];
	    return res;
	}
	Bignum operator * (const Bignum& v) const {
	    if ((max(this->a.size(), v.a.size()) * 9 + 5) / 6 <= 8192) {
	        int r = 6;
	        int t = round(pow(base, r));
	        vector<int> ar = convert_base(this->a, nblock, r);
	        vector<int> br = convert_base(v.a, nblock, r);
	        vector<long long> a(ar.begin(), ar.end());
	        vector<long long> b(br.begin(), br.end());
	        while (a.size() < b.size()) a.push_back(0);
	        while (b.size() < a.size()) b.push_back(0);
	        while (a.size() & (a.size() - 1)) a.push_back(0), b.push_back(0);
	        vector<long long> c = karatsuba(a, b);
	        Bignum res;
	        res.sign = sign * v.sign;
	        long long carry = 0;
	        for (int i = 0; i < (int) c.size(); i++) {
	            long long ncarry = c[i] + carry;
	            res.a.push_back(ncarry % t);
	            carry = ncarry / t;
	        }
	        while (carry) {
	            res.a.push_back(carry % t);
	            carry /= t;
	        }
	        res.a = convert_base(res.a, r, nblock);
	        res.trim();
	        return res;
	    }
	    else {
	        vector<int> c = FFT::multiply(this->a, v.a, blockbase);
	        Bignum res;
	        res.sign = sign * v.sign;
	        int carry = 0;
	        for (int i = 0; i < (int) c.size(); i++) {
	            int ncarry = c[i] + carry;
	            res.a.push_back(ncarry % blockbase);
	            carry = ncarry / blockbase;
	        }
	        assert(!carry);
	        while (carry) {
	            res.a.push_back(carry % blockbase);
	            carry /= blockbase;
	        }
	        res.trim();
	        return res;
	    }
	}
	friend Bignum sqrt(const Bignum& a) {
	    Bignum x0 = a, x1 = (a + 1) / 2;
	    while (x1 < x0) {
	        x0 = x1;
	        x1 = (x1 + a / x1) / 2;
	    }
	    return x0;
	}
	friend Bignum pow(Bignum a, Bignum b) {
	    if (b == Bignum(0)) return Bignum(1);
	    Bignum T = pow(a, b / 2);
	    if (b % 2 == 0) return T * T;
	    return T * T * a;
	}
	friend Bignum pow(Bignum a, int b) {
	    return pow(a, (Bignum(b)));
	}
	friend int log(Bignum a, int n) {
	    int res = 0;
	    while (a > Bignum(1)) {
	        res++;
	        a /= n;
	    }
	    return res;
	}
	template<class T> friend Bignum operator + (const T& v, const Bignum& a) {
	    return a + v;
	}
	template<class T> friend Bignum operator - (const T& v, const Bignum& a) {
	    return -a + v;
	}
	template<class T> friend Bignum operator * (const T& v, const Bignum& a) {
	    return a * v;
	}
	template<class T> friend Bignum operator / (const T& v, const Bignum& a) {
	    return Bignum(v) / a;
	}
	Bignum operator ++() {
	    (*this) += 1;
	    return *this;
	}
	Bignum operator ++(int) {
	    (*this) += 1;
	    return *this - 1;
	}
	Bignum operator --() {
	    (*this) -= 1;
	    return *this;
	}
	Bignum operator --(int) {
	    (*this) -= 1;
	    return *this + 1;
	}
};
 
void chemthan() {
	int test; cin >> test;
	assert(1 <= test && test <= 1e2);
	int sumn = 0;
	while (test--) {
	    int n, x; cin >> n >> x;
	    sumn += n;
	    assert(1 <= n && n <= 1e5);
	    assert(1 <= x && x <= 1e6);
	    assert(1 <= sumn && sumn <= 1e6);
	    vector<long long> a(n + 1);
	    FOR(i, 1, n + 1) {
	        cin >> a[i];
	        assert(1 <= a[i] && a[i] <= 1e6);
	        a[i] += a[i - 1];
	    }
	    vi dvs;
	    for (int i = 1; i * i <= x; i++) if (x % i == 0) {
	        dvs.pb(i);
	        dvs.pb(x / i);
	    }
	    sort(all(dvs)), uni(dvs);
	    static int f[1234567];
	    static int g[1234567];
	    Bignum res = 0;
	    FOR(i, 0, sz(dvs)) {
	        int d = dvs[i];
	        int ptr = 0;
	        FOR(j, 0, n - d + 1) {
	            long long t = a[j + d] - a[j];
	            if (t <= x / d) {
	                if (!f[t]) {
	                    g[ptr++] = t;
	                }
	                f[t]++;
	            }
	        }
	        FOR(j, 0, ptr) {
	            int t = g[j];
	            int y = x / d - t;
	            res += (Bignum) f[t] * f[y];
	        }
	        FOR(j, 0, ptr) {
	            int t = g[j];
	            f[t] = 0;
	        }
	    }
	    cout << res << "\n";
	}
}

int main(int argc, char* argv[]) {
	ios_base::sync_with_stdio(0), cin.tie(0);
	if (argc > 1) {
	    assert(freopen(argv[1], "r", stdin));
	}
	if (argc > 2) {
	    assert(freopen(argv[2], "wb", stdout));
	}
	chemthan();
	cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class SQRSUBMA{
	//SOLUTION BEGIN
	int maxX = (int)1e6;
	int[] freq;
	void pre() throws Exception{
	    freq = new int[1+maxX];
	}
	void solve(int TC) throws Exception{
	    int N = ni(), X = ni();
	    int[] A = new int[N];
	    for(int i = 0; i< N; i++)A[i] = ni();
	    int[] toCheck = factorize(X);
	    long ans = 0;
	    int[] sum = new int[N];
	    for(int size:toCheck){
	        if(size > N)continue;
	        int S = X/size;
	        //we need to count pair of subarrays of length size having sum S
	        long tot = 0;
	        for(int i = 0; i< size-1; i++)tot += A[i];
	        int ptr = 0;
	        for(int i = size-1; i< N; i++){
	            tot += A[i];
	            if(tot <= S){
	                sum[ptr++] = (int)tot;
	                freq[(int)tot]++;
	            }
	            tot -= A[i-size+1];
	        }
	        for(int i = 0; i< ptr; i++)ans += freq[S-sum[i]];
	        for(int i = 0; i< ptr; i++)freq[sum[i]]--;
	    }
	    pn(ans);
	}
	int[] factorize(int X){
	    int sqrt = (int)Math.sqrt(X)+4;
	    int[] f = new int[sqrt];
	    int c = 0;
	    for(int i = 1; i*i<= X; i++){
	        if(X%i != 0)continue;
	        f[c++] = i;
	        if(i*i != X)f[c++] = X/i;
	    }
	    return Arrays.copyOfRange(f, 0, c);
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new SQRSUBMA().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

9 Likes

Can someone help me with my submission? I think I have followed the approach given above. Maybe some optimizations are needed. I am using an unordered map instead of a frequency array. I was thinking that the sums won’t fit in a frequncy array. Why does the size 10*N work in the setter’s solution?

I just want to know that does clearing a map take more time than individually deleting elements from it?

I did the same thing, but I’m getting TLE.

Don’t use map, use Array to keep frequency of sum.

1 Like

Map use ordering for logn time in insertion and array do it in O(1)

1 Like

But don’t we need a huge array for that? The sum can go upto 10^{10} right?

@fastred I have used unordered_map. So on an average it should be O(1) access.

Yes it can, but do we need sums greater than X /f for a factor f, minimum f is 1 so do we need sums greater than X ??

3 Likes

Ohhh that’s right. Thanks!
Also, aren’t unordered_maps quite efficient? Is there a very huge difference between using a hashmap and an array?

Worst case insertion of unordered map is O(n), so if there are too many collisions it is O(n^2). The TL was tight to cut off the map solutions so the TLE.

4 Likes

Thanks! It passed.

1 Like

Got it. Man this had some tricky optimizations involved.

can anyone explain me editorial more clearly

yeah u r right… i also made the same mistake… u only need to check if the key value of map is less than the (x/size) value … otherwise don’t map it… it will pass easily

1 Like


Time to :curly_loop: loop: Kill Myself :curly_loop:
:stuck_out_tongue_winking_eye: (don’t take it seriously)

1 Like

There’s a reason I hate GCC optimize

What does this symbol mean?
Sorry, I am a noob.

Got it! It means sz should be a factor of X

@taran_1407
What can be the maximum value for d(x) in time complexity…?

I used binary search for partial points - here : https://www.codechef.com/viewsolution/34821843