EATROCK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Testers: wuhudsm, satyam_343
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Point update/range query structures (segment trees/fenwick trees), basic combinatorics

PROBLEM:

There are N rocks, the i-th of which is at position X_i and has weight W_i.

Rocky will eat a subset of these rocks starting from the one with least weight, then moving to the one with second least weight, and so on.
Find the expected distance travelled by Rocky if any subset is equally likely to be chosen.

EXPLANATION:

First off, each subset has an equally likely chance of being chosen, and there are 2^N subsets.
So, to find the expected value, it’s enough to compute the total distance traveled by Rocky across all subsets, and then divide this by 2^N.

As is common with several problems of this type (i.e, find the sum of something across all subsets), instead of fixing a subset and computing the desired quantity for it, we can fix a quantity and see how many subsets it occurs in.

In our case, the quantity we’re interested in the distance between points, so let’s look at that.
In particular, we care about distances between adjacent (in terms of weight) points in the chosen subset; since the total distance covered is the sum of these distances.

This quickly gives rise to a solution in \mathcal{O}(N^2), as follows:

  • Let’s fix two indices i and j (without loss of generality, say W_i \lt W_j), and count the number of subsets in which Rocky travels directly from X_i to X_j.
  • In order for this direct travel to happen, the chosen subset cannot contain any weights lying strictly between W_i and W_j.
    On the other hand, any weights \lt W_i or \gt W_j can be freely chosen.
  • This gives us 2^{W_i - 1 + N - W_j} subsets in which this direct travel happens, each for a distance of |X_i - X_j|.

With i and j fixed, this is an \mathcal{O}(1) (or \mathcal{O}(\log {MOD})) computation, and we have our quadratic solution.

To optimize this, let’s fix the value of j, and look at the structure of all i that we need to compute for.

For now, only consider X_i \lt X_j; the other side can be computed similarly.

For a fixed i, the contribution is (X_j - X_i)\cdot 2^{W_i - 1 + N - W_j}.
Here, 2^{N - W_j} is a constant independent of i, so let’s ignore it for now and multiply it in the end.
The remaining part can be split into (X_j \cdot 2^{W_i - 1}) and -(X_i\cdot 2^{W_i - 1}).

We want the sum of this across all i such that X_i \lt X_j and W_i \lt W_j.
Notice that all we really need is the sum of all 2^{W_i - 1} and the sum of X_i \cdot 2^{W_i - 1}, which are both quantities depending purely on i.

This looks something like a 2D range query, but it can in fact be done with a simple segment tree.

Let’s iterate over values of W_j from 1 to N (note that we’re iterating from small to large weight, not index).
Doing this automatically satisfies the constraint W_i \lt W_j whenever we do a query, so we only need to take care of the X_i \lt X_j condition.

This means we can simply query for all indices \lt j (since the input guarantees that X_1 \lt X_2 \lt \ldots \lt X_N).
Knowing the 2^{W_i-1} and X_i\cdot 2^{W_i - 1} values for these indices, this reduces to a simple range query.

In order to ensure that we don’t compute these values for indices with weights \gt W_j, set the values at all indices to be zero initially, and update an index’s value only after it has been processed as a W_j.
This way, when doing the range sum query, the contribution of higher weight indices is 0, which is what we want.

X_i \gt X_j can be handled similarly, the range query is on a suffix instead of a prefix.
Don’t forget to divide by 2^N in the end.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Integers modulo p, where p is a prime
 * Source: Aeren (modified from tourist?)
 *         Modmul for 64-bit mod from kactl:ModMulLL
 * Works with p < 7.2e18 with x87 80-bit long double, and p < 2^52 ~ 4.5e12 with 64-bit
 */
