GIIKLUB - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Mladen Puzic

Tester: Michael Nematollahi

Editorialist: Taranpreet Singh

DIFFICULTY:

PREREQUISITES:

Binary Search, Meet-in-the-middle.

PROBLEM:

Given a N \times N matrix A consisting of positive values and an integer X, find the number of ways to move from (1, 1) to (N, N) such that sum of values at positions visited on path is less than or equal to X. From position (i, j), It is possible to move only to (i+1, j) or (i, j+1).

Print the number of such ways.

EXPLANATION

Let us consider all ways to reach from (1, 1) to (N, N). We can see, that the path consists of N-1 moves moving downward and N-1 moves moving to the right.

A slow solution shall be to consider all such paths using recursion, keeping track of the sum of values to reach current position, and for every path reaching (N, N), just count the number of ways, in which sum of values is less than or equal to X.

But, the glitch is, there are ^{2*N-2}C_{N-1} different paths (Number of permutations of N-1 right moves and N-1 down moves) and with N = 20, this is too high to fit the time limit. Let us try something better.

Let us apply meet-in-the-middle trick here. We can see, that each path from (1, 1) to (N, N) crosses at least one position (x, y) such that it lies on main diagonal given by x+y = N+1, and each cell is at exactly N-1 manhattan distance from both (1, 1) and (N, N).

At each position (x, y) on this diagonal, we shall maintain two lists of values. For each path from (1, 1) to (x, y), we add the sum of values on this path to the first list at position (x, y). This takes a total of 2^{N-1} time overall.

Similarly, we shall consider all paths from (x, y) to (N, N) and add the sum of values on this path to the second list at position (x, y). This also takes a total of 2^{N-1} time overall.

Now, we have considered all paths from (1, 1) to (N, N), and for every position, we have two lists. Since total sum of values on the path should be less than or equal to X, we want to find the number of pairs (a, b) such that a is in the first list and b is in the second list, and a+b \leq X.

Comparing each pair naively shall give the same complexity as the brute solution, but this is time for another trick.

If we just sort the first list and for each element b in the second list, we can apply binary search to count the number of values less than or equal to X - b. This effectively counts the number of pairs, in time proportional to the length of the lists plus the sorting time and we have solved the problem.

Additionally, we do not need to actually store the second lists for each position, we can just compute the first list, sort it and while considering paths from (x, y) to (N, N) on the fly without storing them.

A recent problem on Meet-in-the-Middle (From last Cook-off) is here.

TIME COMPLEXITY

Time complexity Analysis:
Generating the lists takes a total of 2^{N-1} iterations since each mask of N-1 bits correspond to a distinct path from (1, 1) to (x, y) on diagonal given by x+y = N+1.

Using combinatorics, we can see, the size of each list is ^{N-1}C_{x} for ending up in position (x+1, N-x) and sorting adds another log factor. We can see, that sorting time for all lists combined is bounded by 2^{N-1}*log(2^{N-1}) = (N-1)*2^{N-1} which is also doable within time limit.

After that, number of paths from (N, N) to any cell on diagonal is also 2^{N-1} and for each path, we have to run binary search on a list with size no more than 2^{N-1} which gives time complexity (N-1)*2^{N-1} for this part too.

Thus, the overall time complexity comes out to be O((N-1)*2^{N-1}) per test case.
The memory complexity of this solution is also O(2^{N-1}). (The memory required to store lists).

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define MAXN 25
using namespace std;
long long a[MAXN][MAXN];
vector<long long> v[MAXN];
int main()
{
	int t; scanf("%d",&t);
	while(t--)
	{
		int n; long long br,sol=0;
		scanf("%d%lld",&n,&br);
		for(int i=0;i<n;i++) for(int j=0;j<n;j++) scanf("%lld",&a[i][j]);
		for(int k=0;k<(1<<(n-1));k++) //simulating every way to reach the diagonal from the top-left corner
		{
			int msk=k;
			int x=0,y=0;
			long long res=0;
			for(int i=0;i<n-1;i++) //simulating the route according to the bitmask - every 0 is a step down, and every 1 is a step to the left
			{
				res+=a[x][y];
				if(msk&1) x++;
				else y++;
				msk=msk>>1;
			}
			v[x].push_back(res); //maintain a vector of possible values in each field on the diagonal
		}
		for(int i=0;i<n;i++) sort(v[i].begin(),v[i].end()); // sorting the values in each field on the diagonal, preparing it for binary search
		for(int k=0;k<(1<<(n-1));k++) //simulating every way to reach the diagonal from the bottom-right corner
		{
			int msk=k;
			int x=n-1,y=n-1;
			long long res=0;
			for(int i=0;i<n-1;i++) //simulating the route according to the bitmask - every 0 is a step up, and every 1 is a step to the right
			{
				res+=a[x][y];
				if(msk&1) x--;
				else y--;
				msk=msk>>1;
			}
			sol+=upper_bound(v[x].begin(),v[x].end(),br-res-a[x][y])-v[x].begin(); //binary searching on the amount of ways that this path can be completed to a valid path with some path from the first half
		}
		printf("%lld\n",sol);
		for(int i=0;i<n;i++) v[i].clear();
	}
	return 0;
}
Tester's Solution
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

