MULSUBQ - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: iceknight1093 and satyam_343
Testers: tabr, yash_daga
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Segment trees with lazy propagation

PROBLEM:

An array B is called good if, when sorted, B_i divides B_{i+1} for every i.

You’re given an array A and Q queries on it. For each query (L, R), find the number of good subarrays of [A_L, A_{L+1}, \ldots, A_R].

EXPLANATION:

Let’s first analyze the structure of a good array.
Consider a good array B. Without loss of generality, we can assume B is sorted.

Now, for each i, either B_i = B_{i+1} or B_i \lt B_{i+1}.
In the first case, divisibility is trivially satisfied.
In the second case, note that B_{i+1} is at least 2B_i, since it should be a multiple of B_i that’s larger than B_i.

In particular, this means that if B has K distinct elements, the largest of them is at least 2^{K-1}\times B_1, since each time we move to a higher element the value at least doubles.

In our case, we know A_i \leq N \leq 2\cdot 10^5.
This means that any good subarray can have at most 18 distinct elements, since 2^{18} = 262144 \gt 2\cdot 10^5.

Solving for [1, N]

For now, let’s ignore queries entirely and focus on counting the number of good subarrays of a given array.

Let’s fix the left endpoint L, and count the number of valid right endpoints R.
Note that if [L, R] is good, then so is [L, R-1]: this follows from the fact that if a divides b and b divides c, then a divides c.

In particular, the set of valid R for this L form a contiguous range starting at L, so it suffices to find the right endpoint of this range.

Since the range can only contain upto 18 distinct elements, there’s a rather simple way of doing this:
Start with R = L.
Then, move R to the next position that contains a new element and check if the resulting set of elements satisfies the divisibility condition.
if it does, jump to the next new element and continue; otherwise R-1 is the position we’re looking for.

To do this fast, we need to be able to quickly find the next new element.
This can be done as follows:

  • Let S be a set that contains pairs of (position, element), sorted by position. Initially, S is empty.
  • Iterate i from N down to 1.
  • At i, if A_i exists in S (at some other position), delete it. Then, insert (i, A_i) into S.

S now contains the nearest occurrence of every distinct element that occurs at index \geq i, sorted by position. This is exactly what we want.
Now we can simply iterate across S to find the next distinct element, as detailed at the start of the section.

This way, we do N set insertions and deletions, and then at each index upto 18^2 operations (more specifically, upto \log^2 N operations) to find the appropriate R, for a total of \mathcal{O}(N\log^2 N) time.

Answering queries

Of course, applying the above method directly to each query is going to be too slow, since both N and Q are large.

However, notice that we can in fact reuse a lot of information if we solve our queries offline.
That is, first we’ll read all the queries, then at a given index i we’ll answer all queries whose left endpoint is i.

Consider a new array B of length N. Initially, B_i = 0 for every index.

Let’s run the algorithm from the previous section. When at a fixed i, first find its optimal right endpoint r_i as before.
Then, add 1 to the range [i, r_i] of B.

Note that B_j now represents the number of indices \leq j (but \geq i) such that j is a good right endpoint for this index (since we add 1 to it exactly as many times as it occurs in [x, r_x] for some x).

This means, to answer a query (i, R), we can simply take the sum B_i + B_{i+1} + \ldots + B_R.
Note that this is valid only when we haven’t yet processed index i-1, which is why we need to solve queries offline.

Maintaining B is theoretically quite simple: we need to support adding 1 on some range, and getting the sum of some range. This is a standard problem that is solved by segment trees with lazy propagation.

TIME COMPLEXITY

\mathcal{O}(N\log^2 N + Q\log N) per test case.

CODE:

