XSQRH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You’re given an array A containing N integers.
Count the number of ordered tuples (i, j, k, l) such that the values A_i\oplus A_j, A_j\oplus A_k, A_k\oplus A_l, A_l\oplus A_i in order can be the sides of a rectangle with positive area.

EXPLANATION:

It is recommended that you read the solution to the easy version first.

Now that A can contain duplicates, our only issue is certain side lengths becoming zero - since we didn’t control for that at all.
Rather than trying to modify our solution to account for that midway, it’s easier to just count everything as in the easy version, and then subtract out the ‘bad’ tuples.

Let’s analyze when a side length of 0 can occur.
Recall that x\oplus y = 0 \iff x = y.
This means some two adjacent elements of the tuple must be equal.
Looking at cases:

  1. Suppose all four elements are equal.
    Then there’s no valid way of reordering the indices - all XORs will always be 0.
    Such a tuple has been counted 24 times, but should be counted 0 times in reality.
  2. Suppose A_i = A_j.
    Then, since A_i\oplus A_j = A_k \oplus A_l, we must also have A_k\oplus A_l = 0 meaning A_k = A_l.
    So, we have two pairs of equal elements - say (x, x, y, y) (we assume x \neq y, since otherwise it goes back to the first case).
    A tuple of the form (x, x, y, y) has been counted 24 times, but only 8 of those are actually valid: (x, y, x, y) and (y, x, y, x) with four ways to arrange indices in each pattern.

Now, let’s subtract these tuples out.
Let f_x denote the number of times element x appears in A.
Then,

  • For each element x, there are \displaystyle\binom{f_x}{4} ways to choose four distinct indices containing x.
    Each of these should be subtracted 24 times, so subtract \displaystyle24\cdot \binom{f_x}{4} from the answer.
    This takes \mathcal{O}(N) time overall, since we only care about the distinct values of x that appear in A.
  • For each pair of distinct elements x \lt y, there are \displaystyle\binom{f_x}{2} \cdot \binom{f_y}{2} ways of choosing a 4-tuple that includes two each of x and y.
    Each of these tuples has been counted 24 times, but should be counted only 8 times - so subtract 16\cdot\displaystyle\binom{f_x}{2} \cdot \binom{f_y}{2} from the answer.
    This takes \mathcal{O}(N^2) time, since we iterate over only pairs of elements that actually exist in A.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

ll cnt[2000005];
ll cnt2[2000005];

int main() {
	ll tt=1;
    cin>>tt;
    while(tt--){
        ll n;
        cin>>n;
        ll a[n];
        set <ll> s;
        for(int i=0;i<n;i++){
            cin>>a[i];
            cnt2[a[i]]++;
            s.insert(a[i]);
        }
        ll ans=0;
        vector <ll> diff;
        for(auto it:s){
            diff.push_back(it);
        }
        ll m=diff.size();
        ll fact=0;
        ll sum=0;
        vector <ll> pre;
        for(int i=0;i<m;i++){
            sum+=fact*(cnt2[diff[i]]*(cnt2[diff[i]]-1))/2;
            fact+=(cnt2[diff[i]]*(cnt2[diff[i]]-1))/2;
            for(int j=i+1;j<m;j++){
                if(cnt[diff[i]^diff[j]]==0){
                    pre.push_back(diff[i]^diff[j]);
                }
                ans-=((cnt2[diff[i]]*cnt2[diff[j]])*(cnt2[diff[i]]*cnt2[diff[j]]-1))/2;
                cnt[diff[i]^diff[j]]+=cnt2[diff[i]]*cnt2[diff[j]];
            }
        }
        for(auto it:pre){
            ans+=(cnt[it]*(cnt[it]-1))/2;
        }
        ans+=sum;
        ans*=8;
        cout<<ans<<"\n";
        for(int i=0;i<m;i++){
            cnt2[diff[i]]=0;
            for(int j=i+1;j<m;j++){
                cnt[diff[i]^diff[j]]=0;
            }
        }
    }
}
Tester's code (C++)
// Har Har Mahadev
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> // Common file
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define int long long
#define rep(i,a,b) for(int i=a;i<b;i++)
#define rrep(i,a,b) for(int i=a;i>=b;i--)
#define repin rep(i,0,n)
#define precise(i) cout<<fixed<<setprecision(i)
#define vi vector<int>
#define si set<int>
#define mii map<int,int>
#define take(a,n) for(int j=0;j<n;j++) cin>>a[j];
#define give(a,n) for(int j=0;j<n;j++) cout<<a[j]<<' ';
#define vpii vector<pair<int,int>>
#define db double
#define be(x) x.begin(),x.end()
#define pii pair<int,int>
#define pb push_back
#define pob pop_back
#define ff first
#define ss second
#define lb lower_bound
#define ub upper_bound
#define bpc(x) __builtin_popcountll(x) 
#define btz(x) __builtin_ctz(x)
using namespace std;

