COMPCOUNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: gunpoint_88
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

Given N, M, and K, count the number of N\times M black-and-white grids such that exactly K cells are black, and all the black cells form a connected component (with travel allowed in all 8 directions).

EXPLANATION:

Notice the small constraint on K.
This means that there really aren’t too many “shapes” the connected component can take - the only thing that differs is where in the grid that shape appears.

Suppose we fix the “shape” of the black cells’ connected component.
How many positions of the grid can it occur at?

Answer

Let the ‘height’ of the shape (the distance between its topmost and bottom-most cells) be H, and its ‘width’ (defined similarly) be W.
Then, if H\gt N or W\gt M the shape can’t exist in the grid at all; otherwise it can occur at
(N-H+1)\cdot (M-W+1) positions of the grid.

This can be visualized by first pushing the shape upwards and leftwards as much as possible; then you have N-H+1 possible right shifts and M-W+1 possible down shifts, all of which are allowed (and the right shifts can be done independently of the down shifts).


Note that we didn’t even really care about what the shape was to compute its contribution — we only cared about its height and width.
So, suppose we’re able to precompute f(K, H, W) — the number of distinct shapes of size K with height H and width W.
Then, the answer for a single testcase can be found easily in \mathcal{O}(K^2) since it just becomes

\sum_{H=1}^{\min(K, N)} \sum_{W=1}^{\min(K, M)} f(K, H, W) \cdot (N-H+1)\cdot (M-W+1)

All that remains is to actually precompute f(K, H, W).
That can be done by pure brute force!

Observe that a shape with K cells must fit inside a K\times K box.
So, one possible solution is to just brute force all possible K\times K grids with K black cells, and check if the resulting shape is connected or not; if it is, compute its height and width and add 1 to the appropriate f(K, H, W).
To avoid overcounting, ensure that the shape touches both the upper border and the left border of the box.

Even for K = 6, there are \binom{36}{6} \approx 2\cdot 10^6 possible grids to check so this is fast enough.

TIME COMPLEXITY:

\mathcal{O}(K^2) per testcase, after precomputing all possible shapes.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll mod=1e9+7;

#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif

vector<vector<vector<ll>>> precomp(ll K=6) {
	/*
		returns a 6x6 matrix:
			mat[i][j] -> no. of connected comps
			with length = i, breadth = j
	*/

	vector<vector<vector<ll>>> res(K+1,vector<vector<ll>>(K+1,vector<ll>(K+1,0)));
	vector<vector<ll>> g(K+1,vector<ll>(K+1,0)),vis=g;

	ll ops=0;

	function<ll(ll,ll)> dfs=[&](ll i,ll j)->ll{
		ll res=1;
		vis[i][j]=1;
		for(ll ii=i-1;ii<=i+1;ii++) {
			for(ll jj=j-1;jj<=j+1;jj++) {
				if(ii<0||jj<0||ii>K-1||jj>K-1||vis[ii][jj]||!g[ii][jj]) continue;
				res+=dfs(ii,jj);
			}
		}
		return res;
	};

	function<void(ll,ll,ll)> rec=[&](ll u,ll i,ll j)->void{
		if(i==K) {
			// check if the current one satisfies
			if(!u) return;
			ll mni=K+1,mxi=0,mnj=K+1,mxj=0,good=1,i1=-1,i2=-1;

			for(ll x=0;x<K;x++) {
				for(ll y=0;y<K;y++) {
					vis[x][y]=0;
					if(g[x][y]) {
						i1=x; i2=y;
						mni=min(mni,x);
						mxi=max(mxi,x);
						mnj=min(mnj,y);
						mxj=max(mxj,y);
					}
				}
			}
			if(mni||mnj) return;
			if(dfs(i1,i2)==u) {
				res[u][mxi+1][mxj+1]++;
			}
			return;
		}
		
		rec(u,j+1==K?i+1:i,j+1==K?0:j+1);
		if(u==K) return;
		
		g[i][j]=1;
		rec(u+1,j+1==K?i+1:i,j+1==K?0:j+1);
		g[i][j]=0;
	};
	
	rec(0,0,0);
	return res;
}

vector<vector<vector<ll>>> box=precomp(6);

ll solve(ll n,ll m,ll k) {
	ll ans=0;
	for(ll x=1;x<=k;x++) {
		for(ll y=1;y<=k;y++) {
			ll ways=box[k][x][y];
			if(x>n||y>m||!ways) continue;
			ans+=(n-x+1)*(m-y+1)*ways%mod;
			ans%=mod;
		}
	}
	return ans;
}

