COUNTSEQ2 - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Anmol Choudhary
Preparer: Ashley Khoo
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

3979

PREREQUISITES:

None

PROBLEM:

You are given 2 integers N and M.

For each x such that 0 \leq x < M, count the number of non-negative integer sequences A_1,A_2,\ldots,A_M such that:

  • \sum\limits_{i=1}^M A_i = N
  • \prod\limits_{i=1}^M i^{A_i} = x \pmod {M}

Since the answer can be large, output the answer modulo 10^9+7.

QUICK EXPLANATION:

Note that x^0 \bmod M, x^1 \bmod M, \dots form a rho shape, with the cycle length being divisible by \phi(M) while the “handle” portion has length at most 5 (from 32 = 2^5). This means that we can restrict each A_i to be at most \phi(M) + 5, and at the end we can “portion” the remaining sum as cycles.

Therefore, we have a straightforward DP solution: dp[i][sum][product][num_of_elements_in_cycles], which runs in O(M^5). With this, calculating answers for specific N is just straightforward combinatorics.

Note that this DP table is the same for queries of similar M but with different N, so precalculating this DP table instead of calculating it every new test case is necessary for an AC.

TIME COMPLEXITY:

Time complexity is O(M^6).

SOLUTION:

Preparer's Solution
/// Super Idol的笑容
//    都没你的甜
//  八月正午的阳光
//    都没你耀眼
//  热爱105°C的你
// 滴滴清纯的蒸馏水

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ii pair<ll,ll>
#define iii pair<ii,ll>
#define fi first
#define se second
#define endl '\n'
#define debug(x) cout << #x << ": " << x << endl

#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define lb lower_bound
#define ub upper_bound

#define rep(x,start,end) for(int x=(start)-((start)>(end));x!=(end)-((start)>(end));((start)<(end)?x++:x--))
#define all(x) (x).begin(),(x).end()
#define sz(x) (int)(x).size()

mt19937 rng(chrono::system_clock::now().time_since_epoch().count());

const int MOD=1000000007;

ll qexp(ll b,ll p,int m){
    ll res=1;
    while (p){
        if (p&1) res=(res*b)%m;
        b=(b*b)%m;
        p>>=1;
    }
    return res;
}

ll inv(ll i){
	return qexp(i,MOD-2,MOD);
}

ll fac[1000005];
ll ifac[1000005];

ll nCk(int i,int j){
	long long res=ifac[j];
	rep(x,i-j+1,i+1) res=res*x%MOD;
	return res;
}

int N[41],M[41];
int memo[1605][41][41];
int ans[41][41];

signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin.exceptions(ios::badbit | ios::failbit);
	
	fac[0]=1;
	rep(x,1,1000005) fac[x]=fac[x-1]*x%MOD;
	ifac[1000004]=inv(fac[1000004]);
	rep(x,1000005,1) ifac[x-1]=ifac[x]*x%MOD;
	
	int TC;
	cin>>TC;
	rep(t,0,TC) cin>>N[t]>>M[t];
	
	vector<int> idx;
	rep(x,0,TC) idx.pub(x);
	sort(all(idx),[](int i,int j){
		return M[i]<M[j];
	});
	
	while (!idx.empty()){
		int m=M[idx.back()];
		
		int tot=0;
		rep(x,1,m+1) if (__gcd(x,m)==1) tot++;
		
		memset(memo,0,sizeof(memo));
		memo[0][1][0]=1;
		
		int curr=1;
		rep(x,1,m+1){
			int lead=5;
			
			int mul1=1,mul2=1;
			rep(y,0,tot) mul1=mul1*x%m;
			rep(y,0,lead) mul2=mul2*x%m;
			
			rep(i,curr,0) rep(j,1,m) rep(k,m,0) if (memo[i][j][k]){
				int temp=memo[i][j][k];
				
				int curr=j*mul2%m;
				memo[i+lead][curr][k]=(memo[i+lead][curr][k]-temp+MOD)%MOD;
				memo[i+lead][curr][k+1]=(memo[i+lead][curr][k+1]+temp)%MOD;
				curr=curr*mul1%m;
				memo[i+lead+tot][curr][k+1]=(memo[i+lead+tot][curr][k+1]-temp+MOD)%MOD;
			}
			
			curr+=tot+lead;
			rep(i,0,curr) rep(j,1,m) rep(k,0,m) if (memo[i][j][k]){
				memo[i+1][j*x%m][k]=(memo[i+1][j*x%m][k]+memo[i][j][k])%MOD;
			}
		}
		
		while (!idx.empty() && M[idx.back()]==m){
			int n=N[idx.back()],u=idx.back(); idx.pob();
			
			ans[u][0]=nCk(n+m-1,m-1);
			rep(x,1,m){
				rep(y,0,curr) if ((n-y)%tot==0 && n-y>=0) rep(z,1,m) if (memo[y][x][z]){
					ans[u][x]=(ans[u][x]+memo[y][x][z]*nCk((n-y)/tot+z-1,z-1))%MOD;
				}
				if (n<curr) ans[u][x]=(ans[u][x]+memo[n][x][0])%MOD;
				ans[u][0]=(ans[u][0]-ans[u][x]+MOD)%MOD;
			}
		}
	}
	
	rep(x,0,TC){
		rep(y,0,M[x]) cout<<ans[x][y]<<" "; cout<<endl;
	}
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=1e9+7;
const int iu=40;
int gcd(int x,int y){
	if(y==0) return x;
	return gcd(y,x%y);
}
ll pw(ll x,ll y){
	if(y==0) return 1;
	if(y%2) return x*pw(x,y-1)%mod;
	ll res=pw(x,y/2);
	return res*res%mod;
}
ll dp[2001][iu+1][iu+1];
ll inv[iu+1];
ll in(int k){
	if(k>0) return inv[k];
	else return mod-inv[-k];
}
int qm[iu+1];
ll ans[iu+1][iu+1];
int cmg[31]={0,0,1,1,2,1,1,1,3,2,1,1,2,1,1,1,4,1,2,1,2,1,1,1,3,2,1,3,2,1,1};
vector<pair<int,int> >qr[31];
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;
	for(int i=1; i<=t ;i++){
		int n,m;cin >> n >> m;
		qm[i]=m;
		qr[m].push_back({n,i});
	}
	for(int i=1; i<=iu ;i++){
		inv[i]=pw(i,mod-2);
	}
	for(int m=2; m<=iu ;m++){
		if(qr[m].empty()) continue;
		int phi=0;
		for(int i=1; i<m ;i++) phi+=gcd(i,m)==1;
		//int mg=cmg[m];
		int mg=5;
		int z=(mg+m+1)*phi;
		for(int i=0; i<=z ;i++){
			for(int j=0; j<m ;j++){
				for(int k=0; k<m ;k++){
					dp[i][j][k]=0;
				}
			}
		}
		dp[0][0][1]=1;
		for(int i=0; i<=z ;i++){
			for(int j=0; j<m ;j++){
				for(int k=0; k<m ;k++){
					if(j!=m-1) dp[i][j+1][k]=(dp[i][j+1][k]+dp[i][j][k])%mod;
					dp[i+1][j][k*j%m]=(dp[i+1][j][k*j%m]+dp[i][j][k])%mod;
				}
			}
		}
		//cout << m << ' ' << phi << endl;
		//cout << in(-2) << in(-1) << in(1) << in(2) << endl;
		for(auto c:qr[m]){
			ll n=c.fi;int id=c.se;
			if(false){
				for(int i=0; i<m ;i++) ans[id][i]=dp[n][m-1][i];
			}
			else{
				int zero=mg*phi+n%phi;
				ll rn=(n-zero)/phi;
				//cout << "hello " << m << ' ' << n << ' ' << rn << ' ' << zero << endl;
				for(int i=0; i<m ;i++){
					for(int j=0; j<m ;j++){
						int rz=zero+j*phi;
						ll cur=dp[rz][m-1][i];
						//cout << "!? " << j << ' ' << cur << endl;
						for(int k=0; k<m ;k++){
							if(j==k) continue;
							cur=cur*in(j-k)%mod*(rn-k+mod)%mod;
						}
						//cout << "!? " << j << ' ' << cur << endl;
						ans[id][i]=(ans[id][i]+cur)%mod;
					}
				}
			}
		}
	}
	for(int i=1; i<=t ;i++){
		for(int j=0; j<qm[i] ;j++){
			cout << ans[i][j] << ' ';
		}
		cout << '\n';
	}
}