PROBLEM LINK:
Setter: Andrey Filimonov
Tester: Radoslav Dimitrov
Editorialist: Teja Vardhan Reddy
DIFFICULTY:
Medium - Hard
PREREQUISITES:
Lagrange Interpolation, combinatrics, counting
PROBLEM:
Consider all matrices with N rows (numbered 1 through N) and M columns (numbered 1 through M) containing only integers between 1 and K (inclusive). For each such matrix A, let’s form a sequence L_1, L_2, \ldots, L_{N+M}:
For each i (1 \le i \le N), L_i is the maximum of all elements in the i-th row of A. Lets call this N elements as sequence B.
For each i (1 \le i \le M), L_{N+i} is the maximum of all elements in the i-th column of A.Lets call these M elements as sequence C.
Find the number of different sequences $$ formed this way.
EXPLANATION
Let us try to analyse a bit what kind of sequences can be obtained.
Claim 1: Maximum in B and C must be equal.
Proof: Proof by contradiction.
Lets say B_{max} \gt C_{max}
Let B_{max} occur at (i,j), then the C_j \geq B_{max}. And since C_{max}>=C_{j} which implies C_{max}>=B_{max} which is a contradiction.
So now maximum in B and C is equal.
Claim: Any array B and C with same maximum can be obtained.
Proof: Lets say maximum in B and C is x. Now lets say x occurs in B at position p and x occurs in C at position q. We will provide a construction such that rest can be anything from 1 to x in both B and C.
A[p][i]=C_i for each i (1 \le i \le M)
A[i][q]=B_i for each i (1 \le i \le N)
Fill rest of the values as 1.
Now lets see how ith row looks
i \ne p, 1,1,1,...q-1 times , B_i,1,1,1....,1. Hence, maximum will be B_i.
i = p, C_1,C_2,.....C_{q-1},C_q,C_{q+1},.....C_m. Hence , maximum will be C_{max} = C_q = B_p.
Similarly we can see along columns that array C is also satisfied.
Now to count the number of pairs of arrays B and C. We will show how to count for a fixed maximum x and later we can take summation across x from 1 to k.
Number of n length arrays with each element belonging to [1,x] = x^n (since each element has x choices).
Now try to figure out why above is not the answer we are searching.
In the above case, we are counting array which have elements belonging to [1,x] but all elements belong to [1,x-1] which makes maximum < x.
So we need to subtract them to get the required answer = x^n - (x-1)^n. This is a polynomial in x of size n.
Similarly, number of C sequences will be x^m-(x-1)^m. This is a polynomial in x of size m.
Now, we want \sum_{x=1}^{k} (x^n - (x-1)^n)* (x^m - (x-1)^m). This is a polynomial in x of n+m.
We can prove that \sum_{i=1}^k i^n is a polynomial of degree of n+1 in k. Proof attached here
Hence, \sum_{x=1}^{k} (x^n - (x-1)^n)* (x^m - (x-1)^m) (Lets call this S(k)). S(k) will be a polynomial of degree n+m+1 using above two conclusions.
We can calculate answers for S(1),S(2),.....S(n+m+2) by iterating over k and calculating S(k) from S(k-1).
Now if (k\lt n+m+3), we can answer it here itself. So we assume k \gt n+m+2.
Now we know n+m+2 distinct value pairs satisfied by S. Since S is a n+m+1 degree polynomial, using these n+m+2 pairs , we can find the exact polynomial S using Lagrange interpolation.
Lets say we have n+m+2 pairs (x_1,y_1),.....(x_{n+m+2},y_{n+m+2}).
S = \sum\limits_{i=1}^{n+m+2} y_i* l_i.
l_i = \prod\limits_{(1<=j<=n+m+2,j \ne i)}\frac{x-x_j}{x_i-x_j}.
Now we need to substitute x=k and x_i = i and respective values of y_i which we calculated above.
Let p = \prod_{i=1}^{n+m+2}(k-i).
Now l_i = \frac{p/(k-i)}{\prod\limits_{(1<=j<=n+m+2,j \ne i)}{(i-j)}}
Now lets try to simplify, \prod\limits_{(1<=j<=n+m+2,j \ne i)}{(i-j)}
= (i-1)*(i-2)*...2*1*(-1)*(-2)*(i-(n+m+2))
= (i-1)! * (n+m+2-i)! *(-1)^{n+m+2-i}.
So we precompute p and factorial values, then we can find each l_i in O(log(MOD)) time because we will need to find a few inverse modulos.
And substituting l_i in S gives the required S(k).
And thats a wrap!!
TIME COMPLEXITY
Finding first n+m+2 values takes O((n+m)*log(n+m)) because we need fast exponentation for each term.
Finding p takes O(n) time.
precomputing factorials take O(n) time.
Computing each l_i takes O(log(MOD)) time as mentioned above.
Finally combining all and getting S(k) takes O(n+m) time.
So total complexity = O((n+m)*log(MOD)).
SOLUTIONS:
Setter's Solution
#define y1 askjdkasldjlkasd
#include <bits/stdc++.h>
#undef y1
using namespace std;
#define pb push_back
#define mp make_pair
#define fi(a, b) for(int i=a; i<=b; i++)
#define fj(a, b) for(int j=a; j<=b; j++)
#define fo(a, b) for(int o=a; o<=b; o++)
#define fdi(a, b) for(int i=a; i>=b; i--)
#define fdj(a, b) for(int j=a; j>=b; j--)
#define fdo(a, b) for(int o=a; o>=b; o--)
#define sz(x) (int)x.size()
typedef long long ll;
typedef long double ld;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef vector<pii> vpii;
typedef vector<ll> vl;
typedef pair<ll, ll> pll;
typedef vector<pll> vpll;
typedef vector<ll> vll;
#ifdef LOCAL
#define err(...) fprintf(stderr, __VA_ARGS__)
#else
#define err(...) while (0)
#endif
double START_TIME;
void exit() {
#ifdef LOCAL
cerr << "TIME: " << setprecision(5) << fixed << (clock() - START_TIME) / CLOCKS_PER_SEC << endl;
#endif
exit(0);
}
template<typename A, typename B>
ostream& operator<<(ostream& os, pair<A, B> p) {
os << "(" << p.first << ", " << p.second << ")";
return os;
}
template<typename T>
ostream& operator<<(ostream& os, vector<T> v) {
fi(0, sz(v) - 1) {
os << v[i] << " ";
}
return os;
}
template<typename T>
ostream& operator<<(ostream& os, set<T> t) {
for (auto z : t) {
os << z << " ";
}
return os;
}
template<typename T1, typename T2>
ostream& operator<<(ostream& os, map<T1, T2> t) {
cerr << endl;
for (auto z : t) {
os << "\t" << z.first << " -> " << z.second << endl;
}
return os;
}
#ifdef LOCAL
#define dbg(x) {cerr << __LINE__ << "\t" << #x << ": " << x << endl;}
#define dbg0(x, n) {cerr << __LINE__ << "\t" << #x << ": "; for (int ABC = 0; ABC < n; ABC++) cerr << x[ABC] << ' '; cerr << endl;}
#else
#define dbg(x) while(0){}
#define dbg0(x, n) while(0){}
#endif
#ifdef LOCAL
#define ass(x) if (!(x)) { cerr << __LINE__ << "\tassertion failed: " << #x << endl, abort(); }
#else
#define ass(x) assert(x)
#endif
///////////////////////////////////////////////////
const int MOD = 1e9 + 7;
const int MAX = 2e5 + 41;
int add(int a, int b, int MOD) {
a += b;
if (a >= MOD) a -= MOD;
return a;
}
int sub(int a, int b, int MOD) {
a -= b;
if (a < 0) a += MOD;
return a;
}
int mul(int a, int b, int MOD) {
return (ll) a * b % MOD;
}
int bp(int x, int d, int MOD) {
int res = 1;
while (d) {
if (d & 1) res = mul(res, x, MOD);
x = mul(x, x, MOD);
d >>= 1;
}
return res;
}
int inv(int x, int MOD) {
return bp(x, MOD - 2, MOD);
}
namespace FFT {
const int MAGIC = 200;
const int ROOT_PW = (1 << 20);
int MODS[3] = {985661441, 976224257, 975175681};
int ROOTS[3] = {717, 315, 1335};
int IROOTS[3] = {92105044, 951431260, 590949158};
void fft (vi & a, int root, int iroot, int root_pw, int mod, bool invert) {
int n = (int) a.size();
for (int i=1, j=0; i<n; ++i) {
int bit = n >> 1;
for (; j>=bit; bit>>=1)
j -= bit;
j += bit;
if (i < j)
swap (a[i], a[j]);
}
for (int len=2; len<=n; len<<=1) {
int wlen = invert ? iroot : root;
for (int i=len; i<root_pw; i<<=1)
wlen = int (wlen * 1ll * wlen % mod);
for (int i=0; i<n; i+=len) {
int w = 1;
for (int j=0; j<len/2; ++j) {
int u = a[i+j], v = int (a[i+j+len/2] * 1ll * w % mod);
a[i+j] = u+v < mod ? u+v : u+v-mod;
a[i+j+len/2] = u-v >= 0 ? u-v : u-v+mod;
w = int (w * 1ll * wlen % mod);
}
}
}
if (invert) {
int nrev = inv(n, mod);
fi(0, n - 1) {
a[i] = int (a[i] * 1ll * nrev % mod);
}
}
}
void multiply (const vi & a, const vi & b, vi & res, int root, int iroot, int root_pw, int mod, bool square = false) {
vi fa (a.begin(), a.end()), fb (b.begin(), b.end());
size_t n = 1;
while (n < max (a.size(), b.size())) n <<= 1;
n <<= 1;
fa.resize (n), fb.resize (n);
fi(0, sz(fa) - 1) {
fa[i] %= mod;
}
fi(0, sz(fb) - 1) {
fb[i] %= mod;
}
fft (fa, root, iroot, root_pw, mod, false);
if (!square) {
fft (fb, root, iroot, root_pw, mod, false);
} else {
fb = fa;
}
fi(0, (int) n - 1) {
fa[i] = mul(fa[i], fb[i], mod);
}
fft (fa, root, iroot, root_pw, mod, true);
res.resize (n);
fi(0, (int) n - 1) {
res[i] = fa[i];
}
}
const int INVMOD0INMODS1 = inv(MODS[0], MODS[1]);
const int INV2 = inv(mul(MODS[0], MODS[1], MODS[2]), MODS[2]);
int crt(const vi &rems, int MOD) {
int k = (rems[1] - rems[0]) % MODS[1];
if (k < 0) k += MODS[1];
k = mul(k, INVMOD0INMODS1, MODS[1]);
ll x = (ll) k * MODS[0] + rems[0];
k = (int) ( (rems[2] - x) % MODS[2]);
if (k < 0) k += MODS[2];
k = mul(k, INV2, MODS[2]);
int res = 0;
res = add(res, (int) (x % MOD), MOD);
int tmp = mul(MODS[0], MODS[1], MOD);
tmp = mul(tmp, k, MOD);
res = add(res, tmp, MOD);
return res;
}
vi multiplybrute(const vi &a, const vi &b, int MOD) {
vi res(sz(a) + sz(b) - 1, 0);
fi(0, sz(a) - 1) {
if (a[i]) {
fj(0, sz(b) - 1) {
res[i + j] = add(res[i + j], mul(a[i], b[j], MOD), MOD);
}
}
}
return res;
}
vi multiply(const vi &a, const vi &b, int MOD, bool square = false) {
if ( (ll) sz(a) * sz(b) <= MAGIC) {
return multiplybrute(a, b, MOD);
}
vector<vi> resp;
fi(0, 2) {
vi c;
multiply(a, b, c, ROOTS[i], IROOTS[i], ROOT_PW, MODS[i], square);
resp.pb(c);
}
vi res;
fi(0, sz(resp[0]) - 1) {
vi rems;
fo(0, 2) {
rems.pb(resp[o][i]);
}
int g = crt(rems, MOD);
res.pb(g);
}
return res;
}
};
vi getinversepolynomial(vi vec) {//return inverse polynomial
vi v = vec;
vi res(1, inv(vec[0], MOD));
int n = sz(vec);
for (int i = 1; true; i++) {
int len = (1 << i);
vi a(len, 0);
fj(0, len / 2 - 1) {
a[j] = mul(res[j], 2, MOD);
}
vi res2 = FFT::multiply(res, res, MOD, true);
if (sz(res2) > len) res2.resize(len);
vi c;
c.insert(c.begin(), v.begin(), v.begin() + min(len, sz(v)));
vi b = FFT::multiply(res2, c, MOD);
res.resize(len);
fj(0, len - 1) {
res[j] = sub(a[j], b[j], MOD);
}
if (len > n) break;
}
return res;
}
int f[MAX];//factorials
int invf[MAX];//inverse factorials
int n, m, k;
int getc(int n, int k) {//return combinations
if (n == k) return 1;
if (k > n) return 0;
return mul(f[n], mul(invf[k], invf[n - k], MOD), MOD);
}
void init() {
f[0] = 1;
fi(1, MAX - 1) f[i] = mul(f[i - 1], i, MOD);
fi(0, MAX - 1) invf[i] = inv(f[i], MOD);
}
int b[MAX];//b[i] = i-th Bernoulli number of second kind
void findbernoulli() {//calculates Bernoulli numbers of second kind
vi ve;
fi(0, n + m + 1) {
ve.pb(invf[i + 1]);
}
ve = getinversepolynomial(ve);
ve.resize(n + m + 2);
fi(0, n + m + 1) {
b[i] = mul(ve[i], f[i], MOD);
}
b[1] = inv(2, MOD);
}
int degsum[MAX];//degsum[i] = sum_{j = 1}{k} j ^ i
void findfauhalber() {//calculates Fauhalbers sums, save them in degsum
vi a = vi(n + m + 2, 0);
fi(0, n + m + 1) {
int v1 = ::b[i];
int v2 = invf[i];
int v3 = inv(bp(k, i, MOD), MOD);
int v = mul(v1, mul(v2, v3, MOD), MOD);
a[i] = v;
}
vi b = vi(n + m + 2, 0);
fi(0, n + m + 1) {
b[i] = invf[i];
}
vi c = FFT::multiply(a, b, MOD);
fi(0, n + m + 1) {
c[i] = mul(c[i], f[i], MOD);
c[i] = mul(c[i], bp(k, i, MOD), MOD);
c[i] = sub(c[i], ::b[i], MOD);
c[i] = mul(c[i], inv(i, MOD), MOD);
}
fi(0, n + m) {
degsum[i] = c[i + 1];
}
}
void precalc() {
// n = 100 * 1000;
// m = 100 * 1000;
findbernoulli();
findfauhalber();
}
void solve() {
int ans = 0;
{//sum_{i = 1}{k} i ^ (n + m) + (i - 1) ^ (n + m)
int ab = mul(2, degsum[n + m], MOD);
ab = sub(ab, bp(k, n + m, MOD), MOD);
ans = add(ans, ab, MOD);
}
{//sum_{i = 1}{k} (i - 1) ^ n * i ^ m
fj(0, m) {
int sign = (j % 2 == 1 ? MOD - 1 : 1);
int v1 = getc(m, j);
int v2 = degsum[n + m - j];
int v = mul(v1, mul(v2, sign, MOD), MOD);
ans = sub(ans, v, MOD);
}
}
{//sum_{i = 1}{k} (i - 1) ^ m * i ^ n
fj(0, n) {
int sign = (j % 2 == 1 ? MOD - 1 : 1);
int v1 = getc(n, j);
int v2 = degsum[m + n - j];
int v = mul(v1, mul(v2, sign, MOD), MOD);
ans = sub(ans, v, MOD);
}
}
printf("%d\n", ans);
}
int main() {
#ifdef LOCAL
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
START_TIME = (double)clock();
#endif
init();
int t;
scanf("%d", &t);
while (t--) {
scanf("%d %d %d", &n, &m, &k);
precalc();
solve();
}
exit();
return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
//#pragma GCC optimize ("O3")
//#pragma GCC target ("sse4")
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
using namespace std;
template<class T, class T2> inline int chkmax(T &x, const T2 &y) { return x < y ? x = y, 1 : 0; }
template<class T, class T2> inline int chkmin(T &x, const T2 &y) { return x > y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int mod = (int)1e9 + 7;
int pw(int x, int p)
{
int r = 1;
while(p)
{
if(p & 1) r = r * 1ll * x % mod;
x = x * 1ll * x % mod;
p >>= 1;
}
return r;
}
int inv(int x) { return pw(x, mod - 2); }
int n, m, x;
void read()
{
cin >> n >> m >> x;
}
int y[MAXN];
int fact[MAXN];
void solve()
{
y[0] = 0;
for(int i = 1; i <= n + m; i++)
y[i] = (y[i - 1] + (pw(i, n) - pw(i - 1, n) + mod) * 1ll * (pw(i, m) - pw(i - 1, m) + mod)) % mod;
if(x <= n + m)
{
cout << y[x] << endl;
return;
}
// F(k) = SUM y[i] * L(i, k)
// L(i, k) = PRODUCT (k - j) / (i - j)
int product_up = 1;
for(int i = 0; i <= n + m; i++)
product_up = product_up * 1ll * (x - i) % mod;
fact[0] = 1;
for(int i = 1; i <= n + m; i++)
fact[i] = fact[i - 1] * 1ll * i % mod;
int ans = 0;
for(int i = 0; i <= n + m; i++)
{
int v = product_up * 1ll * inv(x - i) % mod;
v = v * 1ll * inv(fact[i]) % mod;
v = v * 1ll * inv(fact[n + m - i]) % mod;
if((n + m - i) & 1)
v = (mod - v) % mod;
v = v * 1ll * y[i] % mod;
ans = (ans + v) % mod;
}
cout << ans << endl;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--)
{
read();
solve();
}
return 0;
}
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout)
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// find_by_order() // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int powi(int a,int b){
int ans=1;
while(b){
if(b%2){
ans*=a;
ans%=mod;
}
a*=a;
a%=mod;
b/=2;
}
return ans;
}
int getinv(int i){
return powi(i,mod-2);
}
int fact[412345],getsum[412345],pn[412345],pm[412345];
main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
cin>>t;
int i;
fact[0]=1;
f(i,1,412345){
fact[i]=fact[i-1]*i;
fact[i]%=mod;
}
//cout<<"das"<<endl;
while(t--){
int n,m;
cin>>n>>m;
int k;
cin>>k;
rep(i,2*(n+m)+10){
pn[i]=powi(i,n);
pm[i]=powi(i,m);
//cout<<pn[i]<<endl;
}
int val;
getsum[0]=0;
f(i,1,2*(n+m)+10){
val=pn[i]-pn[i-1];
val*=pm[i]-pm[i-1];
val%=mod;
val+=mod;
val%=mod;
getsum[i]=getsum[i-1]+val;
getsum[i]%=mod;
}
if(k<2*(n+m)+10){
cout<<getsum[k]<<endl;
continue;
}
ll gg=1;
f(i,1,2*(n+m)+10){
gg*=(k-i);
gg%=mod;
}
ll wow=0;
f(i,1,2*(n+m)+10){
val=fact[i-1];
val*=fact[2*(n+m)+9-i];
val%=mod;
if((2*(n+m)+9-i)%2){
val*=-1;
val%=mod;
val+=mod;
val%=mod;
}
val=getinv(val);
val*=gg;
val%=mod;
val*=getinv(k-i);
val%=mod;
val*=getsum[i];
val%=mod;
wow+=val;
wow%=mod;
}
cout<<wow<<endl;
}
return 0;
}
Feel free to Share your approach, If it differs. Suggestions are always welcomed.