int main() {

	ll t; cin>>t; assert(t<=1e4);
	ll nmsum=0;
	while(t--) {
		ll n,m,k; cin>>n>>m>>k;
		assert(n>=1&&n<=1e5);
		assert(m>=1&&m<=1e5);
		assert(k>=1&&k<=min(6ll,n*m));
		nmsum+=n*m;
		cout<<solve(n,m,k)<<"\n";
	}
	assert(nmsum<=1e5);
}
Tester's code (C++)
// library link: https://github.com/manan-grover/My-CP-Library/blob/main/library.cpp
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
#define asc(i,a,n) for(I i=a;i<n;i++)
#define dsc(i,a,n) for(I i=n-1;i>=a;i--)
#define forw(it,x) for(A it=(x).begin();it!=(x).end();it++)
#define bacw(it,x) for(A it=(x).rbegin();it!=(x).rend();it++)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define lb(x) lower_bound(x)
#define ub(x) upper_bound(x)
#define fbo(x) find_by_order(x)
#define ook(x) order_of_key(x)
#define all(x) (x).begin(),(x).end()
#define sz(x) (I)((x).size())
#define clr(x) (x).clear()
#define U unsigned
#define I long long int
#define S string
#define C char
#define D long double
#define A auto
#define B bool
#define CM(x) complex<x>
#define V(x) vector<x>
#define P(x,y) pair<x,y>
#define OS(x) set<x>
#define US(x) unordered_set<x>
#define OMS(x) multiset<x>
#define UMS(x) unordered_multiset<x>
#define OM(x,y) map<x,y>
#define UM(x,y) unordered_map<x,y>
#define OMM(x,y) multimap<x,y>
#define UMM(x,y) unordered_multimap<x,y>
#define BS(x) bitset<x>
#define L(x) list<x>
#define Q(x) queue<x>
#define PBS(x) tree<x,null_type,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define PBM(x,y) tree<x,y,less<I>,rb_tree_tag,tree_order_statistics_node_update>
#define pi (D)acos(-1)
#define md 1000000007
#define rnd randGen(rng)
I a[6][6]={};
I dp[7][7][7]={};
I vis[6][6]={};
// B check(){
//     asc(i,0,6){
//         asc(j,0,6){
//             if(a[i][j]){
//                 B f=true;
//                 if(i>0){
//                     if(a[i-1][j]){
//                         f=false;
//                     }
//                     if(j>0){
//                         if(a[i-1][j-1]){
//                             f=false;
//                         }
//                     }
//                     if(j<5){
//                         if(a[i-1][j+1]){
//                             f=false;
//                         }
//                     }
//                 }
//                 if(i<5){
//                     if(a[i+1][j]){
//                         f=false;
//                     }
//                     if(j>0){
//                         if(a[i+1][j-1]){
//                             f=false;
//                         }
//                     }
//                     if(j<5){
//                         if(a[i+1][j+1]){
//                             f=false;
//                         }
//                     }
//                 }
//                 if(j>0){
//                     if(a[i][j-1]){
//                         f=false;
//                     }
//                 }
//                 if(j<5){
//                     if(a[i][j+1]){
//                         f=false;
//                     }
//                 }
//                 if(f){
//                     return false;
//                 }
//             }
//         }
//     }
//     return true;
// }
void cnt(I i,I j,I &res){
    if(vis[i][j]){
        return;
    }
    vis[i][j]=1;
    res++;
    if(i>0){
        if(a[i-1][j]){
            cnt(i-1,j,res);
        }
        if(j>0){
            if(a[i-1][j-1]){
                cnt(i-1,j-1,res);
            }
        }
        if(j<5){
            if(a[i-1][j+1]){
                cnt(i-1,j+1,res);
            }
        }
    }
    if(i<5){
        if(a[i+1][j]){
            cnt(i+1,j,res);
        }
        if(j>0){
            if(a[i+1][j-1]){
                cnt(i+1,j-1,res);
            }
        }
        if(j<5){
            if(a[i+1][j+1]){
                cnt(i+1,j+1,res);
            }
        }
    }
    if(j>0){
        if(a[i][j-1]){
            cnt(i,j-1,res);
        }
    }
    if(j<5){
        if(a[i][j+1]){
            cnt(i,j+1,res);
        }
    }
}
void cal(I x,I y,I z){
    if(z<5){
        cal(x,y,z+1);
    }else{
        if(y<5){
            cal(x,y+1,0);
        }
    }
    a[y][z]=1;
    asc(i,0,6){
        asc(j,0,6){
            vis[i][j]=0;
        }
    }
    I res=0;
    cnt(y,z,res);
    if(res==x+1){
        I l=6,r=-1,b=6,u=-1;
        asc(i,0,6){
            asc(j,0,6){
                if(a[i][j]){
                    l=min(l,i);
                    r=max(r,i);
                    b=min(b,j);
                    u=max(u,j);
                }
            }
        }
        if(x+1>6){
            cout<<"here\n";
        }
        if(l==0 && b==0){
            dp[x+1][r+1][u+1]++;
        }
    }
    if(x<5){
        if(z<5){
            cal(x+1,y,z+1);
        }else{
            if(y<5){
                cal(x+1,y+1,0);
            }
        }
    }
    a[y][z]=0;
}
int main(){
  mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
  uniform_int_distribution<I> randGen;
  ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
  #ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
  #endif
  cal(0,0,0);
  I t;
  cin>>t;
  while(t--){
      I n,m,k;
      cin>>n>>m>>k;
      I ans=0;
      asc(i,1,min(n+1,7ll)){
          asc(j,1,min(m+1,7ll)){
              ans+=(n-i+1)*(m-j+1)*dp[k][i][j];
              ans%=md;
          }
      }
      cout<<ans<<"\n";
  }
  return 0;
}
Editorialist's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct DSU {
private:
	std::vector<int> parent_or_size;
