LENGTHX - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: amrharb
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sorting, binary search, segment tree/fenwick tree

PROBLEM:

Given an array A and integer X, define f(L, R) to be the number of pairs (i, j) such that
L \leq i \lt j \leq R and (A_i + A_j) has exactly X digits in its decimal representation.

Find the sum of f(L, R) across all 1 \leq L \leq R \leq N.

EXPLANATION:

First, note that an integer has X digits in its decimal representation if and only if it lies in the range
\left[10^{X-1}, 10^X\right).

Let’s call a pair (l,r) (where 1 \leq l \lt r \leq N) good if (A_l + A_r) lies in this range.
Our task is to count the sum of the number of good pairs present within every subarray.

Observe that if the pair (l, r) is good, it’ll add 1 to the answer for every subarray that contains it.
The number of such subarrays is exactly l\times (N-r+1), since the subarray should start at some index \leq l and end at some index \geq r.


For a fixed index l (and value A_l), the set of A_r for which (A_l + A_r) has X digits will form some range (specifically, A_r must lie between 10^{X-1} - A_l and 10^X - 1 - A_l).

Consider an array b such that b_i denotes the index of the i-th largest element of A - i.e A_{b_i} is the i-th largest element of A.
Then, A_r will come from some range of indices within b - though we must only consider indices that are \gt l.
Finding this range itself is easy: since the lower bound and upper bound on values are both known, simply binary search.

Note that once the range is known, we only need to know the sum of (N-r+1) for all valid indices r within this range; multiplying that with l will get us the contribution of all good pairs with l as their left endpoint.

To account for only indices \gt l, we use a sweepline algorithm along with a segment tree/fenwick tree. That is,

  • Maintain a segment tree on the values of (N-b_i + 1).
    Initially, all values are present in the segment tree.
  • Enumerate l from 1 to N.
    First, find the index i such that b_i = l, and set the value at this position to be 0 (it was N-l+1 before this).
  • Then, find the appropriate range that needs to be queried using binary search.
    Query for the sum of this range, and add it (multiplied with l) to the answer.

The important step here is the second one, where we set the value that was initially N-l+1 to 0.
This ensures that at any point of time, the segment tree contains only values corresponding to indices \gt l (everything else was set to 0), meaning a simple range sum query gets us what we want.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>

#define int long long

using namespace std;


class segTree {
private:
    vector<int> seg;
    int n;

    int merge(int a, int b) {
        return (a + b) ;
    }

    int get(int node, int start, int end, int l, int r) {
        if (r < start || l > end) return 0;
        if (l <= start && end <= r) return seg[node];
        int mid = (start + end) / 2;
        return merge(get(2 * node, start, mid, l, r), get(2 * node + 1, mid + 1, end, l, r));
    }

    void update(int node, int start, int end, int idx, int val) {
        if (start == end) {
            (seg[node] += val);
            return;
        }
        int mid = (start + end) / 2;
        if (idx <= mid) update(2 * node, start, mid, idx, val);
        else update(2 * node + 1, mid + 1, end, idx, val);
        seg[node] = merge(seg[2 * node], seg[2 * node + 1]);
    }

public:
    segTree(int _n) : n(_n) {
        seg.resize(n << 2);
    }

    int query(int l, int r) {
        return get(1, 0, n - 1, l, r);
    }

    void update(int idx, int val) {
        update(1, 0, n - 1, idx, val);
    }
};

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    vector<long long> pw(20);
    for (long long i = 0, k = 1; i <= 18; i++, k *= 10)pw[i] = k;
    pw[19] = LLONG_MAX;
    int t = 1;
    cin >> t;
    while (t--) {
        int n, x;
        cin >> n >> x;
        vector<long long> v(n);
        set<long long> nums;
        for (auto &it: v) {
            cin >> it;
            nums.insert(it);
        }
        vector<long long> sorted(nums.begin(), nums.end());
        segTree st(sorted.size());
        long long ans = 0;
        for (int i = 0; i < n; i++) {
            int l = lower_bound(sorted.begin(), sorted.end(), max(0LL, pw[x - 1] - v[i])) - sorted.begin();
            int r = upper_bound(sorted.begin(), sorted.end(), max(0LL, pw[x] - v[i] - 1)) - sorted.begin() - 1;
            (ans += 1LL * (n - i) * st.query(l, r));
            int idx = lower_bound(sorted.begin(), sorted.end(), v[i]) - sorted.begin();
            st.update(idx, i + 1);
        }
        cout << ans << '\n';
    }
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define N 1000001
#define md LLONG_MAX
int t[4 * N];
void build(int a[], int v, int tl, int tr) {
    if (tl == tr) {
        t[v] = a[tl];
    } else {
        int tm = (tl + tr) / 2;
        build(a, v*2, tl, tm);
        build(a, v*2+1, tm+1, tr);
        t[v] = (t[v*2] + t[v*2 + 1]) % md;
    }
}