#define F first
#define S second

const int MAXN = 20;

int n;
ll x, a[MAXN][MAXN];
vector<ll> vec[MAXN];

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	int te; cin >> te;
	while (te--){
		cin >> n >> x;
		for (int i = 0; i < n; i++) vec[i].clear();

		for (int i = 0; i < n; i++)
			for (int j = 0; j < n; j++)
				cin >> a[i][j];
		for (int mask = 0; mask < 1<<(n-1); mask++){
			int r = 0, c = 0;
			ll cur = a[r][c];
			for (int i = 0; i < n-1; i++){
				if (mask>>i&1)
					r++;
				else
					c++;
				cur += a[r][c];
			}
			vec[r].push_back(cur);
		}
		for (int r = 0; r < n; r++)
			sort(vec[r].begin(), vec[r].end());

		ll ans = 0;
		for (int mask = 0; mask < 1<<(n-1); mask++){
			int r = n-1, c = n-1;
			ll cur = (r+c == n-1? 0: a[r][c]);
			for (int i = 0; i < n-1; i++){
				if (mask>>i&1)
					r--;
				else
					c--;
				if (r+c != n-1)
					cur += a[r][c];
			}
			ans += upper_bound(vec[r].begin(), vec[r].end(), x - cur) - vec[r].begin();
		}
		cout << ans << "\n";
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class GIIKLUB{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	    void solve(int TC) throws Exception{
	    int n = ni();
	    long M = nl();
	    long[][] a = new long[n][n];
	    for(int i = 0; i< n; i++)for(int j = 0; j< n; j++)a[i][j] = nl();
	    ArrayList<Long>[] forward = new ArrayList[n];
	    for(int i = 0; i< n; i++)forward[i] = new ArrayList<>();
	    for(int mask = (1<<(n-1))-1; mask>= 0; mask--){
	        int x = 0, y = 0;
	        long sum = 0;
	        for(int i = 0; i< n-1; i++){
	            sum += a[x][y];
	            if(((mask>>i)&1)==1)y++;
	            else x++;
	        }
	        forward[x].add(sum);
	    }
	    for(int i = 0; i< n; i++)Collections.sort(forward[i]);
	    long ans = 0;
	    for(int mask = (1<<(n-1))-1; mask>=0; mask--){
	        int x = n-1, y = n-1;
	        long sum = 0;
	        for(int i = 0; i< n-1; i++){
	            sum+=a[x][y];
	            if(((mask>>i)&1)==1)y--;
	            else x--;
	        }
	        sum += a[x][y];
	        if(sum <= M)ans+=smaller(forward[x], M-sum);
	    }
	    pn(ans);
	}
	//Returns number of elements <= x, list is sorted
	long smaller(ArrayList<Long> list, long x){
	    if(list.get(list.size()-1) <= x)return list.size();
	    int lo = 0, hi = list.size()-1;
	    while(lo+1 < hi){
	        int mid = lo+(hi-lo)/2;
	        if(list.get(mid) <= x)lo = mid;
	        else hi = mid;
	    }
	    for(int i = hi; i>= lo; i--)if(list.get(i) <= x)return i+1;
	    return 0;
	}
	class MyTreeSet<T>{
	    private int size;
	    private TreeMap<T, Integer> map;
	    public MyTreeSet(){
	        size = 0;
	        map = new TreeMap<>();
	    }
	    public int size(){return size;}
	    public int dsize(){return map.size();}
	    public boolean isEmpty(){return size==0;}
	    public void add(T t){
	        size++;
	        map.put(t, map.getOrDefault(t, 0)+1);
	    }
	    public boolean remove(T t){
	        if(!map.containsKey(t))return false;
	        size--;
	        int c = map.get(t);
	        if(c==1)map.remove(t);
	        else map.put(t, c-1);
	        return true;
	    }
	    public int freq(T t){return map.getOrDefault(t, 0);}
	    public boolean contains(T t){return map.getOrDefault(t,0)>0;}
	    public T ceiling(T t){return map.ceilingKey(t);}
	    public T floor(T t){return map.floorKey(t);}
	    public T first(){return map.firstKey();}
	    public T last(){return map.lastKey();}
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new GIIKLUB().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to Share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

6 Likes

Nice editorial :slight_smile: \hspace{1mm}

1 Like

what was special in the second subtask k<10^5.What trick was there if someone didn’t know meet in the middle(although I know it now)

Second subtask was DP.
DP[I][j][k] is the number of valid paths when at the (I, J)th spot, the sum of paths is = to K.

You can see my submission here:
https://www.codechef.com/viewsolution/24987934

1 Like

thanks__________________