public:
	DSU(int n = 1): parent_or_size(n, -1) {}
	int get_root(int u) {
		if (parent_or_size[u] < 0) return u;
		return parent_or_size[u] = get_root(parent_or_size[u]);
	}
	int size(int u) { return -parent_or_size[get_root(u)]; }
	bool same_set(int u, int v) {return get_root(u) == get_root(v); }
	bool merge(int u, int v) {
		u = get_root(u), v = get_root(v);
		if (u == v) return false;
		if (parent_or_size[u] > parent_or_size[v]) std::swap(u, v);
		parent_or_size[u] += parent_or_size[v];
		parent_or_size[v] = u;
		return true;
	}
	std::vector<std::vector<int>> group_up() {
		int n = parent_or_size.size();
		std::vector<std::vector<int>> groups(n);
		for (int i = 0; i < n; ++i) {
			groups[get_root(i)].push_back(i);
		}
		groups.erase(std::remove_if(groups.begin(), groups.end(), [&](auto &s) { return s.empty(); }), groups.end());
		return groups;
	}
};

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    vector ct(7, vector(7, vector(7, 0)));
    vector<array<int, 2>> dir = {{0, 1}, {1, 0}, {1, 1}, {-1, 1}};
    for (int sz = 1; sz <= 6; ++sz) {
        string s(sz*sz-sz, '0');
        s += string(sz, '1');
        do {
            bool up = false, left = false;
            for (int i = 0; i < sz; ++i) up |= s[i] == '1';
            for (int i = 0; i < sz; ++i) left |= s[sz*i] == '1';
            if (!up or !left) continue;
            DSU D(sz*sz);
            int comps = sz;
            int right = 0, down = 0;
            for (int i = 0; i < sz*sz; ++i) if (s[i] == '1') {
                right = max(right, i%sz);
                down = max(down, i/sz);
                for (auto [dx, dy] : dir) {
                    int x = i/sz + dx, y = i%sz + dy;
                    if (min(x, y) < 0) continue;
                    if (max(x, y) >= sz) continue;
                    int pos = sz*x + y;
                    if (s[pos] == '1') comps -= D.merge(i, pos);
                }
            }
            if (comps > 1) continue;
            ++ct[sz][right+1][down+1];
        } while (next_permutation(begin(s), end(s)));
    }

    int t; cin >> t;
    while (t--) {
        int n, m, k; cin >> n >> m >> k;
        int ans = 0;
        const int mod = 1e9 + 7;
        for (int h = 1; h <= min(n, k); ++h) for (int w = 1; w <= min(m, k); ++w) {
            ll add = 1LL*(n-h+1)*(m-w+1)%mod;
            add = add*ct[k][h][w];
            ans = (ans + add) % mod;
        }
        cout << ans << '\n';
    }
}
1 Like