KPRODSUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: theabbie
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Divide and conquer

PROBLEM:

You’re given an array A of length N, and integers K and M.
Compute the sum of products of all subarrays of A of length \leq K, modulo M.

EXPLANATION:

While a common idea when dealing with subarrays is to think of prefixes, that unfortunately doesn’t work here: since we’re working modulo M, it’s not really possible to obtain the product of a subarray by “dividing” out one prefix from another (“division” in the realm of modular arithmetic really means multiplication with the inverse; and when a isn’t coprime to M, the inverse of a doesn’t exist).

Instead, we use divide and conquer.
Let f(L, R) denote the answer when considering the range [L, R] of the array.
We want f(1, N).

Let \text{mid} = \frac{L+R}{2} denote the midpoint of the range.
Note that any subarray [x, y] within [L, R] must be one of three types:

  1. y \leq \text{mid}, meaning it lies entirely in the left side.
  2. x \gt \text{mid}, meaning it lies entirely in the right side.
  3. x \leq \text{mid} and y\gt \text{mid}, meaning it crosses the middle.

The products of the first and second types of subarrays can be recursively computed by calling f(L, \text{mid}) and f(\text{mid}+1, R), since they’ll lie within one of those ranges.
That leaves only the third type: subarrays that cross the middle.

A subarray [x, y] that crosses the middle can be broken up into subarrays [x, \text{mid}] and [\text{mid}+1, y] - meaning we only really need to care about subarrays ending at \text{mid} and starting at \text{mid}+1 instead.

Let P_{x, \text{mid}} denote the product of the subarray starting at index x and ending at \text{mid}.
We have P_{\text{mid}, \text{mid}} = A_\text{mid}, and otherwise P_{x, \text{mid}} = (A_x \cdot P_{x+1, \text{mid}})\bmod M.
So, all the P_{x, \text{mid}} values can be computed in \mathcal{O}(\text{mid}-L) time.
Similarly, let P_{\text{mid}+1, y} denote the product of the subarray [\text{mid}+1, y].

If we fix an index y\gt \text{mid}, since we’re looking for subarrays of length \leq K, the set of valid x \leq \text{mid} will form some range ending at \text{mid}.
Let l denote the left end of this range.
We then want to add

P_{\text{mid}+1, y} \cdot (P_{l, \text{mid}} + P_{l+1, \text{mid}} + \ldots + P_{\text{mid}, \text{mid}})

to the answer.

This can easily be computed in constant time using prefix sums, so all y can be processed in \mathcal{O}(R-\text{mid}) time.

So, recursive calls aside (both of which halve the size of the range being considered), f(L, R) takes an additional \mathcal{O}(R-L) work.

So, if T(N) is the function describing the time complexity of our algorithm, we have

T(N) = 2T(\frac{N}{2}) + \mathcal{O}(N)

which is well-known to reduce to \mathcal{O}(N\log N), and that’s our complexity.

A simple way to see why this is true is to visualize the recursive tree as the array is processed: each time you move to a child, the size of the subarray being considered halves, so after \mathcal{O}(\log N) levels you’ll reach a size-1 array and not branch further.
At each level of the tree, \mathcal{O}(N) work is performed in the ‘merging’ step (since subarrays corresponding to different nodes within the same level are disjoint), so with \mathcal{O}(\log N) levels that’s \mathcal{O}(N\log N) work overall.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (Python)
def solve(arr, k, m):
    def f(i, j):
        if i + 1 == j:
            return arr[i] % m
        extra = 0
        mid = (i + j) // 2
        pref = [0]
        rp = 1
        rl = min(j - mid, k - 1)
        for y in range(mid, mid + rl):
            rp *= arr[y]
            rp %= m
            pref.append((pref[-1] + rp) % m)
        lp = 1
        for x in range(mid - 1, max(i - 1, mid - k), -1):
            lp *= arr[x]
            lp %= m
            l = mid - x
            extra += lp * pref[min(k - l, j - mid)]
            extra %= m
        return (f(i, mid) + extra + f(mid, j)) % m
    return f(0, len(arr))

t = int(input())

