PROBLEM LINK:
Setter: Rami
Tester: Roman Bilyi
Editorialist: Taranpreet Singh
DIFFICULTY:
Easy
PREREQUISITES:
Combinatorics, Maths, Observation.
PROBLEM:
Given three integers N, M and K, Find the number of arrays A of length N such that the prefix sum array S of array A contains at least K elements divisible by M.
EXPLANATION
Let us solve a simpler problem first.
Given N, M and K, find the number of arrays of length N having each value in range [0, M-1] such that exactly K values are 0.
Here, we want K values to be zero and the remaining N-K values to be non-zero. For each non-zero value, we have exactly M-1 choices for that number. So we can select the non-zero values in (M-1)^{N-K} ways.
But till now, we have only chosen the elements in order. They can appear at any of the N-K positions among total N positions.
Assume N = 4, K = 2 and chosen non-zero values be [1, 2] in this order only. Then, there are six following ways to place them among array of size N.
0 0 1 2
0 1 0 2
0 1 2 0
1 0 0 2
1 0 2 0
1 2 0 0
It is easy to see, that it is equivalent to Choosing (N-K) positions out of N positions and for each set of positions, choosing (N-K) non-zero values, which can be done in (M-1)^{N-K} ways, resulting in total ^{N}C_{N-K}*(M-1)^{N-K} ways to select array with exactly K zeroes.
Coming back to original problem, let us Consider prefix sum array modulo M, since We only care about values in prefix sum array modulo M, and (a+b)\%M = (a\%M+b\%M)\%M .
Let’s call this modulo prefix sum array T. It can be seen, that all values in this array are in range [0, M-1].
Let’s assume T_i = x where 0 \leq x < M. Now, what values can T_{i+1} take? We can see, that T_{i+1} can be (x+y)\%M where y can take any value from 0 to M-1. It is easy to see that for each value of y, (x+y) takes a different value, and y takes M distinct values, so T_{i+1} can take all distinct values irrespective of T_{i}.
Hence, we can consider all possible arrays T such that each value is within range [0, M-1] and find a unique array A whose modulo prefix sum array is T. So, there is one-to-one mapping between array A and array T. So, the number of valid ways to choose A array is same as the number of valid ways to choose T array.
Hence, the original problem turns into choosing array T of length N such that each value is independent of each other and at least K values are 0 (Since only 0 in range [0, M-1] is divisible by M). We can individually fix number of 0 to p for K \leq p \leq N and count the number of arrays T with exactly p zeroes.
For computing binomial coefficients, it is better to compute factorials and their inverses in advance, and ^nC_r = \displaystyle\frac{n!}{r!(n-r)!} Details can be found here.
TIME COMPLEXITY
Time complexity is O(N) per test case.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sc second
#define fr first
using namespace std;
const int mod = 1e9+7;
const int N = 1e5+10;
int n,m,k;
ll fac[N];
ll inv[N];
ll pw(ll x, ll p){
if(!p)
return 1;
ll z = pw(x,p/2);
z *= z;
z %= mod;
if(p%2 == 0)
return z;
z *= x;
z %= mod;
return z;
}
ll c(ll x, ll y){
ll res = fac[x]*inv[y];
res %= mod;
res *= inv[x-y];
res %= mod;
return res;
}
int main() {
fac[0] = inv[0] = 1;
for(int i=1 ;i <N ;i ++){
fac[i] = fac[i-1]*i;
fac[i] %= mod;
inv[i] = pw(fac[i],mod-2);
}
int t;
cin>>t;
while(t--){
scanf("%d%d%d",&n,&m,&k);
ll res =0;
for(int i=k ; i <= n ;i ++){
res += (c(n,i) * pw(m-1,n-i))%mod;
res %= mod;
}
printf("%lld\n",res);
}
return 0;
}
Tester's Solution
#include “bits/stdc++.h”
#pragma GCC optimize(“Ofast”)
#pragma GCC target(“sse,sse2,sse3,ssse3,sse4,avx,avx2”)
using namespace std;
#define FOR(i,a,b) for (int i = (a); i < (b); i++)
#define RFOR(i,b,a) for (int i = (b) - 1; i >= (a); i–)
#define ITER(it,a) for (__typeof(a.begin()) it = a.begin(); it != a.end(); it++)
#define FILL(a,value) memset(a, value, sizeof(a))
#define SZ(a) (int)a.size()
#define ALL(a) a.begin(), a.end()
#define PB push_back
#define MP make_pair
typedef long long Int;
typedef vector VI;
typedef pair<int, int> PII;
const double PI = acos(-1.0);
const int INF = 1000 * 1000 * 1000;
const Int LINF = INF * (Int) INF;
const int MAX = 100007;
const int MOD = 1000000007;
const double Pi = acos(-1.0);
Int F[MAX];
Int IF[MAX];
Int bpow(Int a, Int k)
{
Int res = 1;
while (k) {
if (k & 1) {
res *= a;
res %= MOD;
}
a *= a;
a %= MOD;
k /= 2;
}
return res;
}
Int C(int n, int k)
{
return F[n] * IF[k] % MOD * IF[n - k] % MOD;
}
int main(int argc, char* argv[])
{
// freopen(“in.txt”, “r”, stdin);
//ios::sync_with_stdio(false); cin.tie(0);
F[0] = 1;
FOR(i,1,MAX)
F[i] = F[i - 1] * i % MOD;
FOR(i,0,MAX)
{
IF[i] = bpow(F[i], MOD - 2);
}
int t;
cin >> t;
FOR(tt,0,t)
{
int n, m, k;
cin >> n >> m >> k;
Int res = 0;
FOR(i,k,n + 1)
{
res += C(n, i) * bpow(m - 1, n - i);
res %= MOD;
}
cout << res << endl;
}
cerr << 1.0 * clock() / CLOCKS_PER_SEC << endl;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class ANGGRA{
//SOLUTION BEGIN
long mod = (long)1e9+7;
long[][] fif;
void pre() throws Exception{fif = fif((int)1e5+5);}
void solve(int TC) throws Exception{
int n = ni();long m = nl();int k = ni();
long ans = 0;
for(int i = k; i<= n; i++)ans = add(ans, mul(C(fif, n, i), pow(m-1, n-i)));
pn(ans);
}
long C(long[][] fif, int n, int r){
if(n< 0 || n<r || r<0)return 0;
return (fif[0][n]*((fif[1][r]*fif[1][n-r])%mod))%mod;
}
long[][] fif(int mx){
mx++;
long[] F = new long[mx], IF = new long[mx];
F[0] = 1;
for(int i = 1; i< mx; i++)F[i] = (F[i-1]*i)%mod;
//GFG
long M = mod;
long y = 0, x = 1;
long a = F[mx-1];
while(a> 1){
long q = a/M;
long t = M;
M = a%M;
a = t;
t = y;
y = x-q*y;
x = t;
}
if(x<0)x+=mod;
IF[mx-1] = x;
for(int i = mx-2; i>= 0; i--)IF[i] = (IF[i+1]*(i+1))%mod;
return new long[][]{F, IF};
}
long mul(long a, long b){
if(a>=mod)a%=mod;
if(b>=mod)b%=mod;
a*=b;
if(a>=mod)a%=mod;
return a;
}
long add(long a, long b){
if(Math.abs(a)>=mod)a%=mod;
if(a<0)a+=mod;
if(Math.abs(b)>=mod)b%=mod;
if(b<0)b+=mod;
a+=b;
if(Math.abs(a)>=mod)a%=mod;
return a;
}
long pow(long a, long p){
long o = 1;
while(p>0){
if(p%2==1)o = (a*o)%mod;
a = (a*a)%mod;
p>>=1;
}
return o;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new ANGGRA().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach, if you want to. (even if its same ) . Suggestions are welcomed as always had been.