PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter:
Tester: Rahul Dugar
Editorialist: Taranpreet Singh
DIFFICULTY
Medium
PREREQUISITES
Probability and Expectation, Dynamic Programming Optimization
PROBLEM
Given a communication channel with N+1 stations, arranged in order where station i passes message to station i+1 for 0 \leq i < N. The probability of success of transfer from station i to station i+1 is given by F_{i+1}/F_i If the communication fails, the message again has to be fetched from station 0.
You are allowed to install memory disks on at most K stations such that if communication fails at station x, you can fetch message from nearest station to the left.
Find the minimum expected time for the transfer of message from station 0 to station N, assuming passing a message from one station to next takes exactly one message.
QUICK EXPLANATION
- For K = 0, we just need to find the expected time to pass message from station 0 to station N, which is given by \displaystyle \frac{\sum_{i = 0}^{N-1} F_i}{F_N}
- With K > 0, once the message reaches station with memory disk, we never need to go back, so the expected time becomes the sum of expected times to pass message from station 0 to station v_1, from v_1 to v_2 \ldots , from v_k to N.
- Simple Dynamic Programming would be to compute minimum expected time to send message from station 0 to station u, while using K memory disk, last one installed at station u.
- This DP can be sped up by Divide and Conquer optimization, since the quadrangle inequality is satisfied by cost function.
EXPLANATION
Solving without Memory disks
Let’s assume K = 0 and find the expected time taken for message from station 0 to station N.
There can be following cases,
- The message is delivered successfully
- The message fails on station i
Let’s denote E_x as the expected number of steps to send message from x to N
What’s the probability of a message to be sent from station 0 to station N. It is \displaystyle \prod_{i = 0}^{N-1} p_i = \prod_{i = 0}^{N-1} \frac{F_{i+1}}{F_i} = \frac{F_N}{F_0}
Let’s consider N = 3 for now, and number edges from station i to station I+1 with number i.
Let’s assume the transmission fails, and we come back to station 0.
- Edge numbered 0 will be visited with probability 1 and take 1 second
- Edge numbered 1 will be visited with probability p_0 and take 1 second
- Edge numbered 2 will be visited with probability p_0 * p_1 and take 1 second
We consider contribution of each edge separately.
Now there are two cases,
- transmission succeeds, in which case 0 more steps are required.
- transmission fails, in which case E_0 steps are required, Probability (1-p_0*p_1*p_2)
Hence, We can write E_0 = 1+p_0+p_0 \times p_1 + (1-p_0 \times p_1 \times p_2) \times E_0
Solving this, we get \displaystyle E_0 = \frac{1+p_0+p_0 \times p_1}{p_0 \times p_1 \times p_2} = \frac{F_0+F_1+F_2}{F_3}
Generalizing, the expected time to transfer message from station 0 to station N with no memory disks in between is \displaystyle \frac{\sum_{i = 0}^{N-1} F_i}{F_N}.
Memory disks allowed
Now, with memory disks, once a message reached station u having memory disk, we never need to go behind station u. Let’s assume Memory disks are installed at nodes v_0 = 0, v_1, v_2 \ldots v_k, v_{k+1} = N
So the expected time with these memory disks is given by \displaystyle \sum_{l = 0}^{k} time(v_l, v_{l+1}) where \displaystyle time(u, v) = \frac{\sum_{x = u}^{v-1} F_x}{F_v} denotes expected time from u to v with disk installed at station u.
All we need to do is to choose the stations with memory disks optimally and compute cost.
The naive Dynamic programming can be applied here, with state (last, count) storing the minimum expected time to send message from station 0 to station last, if there are count memory disks used, last one being at station last.
The DP recurrence would be f(last, count) = min_{prev < last}(f(prev, count-1) + time(prev, last)), base case f(0, 0) = 0. We need to compute f(N, K+1).
This DP has N*K states and O(N) iterations per state, leading to O(N^2*K) solution, which would TLE.
Optimizing DP
We have relation F_{i+1} \leq F_i, so It’s worth exploring quadrangle inequalities, checking if some DP optimization applies to our problem.
A sufficient (but not necessary) condition for Divide and Conquer optimization is cost(a, c)+cost(b, d) \leq cost(a, d)+cost(b, c)
Let’s assume a < b < c < d and \displaystyle S_i = \sum_{j = 0}^i F_i. So \displaystyle cost(u, v) = \frac{S_{v-1}-S_{u-1}}{F_v}
We get inequality
\displaystyle \frac{S_{c-1}-S_{a-1}}{F_c} + \frac{S_{d-1}-S_{b-1}}{F_d} \leq \frac{S_{d-1}-S_{a-1}}{F_d} + \frac{S_{c-1}-S_{b-1}}{F_c}
\displaystyle \frac{S_{b-1}-S_{a-1}}{F_c} \leq \frac{S_{b-1}-S_{a-1}}{F_d} \implies F_d \leq F_c which is true.
Hence, Quadrangle condition applies, so we can directly apply Divide and Conquer optimization to speed up calculation of DP.
Following is a list of decent resources for Divide and Conquer optimization, with several problems.
CF blog (do read comments)
cp-algorithms (contains implementation)
Answer by Michael Levin
Also, Arjun Arul has given will be giving a Lecture in IPC camp 2020 Day 3 in advanced track, which covers this optimization.
TIME COMPLEXITY
The time complexity is O(K*N*log(N)) per test case.
Memory complexity O(K*N), which can be optimized to O(N) since only last layer is needed to compute current layer.
SOLUTIONS
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld double
template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
os<<"("<<p.first<<", "<<p.second<<")";
return os;
}
template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
os<<"{";
for(int i = 0;i < (int)v.size(); i++){
if(i)os<<", ";
os<<v[i];
}
os<<"}";
return os;
}
#ifdef LOCAL
#define cerr cout
#else
#endif
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l<=x && x<=r))cerr<<l<<"<="<<x<<"<="<<r<<endl;
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
template<class T>
vector<T> readVector(int n, long long l, long long r){
vector<T> ret(n);
for(int i = 0; i < n; i++){
ret[i] = i == n - 1 ? readIntLn(l, r) : readIntSp(l, r);
}
return ret;
}
const ll INF = 1e12;
const int N = 100005;
ld dp[N], dp2[N];
int main(){
int t = readIntLn(1, 100000);
int sn = 0;
while(t--){
int n = readIntSp(1, 100000);
int k = readIntLn(0, 50);
sn += n;
assert(sn <= 100000);
vector<ll> F = readVector<ll>(n + 1, 1, INF);
vector<ll> prefix = F;
for(int i = 1; i <= n; i++) prefix[i] += prefix[i - 1];
function<ld(int, int)> cost = [&](int i, int j){
if(i == j) return 0.;
return (prefix[j - 1] - (i == 0 ? 0 : prefix[i - 1])) / (ld) F[j];
};
fill(dp, dp + n + 1, 1e18);
dp[0] = 0;
function<void(int, int, int, int)> recurse = [&](int l, int r, int opt_l, int opt_r){
int mid = (l + r) >> 1;
dp[mid] = 1e18;
int opt = mid;
for(int j = opt_l; j <= opt_r && j <= mid; j++){
ld val = dp2[j] + cost(j, mid);
if(val < dp[mid]){
dp[mid] = val;
opt = j;
}
}
if(l == r) return;
if(l <= mid - 1) recurse(l, mid - 1, opt_l, opt);
if(mid + 1 <= r) recurse(mid + 1, r, opt, opt_r);
};
for(int i = 1; i <= k + 1; i++){
for(int j = 0; j <= n; j++) dp2[j] = dp[j];
recurse(1, n, 0, n);
}
cout << setprecision(10) << fixed << dp[n] << endl;
}
}
Tester's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
//const int mod=998244353;
const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
uniform_int_distribution<int> uid(0,lim-1);
return uid(rang);
}
int powm(int a, int b) {
int res=1;
while(b) {
if(b&1)
res=(res*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return res;
}
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
int f[100005];
int p[100005];
double C(int i, int j) {
return (p[j]-p[i])/((double)f[j]);
}
vector<double> dp_before,dp_cur;
const ll INF=1e18;
void compute(int l, int r, int optl, int optr) {
if(l>r)
return;
int mid=(l+r)>>1;
pair<double, int> best={INF,-1};
for(int k=optl; k<=min(mid,optr); k++)
best=min(best,{dp_before[k]+C(k,mid),k});
dp_cur[mid]=best.first;
int opt=best.second;
compute(l,mid-1,optl,opt);
compute(mid+1,r,opt,optr);
}
int sum_n=0;
void solve() {
int n,k;
n=readIntSp(1,100000);
sum_n+=n;
assert(sum_n<=100000);
k=readIntLn(0,min(50LL,n-1));
fr(i,0,n) {
if(i!=n)
f[i]=readIntSp(1,1000'000'000'000LL);
else
f[i]=readIntLn(1,1000'000'000'000LL);
if(i)
assert(f[i]<=f[i-1]);
}
fr(i,1,n+1)
p[i]=p[i-1]+f[i-1];
dp_cur.resize(n+1);
fr(i,0,n)
dp_cur[i]=C(0,i);
fr(i,1,k) {
dp_before=dp_cur;
compute(0,n,0,n);
}
cout<<dp_cur[n]<<endl;
}
signed main() {
ios_base::sync_with_stdio(0),cin.tie(0);
srand(chrono::high_resolution_clock::now().time_since_epoch().count());
cout<<fixed<<setprecision(7);
int t=readIntLn(1,100000);
// cin>>t;
fr(i,1,t)
solve();
assert(getchar()==EOF);
#ifdef rd
// cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.DecimalFormat;
class CHANNEL{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), K = ni()+1;
long[] F = new long[1+N];
for(int i = 0; i<= N; i++){
F[i] = nl();
if(i > 0)F[i] += F[i-1];
}
double[][] DP = new double[1+K][1+N];
for(int i = 0; i<= K; i++)Arrays.fill(DP[i], Long.MAX_VALUE);
DP[0][0] = 0;
for(int layer = 1; layer <= K; layer++)
compute(DP[layer], DP[layer-1], F, 1, N, 0, N);
pn(new DecimalFormat("0.000000").format(DP[K][N]));
}
double cost(long[] F, int l, int r){
return (F[r-1]-(l == 0?0:F[l-1]))*1.0/(F[r]-F[r-1]);
}
void compute(double[] DP_new, double[] DP_prev, long[] F, int l, int r, int optl, int optr){
if(l > r)return;
int mid = (l+r)>>1;
int bestPos = -1;
double bestCost = Long.MAX_VALUE;
for(int k = optl; k <= Math.min(mid, optr); k++){
double cost = DP_prev[k] + cost(F, k, mid);
if(bestCost > cost){
bestPos = k;
bestCost = cost;
}
}
DP_new[mid] = bestCost;
compute(DP_new, DP_prev, F, l, mid-1, optl, bestPos);
compute(DP_new, DP_prev, F, mid+1, r, bestPos, optr);
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new CHANNEL().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
VIDEO EDITORIAL:
Feel free to share your approach. Suggestions are welcomed as always.