VKMPAIRS - Editorial

PROBLEM LINK:

Practice
Div1
Div2
Div3

Setter: Ansh Gupta
Tester: Aryan Choudhary
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

EASY

PREREQUISITES:

Modular Multiplicative Inverse, Bitwise xor

PROBLEM:

We are given an array A of N integers A_1, A_2, \dots, A_N and an array B of M integers B_1, B_2, \dots, B_M. We need to find the total number of ordered pairs (i, j) for which the following conditions hold good:

  • 1 \leq i \leq N

  • 1 \leq j \leq M

  • P divides ( A_i \cdot (A_i \oplus B_j) -1)

  • (A_i \oplus B_j) \lt P

EXPLANATION:

  • The key idea of this problem is to iterate over i from 1 to N and find the number of possible j for which the given conditions hold true.

  • Suppose we fix some i for which A_i \mod P = 0. Then, according to the third condition, -1 \mod P = 0 which is impossible since P \geq 2. Thus, we cannot find any possible j in this case.

  • Let us fix some i for which A_i \mod P \neq 0 . Now, from the third condition, we have
    \hspace{0.8cm} ( A_i \cdot (A_i \oplus B_j) -1) \mod P = 0
    \implies ( A_i \cdot (A_i \oplus B_j) ) \mod P = 1
    \implies (A_i \oplus B_j) \mod P = A_i^{-1} \mod P
    \implies (A_i \oplus B_j) = A_i^{-1} \mod P (according to the last condition)

  • Since P is a prime, the modular inverse always exists for A_i where A_i \mod P \neq 0 and it can be computed with the help of Fermat’s little theorem. Finally, we get the value of B_j as B_j = (A_i \oplus (A_i^{-1} \mod P)). We can initially precompute this count of B_j and add it to the answer.

TIME COMPLEXITY:

O(N \log P) for each test case for calculating modular inverse. If we precompute the counts using a map instead of hash map, we get the time complexity as O(N \log N + N \log P).

SOLUTION:

Editorialist's solution
#include <bits/stdc++.h>
#define ll long long int
using namespace std;

int power(int x, int y, int p) {
      if (y==0) return 1;
      int temp = power(x, y/2, p);
      temp = (1LL * temp * temp) % p;
      if (y&1) {
            temp = (1LL * temp * x) % p;
      }
      return temp;
}

int inverse (int x, int p) {
    return power(x, p-2, p);
}

int main()
{
      int tests;
      cin >> tests;
      while (tests--) {
            int n, m, p;
            cin >> n >> m >> p;
            const int MOD = 1e9 + 7;
            vector<int> a(n), b(m);
            
            for(int i=0; i<n; i++) {
                  cin >> a[i];
            }
            for(int j=0; j<m; j++) {
                  cin >> b[j];
            }

            map<int, int> cnt;
            ll ans = 0;

            for(int j=0; j<m; j++) {
                 cnt[b[j]]++;
            }

            for(int i=0; i<n; i++) {
                  if (a[i]%p == 0) {
                      continue;
                  }
                  int req = inverse(a[i], p) ^ a[i];
                  ans += cnt[req];
            }

            cout << ans << endl;
      }
      return 0;
}
Setter's solution

#include "bits/stdc++.h"
using namespace std;
#define ll               long long
#define all(x)            (x).begin(),(x).end()
#define test int t; cin>>t; while(t--)
#define noob ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
typedef long double ld;

#define inf 1e17
#define endl '\n'
ll mod;
ll modul = 1e9 + 7;
ll max(ll i , ll j) {
    if (i > j)return i;
    else return j;
}
ll min(ll i , ll j) {
    if (i < j)return i;
    else return j;
}
long long binpow(long long a, long long b) {
    a %= mod;
    long long res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
void solve() {
    ll n, m, p;
    cin >> n >> m >> p;
    mod = p;
    map<ll, ll> cnt;
    ll ans = 0;
    ll a[n], b[m];
    for (ll i = 0; i < n; i++) {
        cin >> a[i];
    }
    for (ll i = 0; i < m; i++) {
        cin >> b[i];
        cnt[b[i]]++;
    }
    for (ll i = 0; i < n; i++) {
        if ( (a[i] % p) == 0 )continue;
        ll value_needed = binpow(a[i], mod - 2);
        value_needed = (value_needed ^ a[i]);
        ans = (ans + cnt[value_needed]);
    }
    cout << ans << endl;
}

signed main()
{
    noob
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif

    test
        solve();
}



Tester's solution
def main():
    for _ in range(int(input())):
        n=int(input())
        a=[int(x) for x in input().split()]
        b=[[] for _ in range(n+1)]
        for i,x in enumerate(a):
            if x <=n:
                b[x].append(i)
        def solve(a,x):
            m=len(a)
            if m==0:
                return 0
            ans=0
            for i in range(m-x+1):
                L,R=i,i+x-1
                l=a[L]-(a[L-1] if L else -1)
                r=(a[R+1] if R+1<m else n)-a[R]
                ans+=l*r
            # print(ans,x)
            return ans
        print(sum(solve(v,i)*i for i,v in enumerate(b)))

main()


Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:

5 Likes

I tried using binary trie here. We insert mod inverse all values of B in this trie. then for given ai, all we need is to see if Ai xor of inverse Ai present in trie or not.

for each a[i], there exists only one value of a[i] ^ b[j] such that a[i] ^ b[j] < p and a[i] * (a[i] ^ b[j]) - 1 is divisible by p. let a[i] ^ b[j] = x, then we need to solve the equation a[i] * x - p * k = 1. This can be solved using extended euclidean algorithm.

2 Likes

One of the very good question, upvote++.
Initially have no idea then have a look at inverse mod approach of editorial, mind blowing…