ARRAY_BREAK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: yash_daga
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

2722

PREREQUISITES:

Dynamic programming, prefix sums, basic probability

PROBLEM:

You’re given an array A and a string S which contains only L and R.
For each element of the string in order:

  • If the array has size 1, do nothing.
  • Randomly split A into two non-empty subarrays, each split being equiprobable.
  • If S_i is L, discard the right part; otherwise discard the left part.

Find the expected sum of the final array.

EXPLANATION:

Every move, we cut out either a prefix or a suffix of the remaining array.
That means in the end, we’ll be left with some subarray of A.
In fact, at any stage of the process we’ll have a subarray of A.

Let p(L, R) be the probability that we’re left with subarray [A_L, A_{L+1}, \ldots, A_R] after all the moves.
Then, by definition of expectation, the final answer is just \sum P(L, R)\cdot (A_L + A_{L+1} + \ldots + A_R) across all 1 \leq L \leq R \leq N.

So, it’s enough for us to find all the p(L, R) values.


The remaining subarray changes after each move, so we also need information about which move we’re on.
Let p(L, R, k) be the probability that after k moves, we have subarray A[L:R] with us.
Initially, we have p(1, N, 0) = 1 and p(L, R, 0) = 0 for all other (L, R).

For k \geq 1, let’s look at all possible moves.

  • Suppose S_k = \texttt{L}, meaning we must keep the left subarray.
    Then, subarray [L, R] can be obtained by cutting [L, x] for some R \lt x. So,
p(L, R, k) = \sum_{x = R+1}^N p(L, x, k-1) \cdot \frac{1}{x-L}
  • If L = R, we’ll also have an extra p(L, L, k-1) term to account for size-1 subarrays not being splittable.
  • If S_k = \texttt{R}, transitions are similar: consider all subarrays [x, R] for 1 \leq x \lt L.

Notice that we have \mathcal{O}(N^2 K) states, and \mathcal{O}(N) transitions for each one.
Memoizing states thus allows for a \mathcal{O}(N^3 K) dynamic programming solution, which is enough to pass subtask 1.


To optimize this solution further, it seems our only hope is to make transitions faster: all the states seem necessary.

One way to do this is to look at transitions from the opposite side.
Instead of looking at p(L, R, k) in terms of sums of p(L, x, k-1), let’s look at how p(L, R, k) contributes to p(L, x, k+1)

So, suppose [L, R] is fixed, and S_{k+1} = \texttt{L}, i.e, we keep the left part of the subarray.
There’s a \frac{1}{R-L} chance of obtaining each of [L, L], [L, L+1], \ldots, [L, R-1].
So, we add p(L, R, k) \cdot \frac{1}{R-L} to each of p(L, L, k+1), p(L, L+1, k+1), \ldots, p(L, R-1, k+1).

Notice that L is fixed for all of these, so we’re essentially just adding a constant value to a range!

Processing Q range-add updates on an array of length N can be done in \mathcal{O}(N + Q) time using prefix sums.

How?

Suppose we want to process Q updates on the array B.

Create an auxiliary array C of range N, initially filled with zeros.
Then, for each update (L, R, x) (meaning we want to add x to the range [L, R] of B),

  • Increase C_L by x
  • Decrease C_{R+1} by x

Finally, take the prefix sums of array C.
The i-th prefix sum of C is the amount by which B_i will change.

A slightly more detailed explanation of why this works can be found in the “How?” section here.

In our case we perform one update for each (L, R), for about N^2 updates in total.
So, for each move, we’re now down to \mathcal{O}(N^2) operations in total!

This optimization makes the solution \mathcal{O}(N^2 K).

Note that once again, a little care is needed when dealing with L = R.

TIME COMPLEXITY

\mathcal{O}(KN^2) per testcase.

CODE:

Author's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Mod wale question mein last mein if dalo ie. Ans<0 then ans+=mod;
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define mod 1000000007ll //998244353ll
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
const long long N=200005, INF=2000000000000000000;
const int inf=2e9 + 5;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}
const int MOD=mod;
struct Mint {
    int val;
 
