GOODBINTREE - Editorial

PROBLEM LINK:

Practice
Div-3 Contest
Div-2 Contest
Div-1 Contest

Author: Vishesh Saraswat, Harshikaa Agrawal
Tester: Istvan Nagy
Editorialist: Harshikaa Agrawal

Difficulty:

Easy

Prerequisites:

2D Dynamic Programming, Prefix Sums, Modular Arithmetic (Only for output)

Problem:

Given that there is a perfect binary tree with N nodes, and an integer M. Nodes can be assigned values such that the tree is considered good. The rules for which are,

  • Nodes’ values are positive integers no more than M
  • Nodes at even levels have values strictly more than their parents’ values.
  • Nodes at odd levels have values strictly less than their parents’ values.

Find the total number of possible good trees by assigning nodes values, output answer modulo 10^9+7.

Note:

  • The root of the tree is at layer 1
  • Two assignments are different if there is at least one node with different values in both assignments.
  • You may need to use 64-bit data types in your programming language to take input.

Explanation:

From N, we can find the height h, of the binary tree, since N = (2^h)-1.

A good tree is odd level < even level > odd level < even level > odd level…

Consider DP1 to store the number of trees of the form odd level > even level < odd level > even level …
And, consider DP2 to store the number of good trees of the form odd level < even level > odd level < even level > odd level…

DP1 for some height h, can be thought of as some node (at level 1) attached to which are 2 trees from DP2 with height h-1 (odd level < even level > odd level < even level… form)

Similarly, DP2 for some height can be thought of as some node (at level 1) attached to which are 2 trees from DP1 with height h-1 (odd level > even level < odd level > even level… form)

Let the node value at the root for some tree be x, where x \leq M, and some height y, where 1 \leq y \leq h.
Thus, DP1[y][x] = (ΣDP2[y-1][i])^2 (where, 1 \leq i < x)
And DP2[y][x] = (ΣDP1[y-1][i])^2 (where, x < i \leq M)

Here, base case is that DP1[1][j] = 1, where j \leq M
And DP2[1][j] = 1, where j \leq M

Output would be, ΣDP2[h][k], where k \leq M

Time Complexity:

O(log_2 N * M) for each test case

Setter’s Solution

Setter's Solution
#include "bits/stdc++.h"
using namespace std;
/*
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using ordered_set = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>;
*/

#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
#define sz(x) (int)(x).size()

using ll = long long;
#define int ll
const int mod = 1e9+7;

void solve(int tc) {
    int n, m;
    cin >> n >> m;
    ++n;
    int h = 0;
    while (n%2==0) {
        ++h;
        n/=2;
    }
    vector<vector<int>> dp1(h+1, vector<int>(m+1)), dp2(h+1, vector<int>(m+1));
    // dp1[i][j] = number of trees height i with ith > i+1 and root = m
    // dp2[i][j] = number of trees height i with ith < i+1 and root = m
    // root is ith
    for (int i = 1; i <= m; ++i)
        dp1[1][i] = dp2[1][i] = 1;
    for (int i = 2; i <= h; ++i) {
        int cursum = 0;
        for (int j = 1; j <= m; ++j) {
            dp1[i][j] = (cursum * cursum);
            dp1[i][j] %= mod;
            cursum += dp2[i-1][j];
            cursum %= mod;
        }
        cursum = 0;
        for (int j = m; j >= 1; --j) {
            dp2[i][j] = (cursum * cursum);
            dp2[i][j] %= mod;
            cursum += dp1[i-1][j];
            cursum %= mod;
        }
    }
    int ans = 0;
    for (int j = 1; j <= m; ++j) {
        ans += dp2[h][j];
        ans %= mod;
    }
    cout << ans << '\n';
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    int tc = 1;
    cin >> tc;
    for (int i = 1; i <= tc; ++i) solve(i);
    return 0;
}

Tester’s Solution

Tester's Solution
#include <iostream>
#include <algorithm>
#include <string>
#include <cassert>
#include <vector>
#include <numeric>
using namespace std;

#ifdef HOME
#define NOMINMAX
#include <windows.h>
#endif

long long readInt(long long l, long long r, char endd) {
	long long x = 0;
	int cnt = 0;
	int fi = -1;
	bool is_neg = false;
	while (true) {
		char g = getchar();
		if (g == '-') {
			assert(fi == -1);
			is_neg = true;
			continue;
		}
		if ('0' <= g && g <= '9') {
			x *= 10;
			x += g - '0';
			if (cnt == 0) {
				fi = g - '0';
			}
			cnt++;
			assert(fi != 0 || cnt == 1);
			assert(fi != 0 || is_neg == false);

			assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
		}
		else if (g == endd) {
			assert(cnt > 0);
			if (is_neg) {
				x = -x;
			}
			assert(l <= x && x <= r);
			return x;
		}
		else {
			//assert(false);
		}
	}
}