for _ in range(t):
    n, k, m = map(int, input().split())
    arr = list(map(int, input().split()))
    print(solve(arr, k, m))
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e2);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 3e5);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readSpace();
        int m = in.readInt(1, 1e9);
        in.readEoln();
        sn += n;
        auto a = in.readInts(n, 0, 1e9);
        in.readEoln();
        function<int(int, int)> Rec = [&](int l, int r) {
            if (l + 1 == r) {
                return a[l] % m;
            }
            int x = (l + r) >> 1;
            long long res = Rec(l, x) + Rec(x, r);
            vector<long long> t;
            t.emplace_back(1);
            for (int i = x - 1; i >= l; i--) {
                t.emplace_back(t.back() * a[i] % m);
            }
            t[0] = 0;
            for (int i = 1; i < (int) t.size(); i++) {
                t[i] = (t[i] + t[i - 1]) % m;
            }
            long long u = 1;
            for (int i = x; i < r; i++) {
                u *= a[i];
                u %= m;
                int j = clamp(k - (i - x + 1), 0, (int) t.size() - 1);
                res += u * t[j] % m;
            }
            return (int) (res % m);
        };
        cout << Rec(0, n) << '\n';
    }
    assert(sn <= 3e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, k, m = map(int, input().split())
    a = list(map(int, input().split()))
    def solve(l, r):
        if l+1 == r: return a[l]%m
        mid = (l+r)//2
        res = solve(l, mid) + solve(mid, r)
        prod = 1
        b = [0]*min(k, r-mid)
        for i in range(len(b)):
            prod = (prod * a[mid+i]) % m
            b[i] += prod
        for i in range(1, len(b)):
            b[i] = (b[i-1] + b[i]) % m
        
        prod = 1
        for i in range(k):
            if mid-1-i < l: break
            prod = (prod * a[mid-1-i]) % m
            take = min(k-1-i, r-mid)
            if take > 0: res += prod * b[take-1] % m
        return res % m
    print(solve(0, n))
1 Like

Alternatively, use a sliding window. Every product ending at i gets multiplied by a[i+1] so it’s sufficient to store their sum. Additionally, there’s one new subarray of length 1, and one old subarray of size k needs to be subtracted. Now use your favorite data structure to support static range modular product queries.

2 Likes

Nice solution, I ended up using the sliding window approach.

BTW, I believe there is a minor typo in the editorial, in the definition of P_{x,\{text{mid}}}, it is written as the product ending at M, instead it should be mid

image

I’m afraid to ask, but what exactly would be that data structure. I couldn’t figure it out for the life of me.

I used a segment tree from atcoder library. People also reported using a sparse table. The actual code is rather short:

using S = mint;
S op(S l, S r) { return l * r; }
S e() { return 1; }

int main() {
  cin.tie(0)->sync_with_stdio(0);

  int t; cin >> t; while (t--) {
    int n, k, m;
    cin >> n >> k >> m;
    mint::set_mod(m);

    vector<int> a(n);
    for (auto& ai : a) {
      cin >> ai;
    }

    segtree<S, op, e> seg(n);
    for (int i = 0; i < n; ++i) {
      seg.set(i, a[i]);
    }

    mint ans = 0, sum_prod = 0;
    for (int i = 0; i < n; ++i) {
      if (i >= k) {
        sum_prod -= seg.prod(i - k, i);
      }
      sum_prod *= a[i];
      sum_prod += a[i];
      ans += sum_prod;
    }
    cout << ans.val() << '\n';
  }
}
1 Like

Another Interesting Approach

ans = summationOf(S[i])
Where , S[i] = A[i] * (1 + S[i + 1] - subarrayProduct[i…i+K-1])

say A = [a0 , a1 , a2 , a3 , a4 , a5 , a6]
and K = 4

sum of products with a0 : s0 = (a0) + (a0 * a1) + (a0 * a1 * a2) + (a0 * a1 * a2 * a3)
sum of products with a1 : s1 = (a1) + (a1 * a2) + (a1 * a2 * a3) + (a1 * a2 * a3 * a4)
sum of products with a2 : s2 = (a2) + (a2 * a3) + (a2 * a3 * a4) + (a2 * a3 * a4 * a5)
sum of products with a3 : s3 = (a3) + (a3 * a4) + (a3 * a4 * a5) + (a3 * a4 * a5 * a6)
sum of products with a4 : s4 = (a4) + (a4 * a5) + (a4 * a5 * a6)
sum of products with a5 : s5 = (a5) + (a5 * a6)
sum of products with a6 : s6 = (a6)

Now lets try to simplify these sums

s6 = (a6)

s5 = (a5) + (a5 * a6)
s5 = (a5) * (1 + a6)
s5 = (a5) * (1 + s6)

s4 = (a4) + (a4 * a5) + (a4 * a5 * a6)
s4 = (a4) * (1 + (a5) + (a5 * a6))
s4 = (a4) * (1 + s5)

s3 = (a3) + (a3 * a4) + (a3 * a4 * a5) + (a3 * a4 * a5 * a6)
s3 = (a3) * (1 + (a4) + (a4 * a5) + (a4 * a5 * a6))
s3 = (a3) * (1 + s4)

From here we start getting subarrays with length > k
so we subtract k length subarray From our suffix (Do Dry run for better understanding)

s2 = (a2) + (a2 * a3) + (a2 * a3 * a4) + (a2 * a3 * a4 * a5)
s2 = (a2) * (1 + (a3) + (a3 * a4) + (a3 * a4 * a5))
s2 = (a2) * (1 + (s3 - a3 * a4 * a5 * a6)
s2 = (a2) * (1 + s3 - subarrayProduct(3 , 6))

s1 = (a1) * (1 + (s2 - a2 * a3 * a3 * a5))
s1 = (a1) * (1 + s2 - subarrayProduct(2 , 5))

s0 = (a0) * (1 + (s1 - a1 * a2 * a3 * a4))
s0 = (a0) * (1 + s1 - subarrayProduct(1 , 4))

Implementation -

Here we cannot find subbarray product using prefix/suffix methods , because we can have A[i] = 0 , which disrupts it , also A[i] is very large . So here I’m using segment tree to get subarray products in O(logn) , we can also use sparse tables
Instead of creating Array for S[i] , I’m using a varible ‘S’ and I keep on updating the same

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
ll mod;

class SegmentTree{
public:
      vector<ll> tree;

      SegmentTree(vector<ll> &nums){
            int N = nums.size();
            tree.resize(N << 2 | 1);
            build(1 , 0 , N - 1 , nums);
      }

      ll build(int node , int low , int high , vector<ll> &nums){
            if(low == high) return tree[node] = nums[low];
            int mid = (low + high) >> 1;
            return tree[node] = (build(node << 1 , low , mid , nums) * build(node << 1 | 1 , mid + 1 , high , nums)) % mod;
      }

      ll query(int node , int low , int high , int ql , int qr){
            if(low >  qr || high <  ql) return 1;
            if(low >= ql && high <= qr) return tree[node];
            int mid = (low + high) >> 1;
            return (query(node << 1 , low , mid , ql , qr) * query(node << 1 | 1 , mid + 1 , high , ql , qr)) % mod;
      }
};

int main(){
    int t ; cin >> t;

    while(t--){
        ll n , k ; cin >> n >> k >> mod;

        vector<ll> nums(n);
        for(ll &num : nums) cin >> num;

        SegmentTree seg(nums);

        ll S = 0 , ans = 0;

        for(int i = n - 1 ; i >= 0 ; i--){
            S = (nums[i] * (1 + S)) % mod;
            ans = (ans + S) % mod;
            if(i + k - 1 < n) S = ((S - seg.query(1 , 0 , n - 1 , i , i + k - 1)) % mod + mod) % mod;
        }

        cout << ans % mod << endl; 
    }
    
    return 0;
}


2 Likes

Good and intuitive approach. Started with the same kind of approach but couldn’t derive the subtraction part. :frowning:

Didn’t get the editorial solution tho

Fixed, thanks for noticing.

It has a recurrence relation as well

st.query is a product query in the segment tree

void solve()
{
    int n, k; cin >> n >> k >> mod;
    vi a(n); cin >> a;

    SegmentTree<int> st(a, 1ll);

    vi dp(n);
    dp[0] = a[0];
    for(int i = 1; i < n; ++i)
    {
        dp[i] = (a[i] * (dp[i - 1] + 1)) % mod;
        if(i >= k) {
            dp[i] -= st.query(i - k, i);
            dp[i] %= mod;
            dp[i] += mod;
            dp[i] %= mod;
        }
    }

    int ans = 0;
    for(auto it : dp) {
        ans += it;
        ans %= mod;
    }

    cout << ans << endl;
}