SEARCHBIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: boaca_andrei
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

3363

PREREQUISITES:

Knowledge of binary search, recursion, combinatorics

PROBLEM:

A certain binary search process on an array (described in the statement) returns the sum of all the values it encounters along the way, till it finds the required key or fails to do so.

Answer M queries of the following form:

  • Given x, find the sum of answers returned by the binary search, across all permutations of \{1, 2, \ldots, N\}.

EXPLANATION:

First, let’s try to solve a single query.
Suppose we know x, and we want to find the sum of binsearch(P, 1, N, x) across all permutations P of \{1, 2, \ldots, N\}.

Consider the index where the binary search process terminates when searching for x, say i.
Notice that the path to i is uniquely determined: it has length \mathcal{O}(\log N), and is obtained by a unique sequence of left/right moves, depending on how P_{mid} compares to x.

In particular, if we place elements that compare correctly at these positions, we can always force the binary search to terminate at i, regardless of what the rest of P looks like.

Suppose we require L elements strictly smaller than x, G elements strictly larger than x, and E elements equal to x (note that E is always 0 or 1, depending on how the search terminates).
Let’s count the contribution of each integer from 1 to N to this configuration.

Consider some 1 \leq y \lt x. In how many permutations does y occur in the path to i when searching for x?
A bit of combinatorics tells us that:

  • y can be placed in any of the L positions for “less than” elements; for L ways.
  • For the remaining L-1 positions, we can choose and place them in \binom{x-2}{L-1}\cdot (L-1)! ways.
  • The “greater than” positions are similarly filled in \binom{N-x}{G}\cdot G! ways.
  • If E = 1, x must be placed at index i; otherwise we don’t really care.
  • Finally, all the remaining elements can be permuted in (N-L-G-E)! ways.

So, the overall contribution of y is

y\cdot L\binom{x-2}{L-1} \binom{N-x}{G} (L-1)! G! (N-L-G-E)!

The multiplier is the same for all y, so after computing it once, it can be multiplied by the sum of all y\lt x to get their overall contribution in \mathcal{O}(1).

Similarly, the contribution of all y\gt x can be computed in \mathcal{O}(1) time.
Also, if E = 1 then P_i = x must hold, so compute the contribution of x itself in similar fashion.

Notice that we now have a solution in \mathcal{O}(N) time for a single query: fix each index to be the terminating point of the binary search, and compute the contribution in \mathcal{O}(1) time as we did above.


To optimize this further, observe that the actual index i didn’t matter at all: we only cared about L, G, and E: the results of the comparisons on the way to index i.
In particular, if multiple different indices had the same (L, G, E) triplet, they’d all have the same contribution to the answer.

Further, note that there can only be \mathcal{O}(\log^2 N) distinct such triplets at all, since 0 \leq L + G + E \leq \left\lceil \log_2 N \right\rceil and 0 \leq E \leq 1.

So, we can first simulate a single binary search on the entire array, to find all possible (L, G, E) triplets and their counts.
Then, for each query, compute the contribution of each triplet multiplied by its frequency, and sum them all up to obtain the final answer.

TIME COMPLEXITY

\mathcal{O}(N\log N + M\log^2 N) per testcase.

CODE