template<typename T>
struct Z_p{
	using Type = typename decay<decltype(T::value)>::type;
	static vector<Type> MOD_INV;
	constexpr Z_p(): value(){ }
	template<typename U> Z_p(const U &x){ value = normalize(x); }
	template<typename U> static Type normalize(const U &x){
		Type v;
		if(-mod() <= x && x < mod()) v = static_cast<Type>(x);
		else v = static_cast<Type>(x % mod());
		if(v < 0) v += mod();
		return v;
	}
	const Type& operator()() const{ return value; }
	template<typename U> explicit operator U() const{ return static_cast<U>(value); }
	constexpr static Type mod(){ return T::value; }
	Z_p &operator+=(const Z_p &otr){ if((value += otr.value) >= mod()) value -= mod(); return *this; }
	Z_p &operator-=(const Z_p &otr){ if((value -= otr.value) < 0) value += mod(); return *this; }
	template<typename U> Z_p &operator+=(const U &otr){ return *this += Z_p(otr); }
	template<typename U> Z_p &operator-=(const U &otr){ return *this -= Z_p(otr); }
	Z_p &operator++(){ return *this += 1; }
	Z_p &operator--(){ return *this -= 1; }
	Z_p operator++(int){ Z_p result(*this); *this += 1; return result; }
	Z_p operator--(int){ Z_p result(*this); *this -= 1; return result; }
	Z_p operator-() const{ return Z_p(-value); }
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int>::value, Z_p>::type &operator*=(const Z_p& rhs){
		#ifdef _WIN32
		uint64_t x = static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value);
		uint32_t xh = static_cast<uint32_t>(x >> 32), xl = static_cast<uint32_t>(x), d, m;
		asm(
			"divl %4; \n\t"
			: "=a" (d), "=d" (m)
			: "d" (xh), "a" (xl), "r" (mod())
		);
		value = m;
		#else
		value = normalize(static_cast<int64_t>(value) * static_cast<int64_t>(rhs.value));
		#endif
		return *this;
	}
	template<typename U = T>
	typename enable_if<is_same<typename Z_p<U>::Type, int64_t>::value, Z_p>::type &operator*=(const Z_p &rhs){
		uint64_t ret = static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value) - static_cast<uint64_t>(mod()) * static_cast<uint64_t>(1.L / static_cast<uint64_t>(mod()) * static_cast<uint64_t>(value) * static_cast<uint64_t>(rhs.value));
		value = normalize(static_cast<int64_t>(ret + static_cast<uint64_t>(mod()) * (ret < 0) - static_cast<uint64_t>(mod()) * (ret >= static_cast<uint64_t>(mod()))));
		return *this;
	}
	template<typename U = T>
	typename enable_if<!is_integral<typename Z_p<U>::Type>::value, Z_p>::type &operator*=(const Z_p &rhs){
		value = normalize(value * rhs.value);
		return *this;
	}
	template<typename U>
	Z_p &operator^=(U e){
		if(e < 0) *this = 1 / *this, e = -e;
		Z_p res = 1;
		for(; e; *this *= *this, e >>= 1) if(e & 1) res *= *this;
		return *this = res;
	}
	template<typename U>
	Z_p operator^(U e) const{
		return Z_p(*this) ^= e;
	}
	Z_p &operator/=(const Z_p &otr){
		Type a = otr.value, m = mod(), u = 0, v = 1;
		if(a < (int)MOD_INV.size()) return *this *= MOD_INV[a];
		while(a){
			Type t = m / a;
			m -= t * a; swap(a, m);
			u -= t * v; swap(u, v);
		}
		assert(m == 1);
		return *this *= u;
	}
	template<typename U> friend const Z_p<U> &abs(const Z_p<U> &v){ return v; }
	Type value;
};
template<typename T> bool operator==(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value == rhs.value; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(const Z_p<T>& lhs, U rhs){ return lhs == Z_p<T>(rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator==(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) == rhs; }
template<typename T> bool operator!=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(const Z_p<T> &lhs, U rhs){ return !(lhs == rhs); }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> bool operator!=(U lhs, const Z_p<T> &rhs){ return !(lhs == rhs); }
template<typename T> bool operator<(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value < rhs.value; }
template<typename T> bool operator>(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value > rhs.value; }
template<typename T> bool operator<=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value <= rhs.value; }
template<typename T> bool operator>=(const Z_p<T> &lhs, const Z_p<T> &rhs){ return lhs.value >= rhs.value; }
template<typename T> Z_p<T> operator+(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(const Z_p<T> &lhs, U rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator+(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) += rhs; }
template<typename T> Z_p<T> operator-(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator-(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) -= rhs; }
template<typename T> Z_p<T> operator*(const Z_p<T> &lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(const Z_p<T>& lhs, U rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator*(U lhs, const Z_p<T> &rhs){ return Z_p<T>(lhs) *= rhs; }
template<typename T> Z_p<T> operator/(const Z_p<T> &lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(const Z_p<T>& lhs, U rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T, typename U, typename enable_if<is_integral<U>::value>::type* = nullptr> Z_p<T> operator/(U lhs, const Z_p<T> &rhs) { return Z_p<T>(lhs) /= rhs; }
template<typename T> istream &operator>>(istream &in, Z_p<T> &number){
	typename common_type<typename Z_p<T>::Type, int64_t>::type x;
	in >> x;
	number.value = Z_p<T>::normalize(x);
	return in;
}
template<typename T> ostream &operator<<(ostream &out, const Z_p<T> &number){ return out << number(); }

/*
using ModType = int;
struct VarMod{ static ModType value; };
ModType VarMod::value;
ModType &mod = VarMod::value;
using Zp = Z_p<VarMod>;
*/

constexpr int mod = 1e9 + 7; // 1000000007
// constexpr int mod = (119 << 23) + 1; // 998244353
// constexpr int mod = 1e9 + 9; // 1000000009
using Zp = Z_p<integral_constant<decay<decltype(mod)>::type, mod>>;

template<typename T> vector<typename Z_p<T>::Type> Z_p<T>::MOD_INV;
template<typename T = integral_constant<decay<decltype(mod)>::type, mod>>
void precalc_inverse(int SZ){
	auto &inv = Z_p<T>::MOD_INV;
	if(inv.empty()) inv.assign(2, 1);
	for(; inv.size() <= SZ; ) inv.push_back((mod - 1LL * mod / (int)inv.size() * inv[mod % (int)inv.size()]) % mod);
}

template<typename T>
vector<T> precalc_power(T base, int SZ){
	vector<T> res(SZ + 1, 1);
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * base;
	return res;
}

template<typename T>
vector<T> precalc_factorial(int SZ){
	vector<T> res(SZ + 1, 1); res[0] = 1;
	for(auto i = 1; i <= SZ; ++ i) res[i] = res[i - 1] * i;
	return res;
}

struct Data {
	Zp powsum = 0, pospowsum = 0;
}unit;

/**
 * Point-update Segment Tree
 * Source: kactl
 * Description: Iterative point-update segment tree, ranges are half-open i.e [L, R).
 *              f is any associative function.
 * Time: O(logn) update/query
 */

struct SegTree {
	using T = Data;
	T f(T a, T b) { 
		a.powsum += b.powsum;
		a.pospowsum += b.pospowsum;
		return a;
	}
	vector<T> s; int n;
	SegTree(int _n = 0, T def = unit) : s(2*_n, def), n(_n) {}
	void update(int pos, T val) {
		for (s[pos += n] = val; pos /= 2;)
			s[pos] = f(s[pos * 2], s[pos * 2 + 1]);
	}
	T query(int b, int e) {
		T ra = unit, rb = unit;
		for (b += n, e += n; b < e; b /= 2, e /= 2) {
			if (b % 2) ra = f(ra, s[b++]);
			if (e % 2) rb = f(s[--e], rb);
		}
		return f(ra, rb);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n; cin >> n;
		vector<int> pos(n), wt(n), where(n+1);
		for (int &x : pos) cin >> x;
		for (int &x : wt) cin >> x;

		for (int i = 0; i < n; ++i) where[wt[i]] = i;

		Zp ans = 0;
		SegTree seg(n);
		for (int i = 1; i <= n; ++i) {
			int u = where[i];
			/**
			 * For v > u, (pos[v] - pos[u])*2^(n-i + wt[v]-1)
			 * 2^(n-i) is constant
			 * (pos[v] - pos[u]) * 2^(wt[v] - 1)
			 * = pos[v]*(2 ^ (wt[v] - 1)) - pos[u]*2^(wt[v] - 1)
			 * 
			 * For v < u, similar
			 */
			auto right = seg.query(u+1, n);
			Zp pw = Zp(2) ^ (n - i);
			ans += pw * (right.pospowsum - right.powsum*pos[u]);

			auto left = seg.query(0, u);
			ans += pw * (left.powsum*pos[u] - left.pospowsum);

			pw = Zp(2) ^ (i - 1);
			Data cur = {pw, pw * pos[u]};
			seg.update(u, cur);
		}
		ans /= Zp(2) ^ n;

		cout << ans << '\n';
	}
}
Tester's code (C++)
#pragma GCC optimisation("O3")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_ADD=1e18;
#define pb push_back               
#define mp make_pair        
#define nline "\n"                           
#define f first                                          
#define s second                                               
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>       
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}   
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;    
const ll MAX=200200; 
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
struct FenwickTree{
    vector<ll> bit; 
    ll n;
    FenwickTree(ll n){
        this->n = n;
        bit.assign(n, 0);
    }
    FenwickTree(vector<ll> a):FenwickTree(a.size()){
        ll x=a.size();
        for(size_t i=0;i<x;i++)
            add(i,a[i]);
    }
    ll sum(ll r) {
        ll ret=0;
        for(;r>=0;r=(r&(r+1))-1)
            ret+=bit[r];
        ret%=MOD; 
        return ret;
    }
    ll sum(ll l,ll r) {
        if(l>r)
            return 0;
        return sum(r)-sum(l-1)+MOD;
    }
    void add(ll idx,ll delta) {
        for(;idx<n;idx=idx|(idx+1))
            bit[idx]+=delta;
    }
}; 
void solve(){               
    ll n; cin>>n;
    vector<ll> x(n),w(n);
    vector<pair<ll,ll>> track;
    for(ll i=0;i<n;i++){
        cin>>x[i];
    }
    for(ll i=0;i<n;i++){
        cin>>w[i];
        track.push_back({w[i],i});
    }
    sort(all(track));
    vector<ll> new_x(n);
    for(ll i=0;i<n;i++){
        auto it=track[i];
        new_x[i]=x[it.s];
    }
    x=new_x;
    track.clear();
    for(ll i=0;i<n;i++){
        track.push_back({x[i],i});
    }
    sort(all(track));
    vector<ll> pos(n,0);
    for(ll i=0;i<n;i++){
        pos[track[i].s]=i; 
    }
    ll ans=0;
    FenwickTree get_freq(n);
    FenwickTree get_sum(n);
    for(ll i=0;i<n;i++){
        ll p=pos[i];
        ll lft=(x[i]*get_freq.sum(0,p)-get_sum.sum(0,p))%MOD;
        ll rght=(get_sum.sum(p,n-1)-x[i]*get_freq.sum(p,n-1))%MOD;
        ll now=(lft+rght)*binpow(2,n-i-1,MOD);
        ans=(ans+now)%MOD;
        ans=(ans+MOD)%MOD;
        ll ways=binpow(2,i,MOD);
        get_freq.add(p,ways);
        ways=(ways*x[i])%MOD;
        get_sum.add(p,ways);
    }
    ll den=binpow(2,n,MOD);
    ans=(ans*inverse(den,MOD))%MOD;
    cout<<ans<<nline; 
    return;          
}                                               
int main()                                                                            
{                              
    ios_base::sync_with_stdio(false);                            
    cin.tie(NULL);                       
    #ifndef ONLINE_JUDGE               
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);    
    freopen("error.txt", "w", stderr);                        
    #endif        
    ll test_cases=1;                 
    cin>>test_cases; 
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(9);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}