string readString(int l, int r, char endd) {
	string ret = "";
	int cnt = 0;
	while (true) {
		char g = getchar();
		assert(g != -1);
		if (g == endd) {
			break;
		}
		cnt++;
		ret += g;
	}
	assert(l <= cnt && cnt <= r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
	return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
	return readString(l, r, ' ');
}

int main() {
#ifdef HOME
	if (IsDebuggerPresent())
	{
		freopen("../in.txt", "rb", stdin);
		freopen("../out.txt", "wb", stdout);
	}
#endif
	int T = readIntLn(1, 100);
	const uint64_t Mod = 1'000'000'007;
	for (int tc = 0; tc < T; ++tc)
	{
		uint64_t N = readIntSp(1, ((1ull) << 59));
		++N;
		int M = readIntLn(1, 1'000);
		assert((N&(N-1)) == 0);
		uint32_t level = 1;
		while (((1ull) << level) != N)
			++level;
		vector<uint64_t> a(M, 1), b(M);
		for (uint32_t currLevel = 1; currLevel < level; ++currLevel)
		{
			uint64_t su = 0;
			for (uint32_t i = 0; i < M; ++i)
			{
				b[i] = (su * su) % Mod;
				su = (su + a[i]) % Mod;
			}
			std::reverse(b.begin(), b.end());
			a.swap(b);
			std::fill(b.begin(), b.end(), 0);
		}
		uint64_t res = std::accumulate(a.begin(), a.end(), 0ull);
		res %= Mod;
		printf("%llu\n", res);
	}
	assert(getchar() == -1);
}

Editorialist’s Solution:

Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

int main() {
    long long t;
    cin>>t;
    while(t--)
    {
        long long n, m;
        cin>>n>>m;
        n++;
        long long height = 0;
        while(n!=0)
        {
            height++;
            n /= 2;
        }
        height--;
        
        long long dp1[height+1][m+1];
        long long dp2[height+1][m+1];
        
        for(long long i = 1; i <= m; i++)
        {
            dp1[1][i] = 1;
            dp2[1][i] = 1;
        }
        
        for(long long i = 2; i <= height; i++)
        {
            long long tempsum = 0;
            for(long long j = 1; j <= m; j++)
            {
                dp1[i][j] = tempsum*tempsum;
                dp1[i][j] %= 1000000007;
                tempsum += dp2[i-1][j];
                tempsum %= 1000000007;
            }
            
            tempsum = 0;
            for(long long j = m; j >= 1; j--)
            {
                dp2[i][j] = tempsum*tempsum;
                dp2[i][j] %= 1000000007;
                tempsum += dp1[i-1][j];
                tempsum %= 1000000007;
            }
        }
        
        long long answer = 0;
        for(long long i = 1; i <= m; i++)
        {
            answer += dp2[height][i];
            answer %= 1000000007;
        }
        
        cout<<answer<<"\n";
    }
	return 0;
}
7 Likes
#include <algorithm>
#include <chrono>
#include <climits>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <list>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>

using namespace std;

#define vi vector<int>
#define int long long
#define double long double

#define endl "\n"
#define IOS                           \
    std::ios::sync_with_stdio(false); \
    cin.tie(NULL);                    \
    cout.tie(NULL);

const int inf = 1e16;
const int mod = 1e9+7;
const int Max = 5e5 + 5;

vector<vi> dp;
int solve(int level,int val,int mx,int height){
    if(level==height){
        if(level%2==0)return mx-val;
        else return val-1;
    }
    if(dp[level][val]!=-1)return dp[level][val];
    int ans=0;
    int x=0;
    if(level%2==1){
        for(int i=1;i<val;i++){
            x=solve(level+1,i,mx,height);
            ans+=(x*x)%mod;
            ans%=mod;
        }
    }else{
        for(int i=val+1;i<=mx;i++){
            x=solve(level+1,i,mx,height);
            ans+=(x*x)%mod;
            ans%=mod;
        }
        
    }
    return dp[level][val]=ans;
}
void cases()
{
    int n,m;
    cin>>n>>m;
    
    int z=n+1;
    int h=0;
    while(z!=1){
        z/=2;
        h++;
    }
    dp.resize(h+5,vi(m+5,-1));
    int ans=solve(1,m+1,m,h);
    cout<<ans<<endl;
}

int32_t main()
{
    IOS;
    int t;
    cin >> t;
    while(t--) {
        cases();
    }
    return 0;
}

What is incorrect in this approach (not talking about time complexity just correctness of the code) .