VKMPAIRS - Editorial



Setter: Ansh Gupta
Tester: Aryan Choudhary
Editorialist: Ajit Sharma Kasturi




Modular Multiplicative Inverse, Bitwise xor


We are given an array A of N integers A_1, A_2, \dots, A_N and an array B of M integers B_1, B_2, \dots, B_M. We need to find the total number of ordered pairs (i, j) for which the following conditions hold good:

  • 1 \leq i \leq N

  • 1 \leq j \leq M

  • P divides ( A_i \cdot (A_i \oplus B_j) -1)

  • (A_i \oplus B_j) \lt P


  • The key idea of this problem is to iterate over i from 1 to N and find the number of possible j for which the given conditions hold true.

  • Suppose we fix some i for which A_i \mod P = 0. Then, according to the third condition, -1 \mod P = 0 which is impossible since P \geq 2. Thus, we cannot find any possible j in this case.

  • Let us fix some i for which A_i \mod P \neq 0 . Now, from the third condition, we have
    \hspace{0.8cm} ( A_i \cdot (A_i \oplus B_j) -1) \mod P = 0
    \implies ( A_i \cdot (A_i \oplus B_j) ) \mod P = 1
    \implies (A_i \oplus B_j) \mod P = A_i^{-1} \mod P
    \implies (A_i \oplus B_j) = A_i^{-1} \mod P (according to the last condition)

  • Since P is a prime, the modular inverse always exists for A_i where A_i \mod P \neq 0 and it can be computed with the help of Fermat’s little theorem. Finally, we get the value of B_j as B_j = (A_i \oplus (A_i^{-1} \mod P)). We can initially precompute this count of B_j and add it to the answer.


O(N \log P) for each test case for calculating modular inverse. If we precompute the counts using a map instead of hash map, we get the time complexity as O(N \log N + N \log P).


Editorialist's solution
#include <bits/stdc++.h>
#define ll long long int
using namespace std;

int power(int x, int y, int p) {
      if (y==0) return 1;
      int temp = power(x, y/2, p);
      temp = (1LL * temp * temp) % p;
      if (y&1) {
            temp = (1LL * temp * x) % p;
      return temp;

int inverse (int x, int p) {
    return power(x, p-2, p);

int main()
      int tests;
      cin >> tests;
      while (tests--) {
            int n, m, p;
            cin >> n >> m >> p;
            const int MOD = 1e9 + 7;
            vector<int> a(n), b(m);
            for(int i=0; i<n; i++) {
                  cin >> a[i];
            for(int j=0; j<m; j++) {
                  cin >> b[j];

            map<int, int> cnt;
            ll ans = 0;

            for(int j=0; j<m; j++) {

            for(int i=0; i<n; i++) {
                  if (a[i]%p == 0) {
                  int req = inverse(a[i], p) ^ a[i];
                  ans += cnt[req];

            cout << ans << endl;
      return 0;
Setter's solution

#include "bits/stdc++.h"
using namespace std;
#define ll               long long
#define all(x)            (x).begin(),(x).end()
#define test int t; cin>>t; while(t--)
#define noob ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
typedef long double ld;

#define inf 1e17
#define endl '\n'
ll mod;
ll modul = 1e9 + 7;
ll max(ll i , ll j) {
    if (i > j)return i;
    else return j;
ll min(ll i , ll j) {
    if (i < j)return i;
    else return j;
long long binpow(long long a, long long b) {
    a %= mod;
    long long res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    return res;
void solve() {
    ll n, m, p;
    cin >> n >> m >> p;
    mod = p;
    map<ll, ll> cnt;
    ll ans = 0;
    ll a[n], b[m];
    for (ll i = 0; i < n; i++) {
        cin >> a[i];
    for (ll i = 0; i < m; i++) {
        cin >> b[i];
    for (ll i = 0; i < n; i++) {
        if ( (a[i] % p) == 0 )continue;
        ll value_needed = binpow(a[i], mod - 2);
        value_needed = (value_needed ^ a[i]);
        ans = (ans + cnt[value_needed]);
    cout << ans << endl;

signed main()
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);


Tester's solution
def main():
    for _ in range(int(input())):
        a=[int(x) for x in input().split()]
        b=[[] for _ in range(n+1)]
        for i,x in enumerate(a):
            if x <=n:
        def solve(a,x):
            if m==0:
                return 0
            for i in range(m-x+1):
                l=a[L]-(a[L-1] if L else -1)
                r=(a[R+1] if R+1<m else n)-a[R]
            # print(ans,x)
            return ans
        print(sum(solve(v,i)*i for i,v in enumerate(b)))


Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:


I tried using binary trie here. We insert mod inverse all values of B in this trie. then for given ai, all we need is to see if Ai xor of inverse Ai present in trie or not.

for each a[i], there exists only one value of a[i] ^ b[j] such that a[i] ^ b[j] < p and a[i] * (a[i] ^ b[j]) - 1 is divisible by p. let a[i] ^ b[j] = x, then we need to solve the equation a[i] * x - p * k = 1. This can be solved using extended euclidean algorithm.


One of the very good question, upvote++.
Initially have no idea then have a look at inverse mod approach of editorial, mind blowing…