PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: theabbie
Tester: tabr
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Divide and conquer
PROBLEM:
You’re given an array A of length N, and integers K and M.
Compute the sum of products of all subarrays of A of length \leq K, modulo M.
EXPLANATION:
While a common idea when dealing with subarrays is to think of prefixes, that unfortunately doesn’t work here: since we’re working modulo M, it’s not really possible to obtain the product of a subarray by “dividing” out one prefix from another (“division” in the realm of modular arithmetic really means multiplication with the inverse; and when a isn’t coprime to M, the inverse of a doesn’t exist).
Instead, we use divide and conquer.
Let f(L, R) denote the answer when considering the range [L, R] of the array.
We want f(1, N).
Let \text{mid} = \frac{L+R}{2} denote the midpoint of the range.
Note that any subarray [x, y] within [L, R] must be one of three types:
- y \leq \text{mid}, meaning it lies entirely in the left side.
- x \gt \text{mid}, meaning it lies entirely in the right side.
- x \leq \text{mid} and y\gt \text{mid}, meaning it crosses the middle.
The products of the first and second types of subarrays can be recursively computed by calling f(L, \text{mid}) and f(\text{mid}+1, R), since they’ll lie within one of those ranges.
That leaves only the third type: subarrays that cross the middle.
A subarray [x, y] that crosses the middle can be broken up into subarrays [x, \text{mid}] and [\text{mid}+1, y] - meaning we only really need to care about subarrays ending at \text{mid} and starting at \text{mid}+1 instead.
Let P_{x, \text{mid}} denote the product of the subarray starting at index x and ending at \text{mid}.
We have P_{\text{mid}, \text{mid}} = A_\text{mid}, and otherwise P_{x, \text{mid}} = (A_x \cdot P_{x+1, \text{mid}})\bmod M.
So, all the P_{x, \text{mid}} values can be computed in \mathcal{O}(\text{mid}-L) time.
Similarly, let P_{\text{mid}+1, y} denote the product of the subarray [\text{mid}+1, y].
If we fix an index y\gt \text{mid}, since we’re looking for subarrays of length \leq K, the set of valid x \leq \text{mid} will form some range ending at \text{mid}.
Let l denote the left end of this range.
We then want to add
to the answer.
This can easily be computed in constant time using prefix sums, so all y can be processed in \mathcal{O}(R-\text{mid}) time.
So, recursive calls aside (both of which halve the size of the range being considered), f(L, R) takes an additional \mathcal{O}(R-L) work.
So, if T(N) is the function describing the time complexity of our algorithm, we have
which is well-known to reduce to \mathcal{O}(N\log N), and that’s our complexity.
A simple way to see why this is true is to visualize the recursive tree as the array is processed: each time you move to a child, the size of the subarray being considered halves, so after \mathcal{O}(\log N) levels you’ll reach a size-1 array and not branch further.
At each level of the tree, \mathcal{O}(N) work is performed in the ‘merging’ step (since subarrays corresponding to different nodes within the same level are disjoint), so with \mathcal{O}(\log N) levels that’s \mathcal{O}(N\log N) work overall.
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Author's code (Python)
def solve(arr, k, m):
def f(i, j):
if i + 1 == j:
return arr[i] % m
extra = 0
mid = (i + j) // 2
pref = [0]
rp = 1
rl = min(j - mid, k - 1)
for y in range(mid, mid + rl):
rp *= arr[y]
rp %= m
pref.append((pref[-1] + rp) % m)
lp = 1
for x in range(mid - 1, max(i - 1, mid - k), -1):
lp *= arr[x]
lp %= m
l = mid - x
extra += lp * pref[min(k - l, j - mid)]
extra %= m
return (f(i, mid) + extra + f(mid, j)) % m
return f(0, len(arr))
t = int(input())
for _ in range(t):
n, k, m = map(int, input().split())
arr = list(map(int, input().split()))
print(solve(arr, k, m))
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
#define IGNORE_CR
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
#ifdef IGNORE_CR
if (c == '\r') {
continue;
}
#endif
buffer.push_back((char) c);
}
}
string readOne() {
assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
assert(!isspace(buffer[pos]));
res += buffer[pos];
pos++;
}
return res;
}
string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
string res = readOne();
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
int res = stoi(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
long long res = stoll(readOne());
assert(min_val <= res);
assert(res <= max_val);
return res;
}
vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
res[i] = readInt(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
res[i] = readLong(min_val, max_val);
if (i != size - 1) {
readSpace();
}
}
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
input_checker in;
int tt = in.readInt(1, 1e2);
in.readEoln();
int sn = 0;
while (tt--) {
int n = in.readInt(1, 3e5);
in.readSpace();
int k = in.readInt(1, n);
in.readSpace();
int m = in.readInt(1, 1e9);
in.readEoln();
sn += n;
auto a = in.readInts(n, 0, 1e9);
in.readEoln();
function<int(int, int)> Rec = [&](int l, int r) {
if (l + 1 == r) {
return a[l] % m;
}
int x = (l + r) >> 1;
long long res = Rec(l, x) + Rec(x, r);
vector<long long> t;
t.emplace_back(1);
for (int i = x - 1; i >= l; i--) {
t.emplace_back(t.back() * a[i] % m);
}
t[0] = 0;
for (int i = 1; i < (int) t.size(); i++) {
t[i] = (t[i] + t[i - 1]) % m;
}
long long u = 1;
for (int i = x; i < r; i++) {
u *= a[i];
u %= m;
int j = clamp(k - (i - x + 1), 0, (int) t.size() - 1);
res += u * t[j] % m;
}
return (int) (res % m);
};
cout << Rec(0, n) << '\n';
}
assert(sn <= 3e5);
in.readEof();
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n, k, m = map(int, input().split())
a = list(map(int, input().split()))
def solve(l, r):
if l+1 == r: return a[l]%m
mid = (l+r)//2
res = solve(l, mid) + solve(mid, r)
prod = 1
b = [0]*min(k, r-mid)
for i in range(len(b)):
prod = (prod * a[mid+i]) % m
b[i] += prod
for i in range(1, len(b)):
b[i] = (b[i-1] + b[i]) % m
prod = 1
for i in range(k):
if mid-1-i < l: break
prod = (prod * a[mid-1-i]) % m
take = min(k-1-i, r-mid)
if take > 0: res += prod * b[take-1] % m
return res % m
print(solve(0, n))