    Mint(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
        if (v >= MOD)
            v %= MOD;
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        } 
        return x < 0 ? x + m : x;
    } 
    explicit operator int() const {
        return val;
    }
    Mint& operator+=(const Mint &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    Mint& operator-=(const Mint &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
    Mint& operator*=(const Mint &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
    Mint& operator/=(const Mint &other) {
        return *this *= other.inv();
    }
    friend Mint operator+(const Mint &a, const Mint &b) { return Mint(a) += b; }
    friend Mint operator-(const Mint &a, const Mint &b) { return Mint(a) -= b; }
    friend Mint operator*(const Mint &a, const Mint &b) { return Mint(a) *= b; }
    friend Mint operator/(const Mint &a, const Mint &b) { return Mint(a) /= b; }
    Mint& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
    Mint& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
    // friend Mint operator<=(const Mint &a, const Mint &b) { return (int)a <= (int)b; }
    Mint operator++(int32_t) { Mint before = *this; ++*this; return before; }
    Mint operator--(int32_t) { Mint before = *this; --*this; return before; }
    Mint operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
    bool operator==(const Mint &other) const { return val == other.val; }
    bool operator!=(const Mint &other) const { return val != other.val; }
    Mint inv() const {
        return mod_inv(val);
    }
    Mint power(long long p) const {
        assert(p >= 0);
        Mint a = *this, result = 1;
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
        return result;
    }
    friend ostream& operator << (ostream &stream, const Mint &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, Mint &m) {
        return stream>>m.val;   
    }
};

Mint dp[500][500];
int32_t main()
{
	IOS;
	int t;
	cin>>t;
	while(t--)
	{
		int n, k;
		cin>>n>>k;
		int a[n];
		rep(i,0,n)
		cin>>a[i];
		string s;
		cin>>s;
		rep(i,0,n)
		{
			rep(j,0,n)
			dp[i][j]=0;
		}
		dp[0][n-1]=1;
		Mint inv[n+1];
		inv[0]=1;
		for(int i=1;i<=n;i++)
			inv[i]=(Mint)1/i;
		for(int i=1;i<=k;i++)
		{
			if(s[i-1]=='L')
			{
				for(int l=0;l<n;l++)
				{
					Mint temp=dp[l][n-1];
					for(int r=n-1;r>l;r--)
					{
						dp[l][r]-=temp;
						Mint temp2=dp[l][r-1];
						dp[l][r-1]+=((temp*inv[r-l])+dp[l][r]);
						temp=temp2;
					}
				}
			}
			else
			{
				for(int r=0;r<n;r++)
				{
					Mint temp=dp[0][r];
					for(int l=0;l<r;l++)
					{
						dp[l][r]-=temp;
						Mint temp2=dp[l+1][r];
						dp[l+1][r]+=((temp*inv[r-l])+dp[l][r]);
						temp=temp2;
					}
				}
			}
		}
		Mint ans=0;
		for(int i=0;i<n;i++)
		{
			Mint sum=0;
			for(int j=i;j<n;j++)
			{
				sum+=a[j];
				ans+=(sum*dp[i][j]);
			}
		}
		cout<<ans<<"\n";
	}
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
 
 
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=2e18; 
#define pb push_back                   
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif       
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=1e9+7;     
const ll MAX=500500; 
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
void solve(){     
    ll n,k; cin>>n>>k;
    vector<ll> a(n+5,0);
    for(ll i=1;i<=n;i++){
        cin>>a[i];
    }
    vector<ll> inv(n+5);
    for(ll i=1;i<=n;i++){
        inv[i]=inverse(i,MOD); 
    }
    string s; cin>>s;
    vector<vector<ll>> dp(n+5,vector<ll>(n+5,0));
    dp[1][n]=1;
    for(auto it:s){
        vector<vector<ll>> adp(n+5,vector<ll>(n+5,0));
        for(ll i=1;i<=n;i++){
            adp[i][i]=dp[i][i];
        }
        if(it=='L'){
            vector<ll> sum(n+5,0);
            for(ll i=1;i<=n;i++){
                for(ll len=n;len>=1;len--){
                    ll l=i,r=i+len-1;
                    if(r>=n+1){
                        continue; 
                    }
                    adp[l][r]=(adp[l][r]+sum[l])%MOD;
                    sum[l]=(sum[l]+dp[l][r]*inv[len-1])%MOD;
                }
            }
        }
        else{
            vector<ll> sum(n+5,0);
            for(ll i=n;i>=1;i--){
                for(ll len=n;len>=1;len--){
                    ll l=i-len+1,r=i;
                    if(l<=0){
                        continue; 
                    }
                    adp[l][r]=(adp[l][r]+sum[r])%MOD;
                    sum[r]=(sum[r]+dp[l][r]*inv[len-1])%MOD;
                }
            }
        }
        swap(dp,adp);
    }
    ll ans=0;
    for(ll i=1;i<=n;i++){
        ll sum=0;
        for(ll j=i;j<=n;j++){
            sum=(sum+a[j])%MOD;
            ans=(ans+dp[i][j]*sum)%MOD;
        }
    }
    cout<<ans<<nline;
    return;                                
}                                                  
int main()                                                                                                 
{            
    ios_base::sync_with_stdio(false);                             
    cin.tie(NULL);        
    #ifndef ONLINE_JUDGE                     
    freopen("input.txt", "r", stdin);                                                    
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif                          
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    } 
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}   
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 1e9 + 7;

int pw(int a, int n) {
	int ret = 1;
	while (n) {
		if (n & 1) ret = (1LL * ret * a) % mod;
		a = (1LL * a * a) % mod;
		n /= 2;
	}
	return ret;
}


int main()
{
    ios::sync_with_stdio(false); cin.tie(0);
	
	vector<int> inv(1000);
	for (int i = 1; i < 1000; ++i) inv[i] = pw(i, mod-2);
    
    int t; cin >> t;
	while (t--) {
		int n, k; cin >> n >> k;
		vector<int> a(n);
		for (int &x : a) cin >> x;
		string s; cin >> s;


		vector dp(n, vector(n+1, 0));
		dp[0][n-1] = 1;
		for (int i = 0; i < k; ++i) {
			vector ndp(n, vector(n+1, 0));
			for (int L = 0; L < n; ++L) for (int R = L; R < n; ++R) {
				if (L == R) {
					ndp[L][R] = (ndp[L][R] + dp[L][R]) % mod;
					if (s[i] == 'L') ndp[L][R+1] = (ndp[L][R+1] + mod - dp[L][R]) % mod;
					else if (L > 0) ndp[L-1][L] = (ndp[L-1][L] + mod - dp[L][R]) % mod;
				}
				else {
					int val = (1LL * dp[L][R] * inv[R-L]) % mod;
					if (s[i] == 'L') ndp[L][L] = (ndp[L][L] + val) % mod;
					else ndp[R][R] = (ndp[R][R] + val) % mod;
					ndp[L][R] = (ndp[L][R] - val + mod) % mod;
				}
			}

			if (s[i] == 'L') {
				for (int L = 0; L < n; ++L) {
					for (int R = 1; R < n; ++R) {
						ndp[L][R] = (ndp[L][R] + ndp[L][R-1]) % mod;
					}
				}
			}
			else {
				for (int R = 0; R < n; ++R) {
					for (int L = n-2; L >= 0; --L) {
						ndp[L][R] = (ndp[L][R] + ndp[L+1][R]) % mod;
					}
				}
			}
			swap(dp, ndp);
		}
		ll ans = 0;
		for (int L = 0; L < n; ++L) {
			ll sum = 0;
			for (int R = L; R < n; ++R) {
				sum += a[R];
				sum %= mod;

				ans += 1LL * sum * dp[L][R];
				ans %= mod;
			}
		}
		cout << ans << '\n';
	}
}
2 Likes

Anyone who has written memoisation solution of this. Please share

1 Like

can someone explain testcase 2 for me pls?

3 Likes

We have to multiply by the probability of choice of each path of the tree while calculation the answer.

So, we had to output 31 * power(7, M-2) % M, right? That comes out to be 428571436 for me. Whereas the correct answer given is 416666673. I spent a lot of time thinking about where I was going wrong. Can you help?

2 Likes

We have to multiply by the probability of choosing each path of the tree.

Sorry, I do not understand. What should be the answer according to you not under any modulo. Is it not 31/7?

No.
There are 7 possible final subarrays, sure.
However, it seems you’re assuming they all have the same possibility of appearing, and that isn’t true here.

You can see the full decision tree in this comment.

3 Likes

Much thanks!

Thankyou iceknight for the well written editorial!!

pls someone check what’s wrong in my solution.
submission id: 98844641