Setter's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct Node {
	using T = ll;
	T unit = 0;
	T f(T a, T b) { return a+b; }
 
	Node *l = 0, *r = 0;
	int lo, hi;
	T madd = 0;
	T val = unit;
	Node(int _lo,int _hi):lo(_lo),hi(_hi){}
	T query(int L, int R) {
		if (R <= lo || hi <= L) return unit;
		if (L <= lo && hi <= R) return val;
		push();
		return f(l->query(L, R), r->query(L, R));
	}
	void add(int L, int R, T x) {
		if (R <= lo || hi <= L) return;
		if (L <= lo && hi <= R) {
			madd += x;
			val += (hi-lo)*x;
		}
		else {
			push(), l->add(L, R, x), r->add(L, R, x);
			val = f(l->val, r->val);
		}
	}
	void push() {
		if (!l) {
			int mid = lo + (hi - lo)/2;
			l = new Node(lo, mid); r = new Node(mid, hi);
		}
		if (madd)
			l->add(lo,hi,madd), r->add(lo,hi,madd), madd = 0;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n, q; cin >> n >> q;
		vector<int> a(n);
		for (int &x : a) cin >> x;
		vector<vector<array<int, 2>>> queries(n);
		for (int i = 0; i < q; ++i) {
			int l, r; cin >> l >> r; --l;
			queries[l].push_back({r, i});
		}
		vector<ll> ans(q);

		Node *seg = new Node(0, n);
		set<array<int, 2>> active;
		active.insert({n, 0});
		vector<int> next(n+1, n);
		for (int i = n-1; i >= 0; --i) {
			active.erase({next[a[i]], a[i]});
			next[a[i]] = i;
			active.insert({next[a[i]], a[i]});
			
			auto it = active.begin();
			vector<int> cur;
			bool good = true;
			int lim = i+1;
			while (true) {
				auto [pos, val] = *it;
				if (val == 0) {
					lim = n;
					break;
				}
				cur.push_back(val);
				sort(begin(cur), end(cur));
				for (int j = 0; j+1 < (int)size(cur); ++j) good &= cur[j+1] % cur[j] == 0;
				if (!good) {
					lim = pos;
					break;
				}
				++it;
			}
			seg -> add(i, lim, 1);
			for (auto [r, id] : queries[i]) {
				ans[id] = seg -> query(i, r);
			}
		}

		for (auto x : ans) cout << x << '\n';
	}
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string& pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct segtree {
    using S = long long;
    using T = pair<S, S>;
    using F = S;

    T e() {
        return make_pair(0, 0);
    }

    F id() {
        return 0;
    }

    T op(T a, T b) {
        return make_pair(a.first + b.first, a.second + b.second);
    }

    T mapping(F f, T x) {
        x.first += x.second * f;
        return x;
    }

    F composition(F f, F g) {
        return f + g;
    }

    int n;
    int size;
    int log_size;
    vector<T> node;
    vector<F> lazy;

    segtree() : segtree(0) {}
    segtree(int _n) {
        build(vector<T>(_n, e()));
    }
    segtree(const vector<T>& v) {
        build(v);
    }

    void build(const vector<T>& v) {
        n = (int) v.size();
        if (n <= 1) {
            log_size = 0;
        } else {
            log_size = 32 - __builtin_clz(n - 1);
        }
        size = 1 << log_size;
        node.resize(2 * size, e());
        lazy.resize(size, id());
        for (int i = 0; i < n; i++) {
            node[i + size] = v[i];
        }
        for (int i = size - 1; i > 0; i--) {
            pull(i);
        }
    }

    void push(int x) {
        node[2 * x] = mapping(lazy[x], node[2 * x]);
        node[2 * x + 1] = mapping(lazy[x], node[2 * x + 1]);
        if (2 * x < size) {
            lazy[2 * x] = composition(lazy[x], lazy[2 * x]);
            lazy[2 * x + 1] = composition(lazy[x], lazy[2 * x + 1]);
        }
        lazy[x] = id();
    }

    void pull(int x) {
        node[x] = op(node[2 * x], node[2 * x + 1]);
    }

    void set(int p, T v) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        node[p] = v;
        for (int i = 1; i <= log_size; i++) {
            pull(p >> i);
        }
    }

    T get(int p) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        return node[p];
    }

    T get(int l, int r) {
        assert(0 <= l && l <= r && r <= n);
        l += size;
        r += size;
        for (int i = log_size; i >= 1; i--) {
            if (((l >> i) << i) != l) {
                push(l >> i);
            }
            if (((r >> i) << i) != r) {
                push((r - 1) >> i);
            }
        }
        T vl = e();
        T vr = e();
        while (l < r) {
            if (l & 1) {
                vl = op(vl, node[l++]);
            }
            if (r & 1) {
                vr = op(node[--r], vr);
            }
            l >>= 1;
            r >>= 1;
        }
        return op(vl, vr);
    }

    void apply(int p, F f) {
        assert(0 <= p && p < n);
        p += size;
        for (int i = log_size; i >= 1; i--) {
            push(p >> i);
        }
        node[p] = mapping(f, node[p]);
        for (int i = 1; i <= log_size; i++) {
            pull(p >> i);
        }
    }

    void apply(int l, int r, F f) {
        assert(0 <= l && l <= r && r <= n);
        l += size;
        r += size;
        for (int i = log_size; i >= 1; i--) {
            if (((l >> i) << i) != l) {
                push(l >> i);
            }
            if (((r >> i) << i) != r) {
                push((r - 1) >> i);
            }
        }
        int ll = l;
        int rr = r;
        while (l < r) {
            if (l & 1) {
                node[l] = mapping(f, node[l]);
                if (l < size) {
                    lazy[l] = composition(f, lazy[l]);
                }
                l++;
            }
            if (r & 1) {
                r--;
                node[r] = mapping(f, node[r]);
                if (l < size) {
                    lazy[r] = composition(f, lazy[r]);
                }
            }
            l >>= 1;
            r >>= 1;
        }
        l = ll;
        r = rr;
        for (int i = 1; i <= log_size; i++) {
            if (((l >> i) << i) != l) {
                pull(l >> i);
            }
            if (((r >> i) << i) != r) {
                pull((r - 1) >> i);
            }
        }
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0, sq = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readSpace();
        int q = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        sq += q;
        vector<int> a(n);
        for (int i = 0; i < n; i++) {
            a[i] = in.readInt(1, n);
            (i == n - 1 ? in.readEoln() : in.readSpace());
        }
        vector<vector<pair<int, int>>> b(n);
        for (int i = 0; i < q; i++) {
            int l = in.readInt(1, n);
            in.readSpace();
            int r = in.readInt(l, n);
            in.readEoln();
            b[r - 1].emplace_back(l - 1, i);
        }
        vector<pair<long long, long long>> sinit(n, make_pair(0, 1));
        segtree seg(sinit);
        set<pair<int, int>> st;
        vector<long long> ans(q);
        vector<int> d(n + 1, -1);
        for (int i = 0; i < n; i++) {            
            st.erase(make_pair(-d[a[i]], a[i]));
            d[a[i]] = i;
            st.emplace(-d[a[i]], a[i]);
            int last = -1;
            vector<int> c;
            for (auto p : st) {
                c.emplace_back(p.second);
                sort(c.begin(), c.end());
                int ok = 1;
                for (int j = 0; j < (int) c.size() - 1; j++) {
                    if (c[j + 1] % c[j] != 0) {
                        ok = 0;
                    }
                }
                if (!ok) {
                    last = -p.first;
                    break;
                }
            }
            seg.apply(last + 1, i + 1, 1);
            for (auto [x, y] : b[i]) {
                ans[y] = seg.get(x, i + 1).first;
            }
        }
        for (int i = 0; i < q; i++) {
            cout << ans[i] << '\n';
        }
    }
    assert(max(sn, sq) <= 2e5);
    in.readEof();
    return 0;
}

Why does prefix sums not work here?

I’ll throw that question right back at you: how exactly would you use prefix sums?

The hyperlink on satyam’s name links to your profile @iceknight1093

P.S. I solved this with prefix sum + sliding window

Good catch, I’ve fixed the link.

I solved it using sqrt decomposition + sliding window, I can’t find why last test case is getting runtime error. can someone please help??

here is my submission.

https://www.codechef.com/viewsolution/91486344