using namespace __gnu_pbds;

typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int, int>, null_type,less<pair<int, int> >, rb_tree_tag,tree_order_statistics_node_update> ordered_multiset;

const long long INF=1e18;
const long long M=1e9+7;
const long long MM=998244353;
  
int power( int N, int M){
    int power = N, sum = 1;
    if(N == 0) sum = 0;
    while(M > 0){if((M & 1) == 1){sum *= power;}
    power = power * power;M = M >> 1;}
    return sum;
}

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
}inp;


int fr[2000001];
int frx[2000001];

int comb(int n){
    return n*(n-1)/2;
}

int smn = 0;
 
void solve()
{
    int n;
    // cin >> n;
    n = inp.readInt(1,5000);
    smn += n;
    inp.readEoln();
    vi a(n);
    // take(a,n);
    repin{
        a[i] = inp.readInt(0,1000'000);
        if(i == n-1)inp.readEoln();
        else inp.readSpace();
    }
    int mx = (*max_element(be(a)));
    si s(be(a));
    vi b;
    for(auto x : s)b.pb(x);
    repin{
        fr[a[i]]++;
    }
    int ans = 0;
    vi oc;
    int m = b.size();
    rep(i,0,m){
        rep(j,i+1,m){
            if(b[i] == b[j])continue;
            if(frx[b[i]^b[j]] == 0)oc.pb(b[i]^b[j]);
            frx[b[i]^b[j]] += fr[b[i]]*fr[b[j]];
            ans -= comb(fr[b[i]]*fr[b[j]]);
        }
    }

    for(auto x : oc){
        ans += comb(frx[x]);
    }

    ans *= 8;

    int sm = 0;
    for(auto x : s){
        sm += comb(fr[x]);
    }
    sm *= sm;
    for(auto x : s){
        sm -= comb(fr[x])*comb(fr[x]);
    }

    ans += sm*4;
    
    cout << ans << "\n";
    
    rep(i,0,m){
        rep(j,i+1,m){
            if(b[i] == b[j])continue;
            frx[b[i]^b[j]] -= fr[b[i]]*fr[b[j]];
        }
    }
    repin{
        fr[a[i]]--;
    }
}

signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    #ifdef NCR
        init();
    #endif
    #ifdef SIEVE
        sieve();
    #endif
    int t;
    // cin >> t;
    t = inp.readInt(1,1000);
    inp.readEoln();
    while(t--)
        solve();
    inp.readEof();
    assert(smn <= 5000);
    return 0;
}
Editorialist's code (Python)
M = 1 << 21
freq = [0]*M
ele_freq = [0]*M

for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    ans = 0
    for i in range(n):
        for j in range(i):
            freq[a[i] ^ a[j]] += 1
        for j in range(i+2, n): ans += freq[a[i+1] ^ a[j]]
    ans *= 24
    for x in a: ele_freq[x] += 1

    distinct = list(set(a))
    sz = len(distinct)
    for i in range(sz):
        x = ele_freq[distinct[i]]
        ans -= x*(x-1)*(x-2)*(x-3)
        for j in range(i+1, sz):
            y = ele_freq[distinct[j]]
            ans -= 4 * x*(x-1) * y*(y-1)
    print(ans)

    for x in a: ele_freq[x] = 0
    for x in a:
        for y in a:
            freq[x^y] = 0

1 Like

My solution is indeed O(n^2).

But in the contest, it kept saying TLE.
After the contest, same code gave ACCEPTED.
I guess the online judge 's load was pretty high during the contest.

2 Likes

Pypy 3 is accepted but not python3

1 Like

some accepted solution now show time limit also …my code is not accepted

i think this solution need to rejudged many code was not accepted for compiler

Here it is said that we are removing the bad tuples, that’s ok, but in the actual counting,

suppose we have 2 4 2 4, then (1,2)(2,3) pairs would also form(in the two loops), where their xor is 6, now here the i,j,k,l are 1,2,2,3 are not pairwise distinct,how are we accounting for that?They shouldn’t be there…

In the easy version, since the duplicates are not there, this case won’t arise, so we can directly use brute to find that, here these type of cases are also there, where are they removed?

these types of quadruples are not counted initially , as
1^2 !=2^3 , we are not checking like 1^2 ^ 3==0 , instead we are checking if
a^b=c^d then definetly a^b^c^d=0

I meant (1,2)(2,3) they are indices

even if they are indices , take for example , they are never counted ,
consider this
indices - [1,2,2,3]
arr - [1,2,3]
arr[1]^arr[2] !=arr[2]^arr[3]

Bruh, read my first post once again. I mentioned 2 4 2 4 array now (1,2)(2,3) indices of that array, is what I mentioned.