GOOD_PERM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: krypto_ray
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Math

PROBLEM:

You have two arrays A and B.
Find the number of permutations P such that the following holds:

  • For each i, you can choose either the value A_i + B_{P_i} or A_i - B_{P_i}
  • There must exist a way to make these choices such that all N values obtained are equal.

EXPLANATION:

First, let’s replace each B_i by |B_i|.
This makes reasoning about things easier, and doesn’t actually affect what we can do since at worst signs are flipped, but we have the choice of both addition and subtraction to counteract that.

For convenience, let’s also sort the arrays A and B, so that A_i \leq A_{i+1} and B_i \leq B_{i+1}. Recall that every B_i is currently \geq 0 as well.

Now, B_N has to be matched with something, say A_i.
There are two cases:

  • Suppose we choose A_i + B_N. Then, if there exists a j such that A_j \lt A_i, it’s impossible to make all the elements equal (since A_j + B_k can never equal A_i + B_N no matter what k is chosen).
    This means that in this case, A_i must be the minimum element of the array.
  • Suppose we choose A_i - B_N. Similar reasoning tells us that A_i should be the maximum element of the array in this case.

So, there are only two cases to check: making everything equal to A_1 + B_N, or making everything equal to A_N - B_N.
Note that if A_1 + B_N = A_N - B_N, only one check needs to be made to prevent double-counting.

We now have a subproblem to solve: suppose the final value should be x. How many permutations allow us to achieve this?

This can be solved as follows:

  • Let \text{freq} denote the frequency array of B, so \text{freq}[x] denotes the number of times x occurs in B.
  • Initialize the answer to 1.
  • For each i from 1 to N,
    • If A_i \lt x, we need to match it with x - A_i. Otherwise, we need to match it with A_i - x.
      Note that this is because we made all the B_i \geq 0, which makes this step simpler since we only need to care about positive numbers.
    • In general, we need to match A_i with y = |A_i - x|.
    • There are \text{freq}[y] such choices, so multiply the answer with \text{freq}[y].
    • Then, decrease \text{freq}[y] by 1, to account for the one we used up.

So, for a fixed x, we can compute the number of valid permutations in \mathcal{O}(N) or \mathcal{O}(N\log N).
There are at most 2 values of x that need to be checked for, so check both and add up the answers to obtain the final answer.

TIME COMPLEXITY

\mathcal{O}(N) or \mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

#define int  long long
#define F    first
#define S    second
#define pb   push_back
#define endl "\n"

const int mod=1e9+7;
const int mod1=998244353;
const int inf=1e18;
const long double pi=2*acos(0.0);
const long double eps=1e-9;
const int N=2e5;

int fact[N+1];

int power(int a,int b){
    int res=1;
    while(b){
        if(b&1){
            res=(res*a)%mod;
        }
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}

int C(int n,int r){
    return (fact[n]*power((fact[r]*fact[n-r])%mod,mod-2))%mod;
}

void solve(){

    int n;
    cin>>n;
    map<int,int>mp1,mp2;
    for(int i=0;i<n;i++){
        int x;
        cin>>x;
        mp1[x]++;
    }
    for(int i=0;i<n;i++){
        int x;
        cin>>x;
        x=abs(x);
        mp2[x]++;
    }
    if(mp1.size()==1 and mp2.size()==1){
        cout<<fact[n]<<endl;
        return;
    }
    map<int,int>mp3=mp2;
    int ans=0;
    int curr=1;
    int to_achieve1=(*mp1.begin()).F+(*mp3.rbegin()).F;
    for(auto x:mp1){
        int required=abs(to_achieve1-x.F);
        if(mp3[required]<x.S){
            curr*=0;
            break;
        }
        else{
            curr=(curr*C(mp3[required],x.S))%mod;
            curr=(curr*fact[x.S])%mod;
            mp3[required]-=x.S;
        }
    }
    ans+=curr;
    mp3=mp2;
    int to_achieve2=(*mp1.rbegin()).F-(*mp3.rbegin()).F;
    curr=0;
    if(to_achieve1!=to_achieve2){
        curr=1;
        for(auto x:mp1){
            int required=abs(to_achieve2-x.F);
            if(mp3[required]<x.S){
                curr*=0;
                break;
            }
            else{
                curr=(curr*C(mp3[required],x.S))%mod;
                curr=(curr*fact[x.S])%mod;
                mp3[required]-=x.S;
            }
        }
    }
    ans+=curr;
    cout<<ans%mod<<endl;
}
int32_t main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    fact[0]=1;
    for(int i=1;i<=N;i++){
        fact[i]=(fact[i-1]*i)%mod;
    }
    int t=1;
    cin>>t;
    while(t--){
        solve();  
    }
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = (int) 1e9 + 7;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e5);
        in.readEoln();
        sn += n;
        auto a = in.readLongs(n, -1e9, 1e9);
        in.readEoln();
        auto b = in.readLongs(n, -1e9, 1e9);
        in.readEoln();
        for (int i = 0; i < n; i++) {
            // assert(a[i] != 0);
            // assert(b[i] != 0);
            b[i] = abs(b[i]);
        }
        sort(a.begin(), a.end());
        sort(b.begin(), b.end());
        mint t = 1;
        map<long long, int> cnt;
        for (int i = 0; i < n; i++) {
            cnt[b[i]]++;
        }
        for (auto p : cnt) {
            C(p.second, 0);
            t *= fact[p.second];
        }
        mint ans = 0;
        set<vector<long long>> st;
        auto Check = [&](long long c) {
            vector<long long> d(n);
            for (int i = 0; i < n; i++) {
                d[i] = abs(c - a[i]);
            }
            auto e = d;
            sort(d.begin(), d.end());
            if (b == d) {
                if (!st.count(e)) {
                    ans += t;
                    st.emplace(e);
                }
            }
        };
        Check(a[0] + b.back());
        Check(a.back() - b.back());
        cout << ans << '\n';
    }
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
from collections import Counter

def calc(a, b, target):
	freq = Counter(b)
	
	ways = 1
	for x in a:
		y = abs(target - x)
		ways *= freq[y]
		ways %= mod
		freq[y] -= 1
	return ways
	

for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	b = list(map(int, input().split()))

	for i in range(n): b[i] = max(b[i], -b[i])
	
	ans = calc(a, b, max(a) - max(b))
	if min(a) != max(a) and max(a) - max(b) != min(a) + max(b): ans += calc(a, b, min(a) + max(b))
	print(ans % mod)