void update(int v, int tl, int tr, int pos, int addend) {
    if (tl == tr) {
        t[v] += addend;
        t[v] %= md;
    } else {
        int tm = (tl + tr) / 2;
        if (pos <= tm)
            update(v*2, tl, tm, pos, addend);
        else
            update(v*2+1, tm+1, tr, pos, addend);
        t[v] = (t[v*2] + t[v*2+1]) % md;
    }
}

int query(int v, int tl, int tr, int l, int r) {
    if (l > r)
        return 0;
    if (l == tl && tr == r)
        return t[v];
    int tm = (tl + tr) / 2;
    return (query(v*2, tl, tm, l, min(r, tm)) + query(v*2+1, tm+1, tr, max(l, tm+1), r)) % md;
}
int32_t main() {
	int t;
	cin>>t;
	while(t--){
	    int n,x;
	    cin>>n>>x;
	    int a[n];
	    map<int, int> mp;
	    for(int i = 0; i < n; i++){
	        cin>>a[i];
	        mp[a[i]] = 0;
	    }
	    int cnt = 0;
	    for(auto &it: mp){
	        it.second = cnt++;
	    }
	    int mn, mx;
	    if(x == 19){
	        mn = 1;
	        for(int i = 0; i < x - 1; i++){
	            mn *= 10;
	        }
	        mx = LLONG_MAX;
	    }else{
	        mx = 1;
	        for(int i = 0; i < x; i++){
    	        mx *= 10;
    	    }
    	    mn = mx / 10;
    	    mx--;
	    }
	    int m = mp.size();
	    int b[m] = {};
	    build(b, 1, 0, m - 1);
	    int ans = 0;
	    for(int i = 0; i < n; i++){
	      int temp = a[i];
	      int len = 0;
	      while(temp){
	          len++;
	          temp /= 10;
	      }
          auto it = mp.lower_bound(mn - a[i]);
          int start = m;
          int end = -1;
          if(it != mp.end()){
              start = (*it).second;
          }
          it = mp.upper_bound(mx - a[i]);
          if(it != mp.begin()){
              it--;
              end = (*it).second;
          }
          ans += query(1, 0, m - 1, start, end) * (n - i);
          ans %= md;
	      update(1, 0, m - 1, mp[a[i]], i + 1);
	    }
	    cout<<ans<<"\n";
	}
	return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7

# https://github.com/cheran-senthil/PyRival/blob/master/pyrival/data_structures/FenwickTree.py
class FenwickTree:
    def __init__(self, x):
        """transform list into BIT"""
        self.bit = x
        for i in range(len(x)):
            j = i | (i + 1)
            if j < len(x):
                x[j] += x[i]

    def update(self, idx, x):
        """updates bit[idx] += x"""
        while idx < len(self.bit):
            self.bit[idx] += x
            idx |= idx + 1

    def query(self, end):
        """calc sum(bit[:end])"""
        x = 0
        while end:
            x += self.bit[end - 1]
            end &= end - 1
        return x

import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n, x = map(int, input().split())
    a = list(map(int, input().split()))
    indices = list(range(n))
    indices.sort(key = lambda x: a[x])
    pos = [0]*n
    for i in range(n): pos[indices[i]] = i

    b = [n - x for x in indices]
    F = FenwickTree(b)

    lo = 10**(x-1)
    hi = 10**x
    ans = 0
    for i in range(n):
        F.update(pos[i], i - n)
        import bisect
        L = bisect.bisect_left(indices, lo - a[i], key=lambda x: a[x])
        R = bisect.bisect_left(indices, hi - a[i], key=lambda x: a[x])
        if L == R: continue
        ans += (i+1) * (F.query(R) - F.query(L))
    print(ans)
1 Like

Can someone help with a testcase where the assert condition is getting triggered? Otherwise the solution seems correct to me. Thanks!

Code

After stress testing your solution.

WA on the following test:
1
3 1
5 1 5 
Your answer is:
3
Correct answer is:
4
2 Likes

Can anyone help to point out which test case gives wrong answer (I get WA on task 1) for this solution? My approach is similar (though not exactly the same) as editorial. Thanks!

You are right, testcase is wrong.

All solutions are wrong.

Why all code MOD 10^9+7?

This will work:CodeChef: Practical coding for everyone

@iceknight1093, Please look into this issue.

Edit: Even I used a similar approach in this submission 1070912917

The problem initially had higher constraints, and the answer didn’t fit in a 64-bit integer - so the mod was necessary; this is also when I wrote the editorial.
When the constraints were lowered I forgot to update the linked code, sorry about that.
Testcases are correct though.