PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
None
PROBLEM:
You’re given an array A containing N integers.
Count the number of ordered tuples (i, j, k, l) such that the values A_i\oplus A_j, A_j\oplus A_k, A_k\oplus A_l, A_l\oplus A_i in order can be the sides of a rectangle with positive area.
EXPLANATION:
It is recommended that you read the solution to the easy version first.
Now that A can contain duplicates, our only issue is certain side lengths becoming zero - since we didn’t control for that at all.
Rather than trying to modify our solution to account for that midway, it’s easier to just count everything as in the easy version, and then subtract out the ‘bad’ tuples.
Let’s analyze when a side length of 0 can occur.
Recall that x\oplus y = 0 \iff x = y.
This means some two adjacent elements of the tuple must be equal.
Looking at cases:
- Suppose all four elements are equal.
Then there’s no valid way of reordering the indices - all XORs will always be 0.
Such a tuple has been counted 24 times, but should be counted 0 times in reality. - Suppose A_i = A_j.
Then, since A_i\oplus A_j = A_k \oplus A_l, we must also have A_k\oplus A_l = 0 meaning A_k = A_l.
So, we have two pairs of equal elements - say (x, x, y, y) (we assume x \neq y, since otherwise it goes back to the first case).
A tuple of the form (x, x, y, y) has been counted 24 times, but only 8 of those are actually valid: (x, y, x, y) and (y, x, y, x) with four ways to arrange indices in each pattern.
Now, let’s subtract these tuples out.
Let f_x denote the number of times element x appears in A.
Then,
- For each element x, there are \displaystyle\binom{f_x}{4} ways to choose four distinct indices containing x.
Each of these should be subtracted 24 times, so subtract \displaystyle24\cdot \binom{f_x}{4} from the answer.
This takes \mathcal{O}(N) time overall, since we only care about the distinct values of x that appear in A. - For each pair of distinct elements x \lt y, there are \displaystyle\binom{f_x}{2} \cdot \binom{f_y}{2} ways of choosing a 4-tuple that includes two each of x and y.
Each of these tuples has been counted 24 times, but should be counted only 8 times - so subtract 16\cdot\displaystyle\binom{f_x}{2} \cdot \binom{f_y}{2} from the answer.
This takes \mathcal{O}(N^2) time, since we iterate over only pairs of elements that actually exist in A.
TIME COMPLEXITY:
\mathcal{O}(N^2) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll cnt[2000005];
ll cnt2[2000005];
int main() {
ll tt=1;
cin>>tt;
while(tt--){
ll n;
cin>>n;
ll a[n];
set <ll> s;
for(int i=0;i<n;i++){
cin>>a[i];
cnt2[a[i]]++;
s.insert(a[i]);
}
ll ans=0;
vector <ll> diff;
for(auto it:s){
diff.push_back(it);
}
ll m=diff.size();
ll fact=0;
ll sum=0;
vector <ll> pre;
for(int i=0;i<m;i++){
sum+=fact*(cnt2[diff[i]]*(cnt2[diff[i]]-1))/2;
fact+=(cnt2[diff[i]]*(cnt2[diff[i]]-1))/2;
for(int j=i+1;j<m;j++){
if(cnt[diff[i]^diff[j]]==0){
pre.push_back(diff[i]^diff[j]);
}
ans-=((cnt2[diff[i]]*cnt2[diff[j]])*(cnt2[diff[i]]*cnt2[diff[j]]-1))/2;
cnt[diff[i]^diff[j]]+=cnt2[diff[i]]*cnt2[diff[j]];
}
}
for(auto it:pre){
ans+=(cnt[it]*(cnt[it]-1))/2;
}
ans+=sum;
ans*=8;
cout<<ans<<"\n";
for(int i=0;i<m;i++){
cnt2[diff[i]]=0;
for(int j=i+1;j<m;j++){
cnt[diff[i]^diff[j]]=0;
}
}
}
}
Tester's code (C++)
// Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x)
#define btz(x) __builtin_ctz(x)
using namespace std;
using namespace __gnu_pbds;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;
const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
int power( int N, int M){
int power = N, sum = 1;
if(N == 0) sum = 0;
while(M > 0){if((M & 1) == 1){sum *= power;}
power = power * power;M = M >> 1;}
return sum;
}
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
}inp;
int fr[2000001];
int frx[2000001];
int comb(int n){
return n*(n-1)/2;
}
int smn = 0;
void solve()
{
int n;
// cin >> n;
n = inp.readInt(1,5000);
smn += n;
inp.readEoln();
vi a(n);
// take(a,n);
repin{
a[i] = inp.readInt(0,1000'000);
if(i == n-1)inp.readEoln();
else inp.readSpace();
}
int mx = (*max_element(be(a)));
si s(be(a));
vi b;
for(auto x : s)b.pb(x);
repin{
fr[a[i]]++;
}
int ans = 0;
vi oc;
int m = b.size();
rep(i,0,m){
rep(j,i+1,m){
if(b[i] == b[j])continue;
if(frx[b[i]^b[j]] == 0)oc.pb(b[i]^b[j]);
frx[b[i]^b[j]] += fr[b[i]]*fr[b[j]];
ans -= comb(fr[b[i]]*fr[b[j]]);
}
}
for(auto x : oc){
ans += comb(frx[x]);
}
ans *= 8;
int sm = 0;
for(auto x : s){
sm += comb(fr[x]);
}
sm *= sm;
for(auto x : s){
sm -= comb(fr[x])*comb(fr[x]);
}
ans += sm*4;
cout << ans << "\n";
rep(i,0,m){
rep(j,i+1,m){
if(b[i] == b[j])continue;
frx[b[i]^b[j]] -= fr[b[i]]*fr[b[j]];
}
}
repin{
fr[a[i]]--;
}
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
#ifdef NCR
init();
#endif
#ifdef SIEVE
sieve();
#endif
int t;
// cin >> t;
t = inp.readInt(1,1000);
inp.readEoln();
while(t--)
solve();
inp.readEof();
assert(smn <= 5000);
return 0;
}
Editorialist's code (Python)
M = 1 << 21
freq = [0]*M
ele_freq = [0]*M
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
ans = 0
for i in range(n):
for j in range(i):
freq[a[i] ^ a[j]] += 1
for j in range(i+2, n): ans += freq[a[i+1] ^ a[j]]
ans *= 24
for x in a: ele_freq[x] += 1
distinct = list(set(a))
sz = len(distinct)
for i in range(sz):
x = ele_freq[distinct[i]]
ans -= x*(x-1)*(x-2)*(x-3)
for j in range(i+1, sz):
y = ele_freq[distinct[j]]
ans -= 4 * x*(x-1) * y*(y-1)
print(ans)
for x in a: ele_freq[x] = 0
for x in a:
for y in a:
freq[x^y] = 0