PROBLEM LINK:
Setter: Ansh Gupta
Tester: Aryan Choudhary
Editorialist: Ajit Sharma Kasturi
DIFFICULTY:
EASY
PREREQUISITES:
Modular Multiplicative Inverse, Bitwise xor
PROBLEM:
We are given an array A of N integers A_1, A_2, \dots, A_N and an array B of M integers B_1, B_2, \dots, B_M. We need to find the total number of ordered pairs (i, j) for which the following conditions hold good:
-
1 \leq i \leq N
-
1 \leq j \leq M
-
P divides ( A_i \cdot (A_i \oplus B_j) -1)
-
(A_i \oplus B_j) \lt P
EXPLANATION:
-
The key idea of this problem is to iterate over i from 1 to N and find the number of possible j for which the given conditions hold true.
-
Suppose we fix some i for which A_i \mod P = 0. Then, according to the third condition, -1 \mod P = 0 which is impossible since P \geq 2. Thus, we cannot find any possible j in this case.
-
Let us fix some i for which A_i \mod P \neq 0 . Now, from the third condition, we have
\hspace{0.8cm} ( A_i \cdot (A_i \oplus B_j) -1) \mod P = 0
\implies ( A_i \cdot (A_i \oplus B_j) ) \mod P = 1
\implies (A_i \oplus B_j) \mod P = A_i^{-1} \mod P
\implies (A_i \oplus B_j) = A_i^{-1} \mod P (according to the last condition) -
Since P is a prime, the modular inverse always exists for A_i where A_i \mod P \neq 0 and it can be computed with the help of Fermat’s little theorem. Finally, we get the value of B_j as B_j = (A_i \oplus (A_i^{-1} \mod P)). We can initially precompute this count of B_j and add it to the answer.
TIME COMPLEXITY:
O(N \log P) for each test case for calculating modular inverse. If we precompute the counts using a map instead of hash map, we get the time complexity as O(N \log N + N \log P).
SOLUTION:
Editorialist's solution
#include <bits/stdc++.h>
#define ll long long int
using namespace std;
int power(int x, int y, int p) {
if (y==0) return 1;
int temp = power(x, y/2, p);
temp = (1LL * temp * temp) % p;
if (y&1) {
temp = (1LL * temp * x) % p;
}
return temp;
}
int inverse (int x, int p) {
return power(x, p-2, p);
}
int main()
{
int tests;
cin >> tests;
while (tests--) {
int n, m, p;
cin >> n >> m >> p;
const int MOD = 1e9 + 7;
vector<int> a(n), b(m);
for(int i=0; i<n; i++) {
cin >> a[i];
}
for(int j=0; j<m; j++) {
cin >> b[j];
}
map<int, int> cnt;
ll ans = 0;
for(int j=0; j<m; j++) {
cnt[b[j]]++;
}
for(int i=0; i<n; i++) {
if (a[i]%p == 0) {
continue;
}
int req = inverse(a[i], p) ^ a[i];
ans += cnt[req];
}
cout << ans << endl;
}
return 0;
}
Setter's solution
#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define all(x) (x).begin(),(x).end()
#define test int t; cin>>t; while(t--)
#define noob ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
typedef long double ld;
#define inf 1e17
#define endl '\n'
ll mod;
ll modul = 1e9 + 7;
ll max(ll i , ll j) {
if (i > j)return i;
else return j;
}
ll min(ll i , ll j) {
if (i < j)return i;
else return j;
}
long long binpow(long long a, long long b) {
a %= mod;
long long res = 1;
while (b > 0) {
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void solve() {
ll n, m, p;
cin >> n >> m >> p;
mod = p;
map<ll, ll> cnt;
ll ans = 0;
ll a[n], b[m];
for (ll i = 0; i < n; i++) {
cin >> a[i];
}
for (ll i = 0; i < m; i++) {
cin >> b[i];
cnt[b[i]]++;
}
for (ll i = 0; i < n; i++) {
if ( (a[i] % p) == 0 )continue;
ll value_needed = binpow(a[i], mod - 2);
value_needed = (value_needed ^ a[i]);
ans = (ans + cnt[value_needed]);
}
cout << ans << endl;
}
signed main()
{
noob
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
test
solve();
}
Tester's solution
def main():
for _ in range(int(input())):
n=int(input())
a=[int(x) for x in input().split()]
b=[[] for _ in range(n+1)]
for i,x in enumerate(a):
if x <=n:
b[x].append(i)
def solve(a,x):
m=len(a)
if m==0:
return 0
ans=0
for i in range(m-x+1):
L,R=i,i+x-1
l=a[L]-(a[L-1] if L else -1)
r=(a[R+1] if R+1<m else n)-a[R]
ans+=l*r
# print(ans,x)
return ans
print(sum(solve(v,i)*i for i,v in enumerate(b)))
main()
Please comment below if you have any questions, alternate solutions, or suggestions.