Array Description

I was trying this question on CSES. Everytime I start dp problem by thinking recursively. I did the same here, but my recursive solution isn’t working and I am not able to figure out why. So it will be very helpful if one can have a look at my code.

#include<bits/stdc++.h>
#define int long long int
#define pb push_back
#define vi vector<int>
#define vb vector<bool>
#define vd vector<double>
#define vc vector<char>
#define vii vector<vi>
#define mp make_pair
#define vpi vector< pair<int, int> >
#define take_input freopen("input.txt", "r", stdin)
#define give_output freopen("output.txt", "w", stdout)
#define fastIO ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
#define fi first
#define se second
#define mod 1000000007
#define min_pql priority_queue< int, vector<int>, greater<int> >

using namespace std;
using namespace std::chrono;

int solve(int n, int m, vi &arr, int i=0) {
    if(i == n) return 1;
    int cnt=0;
    if(arr[i] == 0) {
        if(i==0) {
            if( i+1<n && arr[i+1]!=0) {
                for(int j=-1; j<=1; j++){
                    if( (arr[i+1]+j>=1) && (arr[i+1]+j<=m)) {
                        arr[i] = arr[i+1]+j;
                        cnt += solve(n, m, arr, i+1);
                        arr[i] -= (arr[i+1]+j);
                    }
                }
            } else {
                for(int j=1; j<=m; j++){
                    arr[i] = j;
                    cnt += solve(n, m, arr, i+1);
                    arr[i] -= j;
                }
            } 
        } else {
            set<int> ind;
            if(i+1<n && arr[i+1] != 0) {
                for(int j=-1; j<=1; j++){
                    ind.insert(arr[i+1]+j);
                    ind.insert(arr[i-1]+j);
                }
            } else {
                for(int j=-1; j<=1; j++) {
                    ind.insert(arr[i-1]+j);
                }
            } 
            for(int x:ind) {
                arr[i] = x;
                if(x>=1 && x<=m) cnt += solve(n, m, arr, i+1);
                arr[i] -= x;
            }
        }
    } else {
        cnt += solve(n, m, arr, i+1);
    }
    return cnt;

}

int32_t main(){
    fastIO;
    //take_input;
    //give_output;
    int n, m; cin >> n >> m;
    vi arr(n);
    for(int &i:arr) cin >> i;
    cout << solve(n, m, arr);
}

Same Here!!

I also stuck in the same problem.
my program gives correct output in lower range but for higher range it gives WA.
I don’t know exactly what’s wrong in your code but maybe I can explain my intuition , by which you can get an idea .
my code https://cses.fi/paste/bfc818186c52bd982933be/


Well, here is my explanation for input
2 2
0 1
Here, the valid array would be
2 1
1 1
note: I have taken an array of size n+1 (assigning the element in array from index 1 to n)

total method dp(n,1)+dp(n,2)+dp(n,3)…dp(n,m)

a particular state dp(i,x) is possible if at index i-1 we get x-1||x||x+1

means dp(i-1,x-1)+dp(i-1,x)+dp(i-1,x+1)

recursion tree

If anyone have solve this problem using top-down approach, then, Please help!!

Finally solved it.

Recursive Code: https://cses.fi/paste/be9b8cc2a7103fea296010/

Iterative Code: https://cses.fi/paste/bf81fc70261c49e5295ffa/

5 Likes

A little bit different memoization solution.The prev variable here carries the suffix element of the array, when 0 > n < v.size() -1. You can use a single count variable instead of take,notTake.

Here, I’ve also handled the case separately, when n=1

ll fn(vector<ll> &v, ll n, ll m, ll prev, vector<vector<ll>> &dp){

    if(n==0){
        if(v[n]==0){
            if(prev==1||prev==m) return 2;
            return 3;
        }
       
        if(v[n] != 0 && abs(v[n]-prev)<2) return 1;
        return 0;
    }

    if(dp[n][prev] != -1) return dp[n][prev];

    ll take=0,notTake=0;

    if(n == v.size()-1){
        if(v[n] != 0){
            notTake += fn(v, n-1, m, v[n], dp);
        }
        else{
            for(ll i=1; i<=m; ++i){
                take += fn(v, n-1, m, i, dp);
            }
        }
    }

    else{
        if(v[n] != 0){
            if(abs(v[n]-prev)<2){
                notTake += fn(v, n-1, m, v[n], dp);
            }
        }

        else{
            ll lb = prev == 1? 1: prev-1;
            ll rb = prev == m? m: prev+1;

            for(ll i=lb; i<=rb; ++i){
                take += fn(v, n-1, m, i, dp);
            }
        }
    }

    dp[n][prev] = (take+notTake) % mod;
    return dp[n][prev];
}

int main(){

    ll n,m;
    cin >> n >> m;
    vector<ll> v(n);

    for(ll i=0; i<n; ++i){
        cin >> v[i];
    }

    ll ans = 0;

    if(n == 1){
        if(v[0] == 0) ans = m;
        else ans = 1;
    }

    vector<vector<ll>> dp(n, vector<ll>(m+1, -1));

    ll prev = m;
    if(n > 1){
        ans = fn(v,n-1,m,prev,dp);
    }
    cout << ans;
    cout << '\n';

	return 0;
}





1 Like

I also got WA for a higher range. How you fixed it?

Nice information. Thanks for update about iso compliance

It literally helped me, I’m stuck on this question for two days.

1 Like