Author's code (C++)
// OFICIAL
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const ll mod=1e9+7;
ll fact[1000005],invfact[1000005];
ll pw(ll a, ll b)
{
    ll rez=1;
    while(b)
    {
        if(b&1)
            rez=(rez*a)%mod;
        b/=2;
        a=(a*a)%mod;
    }
    return rez;
}
ll inv(ll a)
{
    return pw(a,mod-2);
}
ll comb(ll a, ll b)
{
    if(a<0||b<0||a<b)
        return 0;
    ll rez=(fact[a]*invfact[b])%mod;
    rez=(rez*invfact[a-b])%mod;
    return rez;
}
ll aranj(ll a, ll b)
{
    return (comb(a,b)*fact[b])%mod;
}
ll sb(ll bile, ll urne)
{
    return comb(bile+urne-1,urne-1);
}
ll cayley(ll n,ll k)
{
    if(k==n)
        return 1;
    return (k*pw(n,n-k-1))%mod;
}
void makefact()
{
    fact[0]=1;
    for(ll i=1;i<=1000000;i++)
        fact[i]=(fact[i-1]*i)%mod;
    invfact[1000000]=inv(fact[1000000]);
    for(int j=1000000-1;j>=0;j--)
        invfact[j]=(invfact[j+1]*(j+1))%mod;
}
map<array<int,3>,ll> f;
ll smaller=0,bigger=0,eq=0;
ll ans=0;
ll n,m,x;
void findme(ll x,ll nr)
{
    ll val=x*(x-1)/2;
    val%=mod;
    val=(val*comb(x-2,smaller-1))%mod;
    val=(val*aranj(n-x,bigger))%mod;
    val=(val*fact[n-smaller-bigger-eq])%mod;
    val=(val*fact[smaller])%mod;
    val=(val*nr)%mod;
    ans=(ans+val)%mod;

    val=n*(n+1)/2-x*(x+1)/2;
    val%=mod;
    val=(val*aranj(x-1,smaller))%mod;
    val=(val*comb(n-x-1,bigger-1))%mod;
    val=(val*fact[n-smaller-bigger-eq])%mod;
    val=(val*fact[bigger])%mod;
    val=(val*nr)%mod;
    ans=(ans+val)%mod;

    if(eq==1)
    {
        val=x;
        val=(val*aranj(x-1,smaller))%mod;
        val=(val*aranj(n-x,bigger))%mod;
        val=(val*fact[n-smaller-bigger-eq])%mod;
        val=(val*nr)%mod;
        ans=(ans+val)%mod;
    }
}
void calc(ll st,ll dr)
{
    if(st>dr)
        return;
    eq++;
    //findme(x);
    f[{smaller,bigger,eq}]++;
    eq--;
    ll mij=(st+dr)/2;
    if(st>mij-1)
    {
        bigger++;
        //findme(x);
        f[{smaller,bigger,eq}]++;
        bigger--;
    }
    if(dr<mij+1)
    {
        smaller++;
        //findme(x);
        f[{smaller,bigger,eq}]++;
        smaller--;
    }
    bigger++;
    calc(st,mij-1);
    bigger--;
    smaller++;
    calc(mij+1,dr);
    smaller--;
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>m;
    makefact();
    calc(1,n);
    for(int q=1;q<=m;q++)
    {
        ll x;
        cin>>x;
        ans=0;
        for(auto p:f)
        {
            smaller=p.first[0]%mod;
            bigger=p.first[1]%mod;
            eq=p.first[2]%mod;
            ll nr=p.second%mod;
            findme(x,nr);
        }
        cout<<ans<<'\n';
    }
    return 0;
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=300005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
{
    if(a==0)
    return 0;
    int res=1;
    a%=p;
    while(b>0)
    {
        if(b&1)
        res=(1ll*res*a)%p;
        b>>=1;
        a=(1ll*a*a)%p;
    }
    return res;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}
//Mint

const int MOD=mod;
struct Mint {
    int val;
 
    Mint(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
        if (v >= MOD)
            v %= MOD;
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        } 
        return x < 0 ? x + m : x;
    } 
    explicit operator int() const {
        return val;
    }
    Mint& operator+=(const Mint &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    Mint& operator-=(const Mint &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
    Mint& operator*=(const Mint &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
    Mint& operator/=(const Mint &other) {
        return *this *= other.inv();
    }
    friend Mint operator+(const Mint &a, const Mint &b) { return Mint(a) += b; }
    friend Mint operator-(const Mint &a, const Mint &b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint &a, const Mint &b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint &a, const Mint &b) { return Mint(a) /= b; }
    Mint& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
    Mint& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
    // friend Mint operator<=(const Mint &a, const Mint &b) { return (int)a <= (int)b; }
    Mint operator++(int32_t) { Mint before = *this; ++*this; return before; }
    Mint operator--(int32_t) { Mint before = *this; --*this; return before; }
    Mint operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
    bool operator==(const Mint &other) const { return val == other.val; }
    bool operator!=(const Mint &other) const { return val != other.val; }
    Mint inv() const {
        return mod_inv(val);
    }
    Mint power(long long p) const {
        assert(p >= 0);
        Mint a = *this, result = 1;
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
        return result;
    }
    friend ostream& operator << (ostream &stream, const Mint &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, Mint &m) {
        return stream>>m.val;   
    }
};

Mint fact[N], inv[N], inv2[N];

void pre()
{
    fact[0]=inv[0]=1;
    rep(i,1,N)
    fact[i]=(fact[i-1]*i);
    rep(i,1,N)
    {
        inv[i]=(Mint)1/fact[i];
        inv2[i]=(Mint)1/(Mint)i;
    }
}
Mint nCr(int n, int r)
{
    if(min(n, r)<0 || r>n)
    return 0;
    if(n==r)
    return 1;
    return (((fact[n]*inv[r]))*inv[n-r]);
}

map <pair <pii, int>, int>  mp;
void divide(int l, int r, int sm, int gr)
{
    if(r<l)
    return;
    int md=(l+r)/2;
    mp[{{sm, gr}, 1}]++;
    if(md-1<l)
    mp[{{sm, gr+1}, 0}]++;
    if(md+1>r)
    mp[{{sm+1, gr}, 0}]++;
    divide(l, md-1, sm, gr+1);
    divide(md+1, r, sm+1, gr);
}

Mint get(Mint n)
{
    return (n*(n+1)*inv2[2]);
}

Mint avg(int l, int r)
{
    if(r<l)
    return 0;
    Mint sum=get(r)-get(l-1);
    return sum*inv2[r-l+1];
}
int32_t main()
{
    IOS;
    pre();
    int n, m;
    cin>>n>>m;
    divide(1, n, 0, 0);
    vector <pair <pii, pii> > res;
    for(auto it:mp)
    res.pb({it.ff.ff, {it.ff.ss, it.ss}});
    while(m--)
    {
        int x;
        cin>>x;
        Mint ans=0;
        for(auto p:res)
        {
            int sm=p.ff.ff, gr=p.ff.ss, eq=p.ss.ff;
            Mint co=p.ss.ss;
            if(eq)
            {
                Mint ways=(nCr(x-1, sm)*fact[sm]*nCr(n-x, gr)*fact[gr]*co*fact[n-sm-gr-1]);
                Mint sum=((avg(1, x-1)*sm) + (avg(x+1, n)*gr) + x);
                ans+=(ways*sum);
            }
            else
            {
                Mint ways=(nCr(x-1, sm)*fact[sm]*nCr(n-x, gr)*fact[gr]*co*fact[n-sm-gr]);
                Mint sum=((avg(1, x-1)*sm) + (avg(x+1, n)*gr));
                ans+=(ways*sum);
            }
        }
        cout<<ans<<"\n";
    }
}
Editorialist's code (Python)
mod = 10**9 + 7
MX = 10**6 + 10
fac = [1]*MX
for i in range(1, MX): fac[i] = i*fac[i-1] % mod
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for i in reversed(range(MX-1)):
    inv[i] = inv[i+1] * (i+1) % mod

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

from collections import defaultdict

n, m = map(int, input().split())
a = list(map(int, input().split()))

freq = defaultdict(int)
def calc(l, r, curl, curm):
    if l > r:
        freq[(curl, curm, 0)] += 1
        return
    mid = (l+r)//2
    freq[(curl, curm, 1)] += 1
    calc(l, mid-1, curl, curm+1)
    calc(mid+1, r, curl+1, curm)
calc(1, n, 0, 0)

mod = 10**9 + 7
for x in a:
    ans = 0
    for (lo, hi, eq), ct in freq.items():
        ways = C(x-2, lo-1) * C(n-x, hi) % mod * fac[lo] % mod * fac[hi] % mod * fac[n - lo - hi - eq] % mod * ct % mod
        ans += x * (x-1)//2 % mod * ways % mod

        ways = C(x-1, lo) * C(n-x-1, hi-1) % mod * fac[lo] % mod * fac[hi] % mod * fac[n - lo - hi - eq] % mod * ct % mod
        ans += (n*(n+1)//2 - x*(x+1)//2) % mod * ways % mod

        if eq == 1:
            ways = C(x-1, lo) * C(n-x, hi) % mod * fac[lo] % mod * fac[hi] % mod * fac[n - lo - hi - eq] % mod * ct % mod
            ans += x * ways % mod
        ans %= mod
    print(ans)