COUNTIT - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Andrey Filimonov

Tester: Radoslav Dimitrov

Editorialist: Teja Vardhan Reddy

DIFFICULTY:

Medium - Hard

PREREQUISITES:

Lagrange Interpolation, combinatrics, counting

PROBLEM:

Consider all matrices with N rows (numbered 1 through N) and M columns (numbered 1 through M) containing only integers between 1 and K (inclusive). For each such matrix A, let’s form a sequence L_1, L_2, \ldots, L_{N+M}:

For each i (1 \le i \le N), L_i is the maximum of all elements in the i-th row of A. Lets call this N elements as sequence B.
For each i (1 \le i \le M), L_{N+i} is the maximum of all elements in the i-th column of A.Lets call these M elements as sequence C.

Find the number of different sequences $$ formed this way.

EXPLANATION

Let us try to analyse a bit what kind of sequences can be obtained.
Claim 1: Maximum in B and C must be equal.
Proof: Proof by contradiction.
Lets say B_{max} \gt C_{max}
Let B_{max} occur at (i,j), then the C_j \geq B_{max}. And since C_{max}>=C_{j} which implies C_{max}>=B_{max} which is a contradiction.

So now maximum in B and C is equal.

Claim: Any array B and C with same maximum can be obtained.
Proof: Lets say maximum in B and C is x. Now lets say x occurs in B at position p and x occurs in C at position q. We will provide a construction such that rest can be anything from 1 to x in both B and C.

A[p][i]=C_i for each i (1 \le i \le M)
A[i][q]=B_i for each i (1 \le i \le N)
Fill rest of the values as 1.

Now lets see how ith row looks

i \ne p, 1,1,1,...q-1 times , B_i,1,1,1....,1. Hence, maximum will be B_i.

i = p, C_1,C_2,.....C_{q-1},C_q,C_{q+1},.....C_m. Hence , maximum will be C_{max} = C_q = B_p.

Similarly we can see along columns that array C is also satisfied.

Now to count the number of pairs of arrays B and C. We will show how to count for a fixed maximum x and later we can take summation across x from 1 to k.

Number of n length arrays with each element belonging to [1,x] = x^n (since each element has x choices).

Now try to figure out why above is not the answer we are searching.

In the above case, we are counting array which have elements belonging to [1,x] but all elements belong to [1,x-1] which makes maximum < x.

So we need to subtract them to get the required answer = x^n - (x-1)^n. This is a polynomial in x of size n.

Similarly, number of C sequences will be x^m-(x-1)^m. This is a polynomial in x of size m.

Now, we want \sum_{x=1}^{k} (x^n - (x-1)^n)* (x^m - (x-1)^m). This is a polynomial in x of n+m.

We can prove that \sum_{i=1}^k i^n is a polynomial of degree of n+1 in k. Proof attached here

Hence, \sum_{x=1}^{k} (x^n - (x-1)^n)* (x^m - (x-1)^m) (Lets call this S(k)). S(k) will be a polynomial of degree n+m+1 using above two conclusions.

We can calculate answers for S(1),S(2),.....S(n+m+2) by iterating over k and calculating S(k) from S(k-1).

Now if (k\lt n+m+3), we can answer it here itself. So we assume k \gt n+m+2.
Now we know n+m+2 distinct value pairs satisfied by S. Since S is a n+m+1 degree polynomial, using these n+m+2 pairs , we can find the exact polynomial S using Lagrange interpolation.

Lets say we have n+m+2 pairs (x_1,y_1),.....(x_{n+m+2},y_{n+m+2}).

S = \sum\limits_{i=1}^{n+m+2} y_i* l_i.
l_i = \prod\limits_{(1<=j<=n+m+2,j \ne i)}\frac{x-x_j}{x_i-x_j}.

Now we need to substitute x=k and x_i = i and respective values of y_i which we calculated above.

Let p = \prod_{i=1}^{n+m+2}(k-i).

Now l_i = \frac{p/(k-i)}{\prod\limits_{(1<=j<=n+m+2,j \ne i)}{(i-j)}}

Now lets try to simplify, \prod\limits_{(1<=j<=n+m+2,j \ne i)}{(i-j)}
= (i-1)*(i-2)*...2*1*(-1)*(-2)*(i-(n+m+2))
= (i-1)! * (n+m+2-i)! *(-1)^{n+m+2-i}.

So we precompute p and factorial values, then we can find each l_i in O(log(MOD)) time because we will need to find a few inverse modulos.

And substituting l_i in S gives the required S(k).

And thats a wrap!!

TIME COMPLEXITY

Finding first n+m+2 values takes O((n+m)*log(n+m)) because we need fast exponentation for each term.

Finding p takes O(n) time.

precomputing factorials take O(n) time.

Computing each l_i takes O(log(MOD)) time as mentioned above.

Finally combining all and getting S(k) takes O(n+m) time.

So total complexity = O((n+m)*log(MOD)).

SOLUTIONS:

Setter's Solution
#define y1 askjdkasldjlkasd
#include <bits/stdc++.h>
#undef y1
 
using namespace std;
 
#define pb push_back
#define mp make_pair
#define fi(a, b) for(int i=a; i<=b; i++)
#define fj(a, b) for(int j=a; j<=b; j++)
#define fo(a, b) for(int o=a; o<=b; o++)
#define fdi(a, b) for(int i=a; i>=b; i--)
#define fdj(a, b) for(int j=a; j>=b; j--)
#define fdo(a, b) for(int o=a; o>=b; o--)
#define sz(x) (int)x.size()
 
typedef long long ll;
typedef long double ld;
typedef vector<int> vi;
typedef pair<int, int> pii;
typedef vector<pii> vpii;
typedef vector<ll> vl;
typedef pair<ll, ll> pll;
typedef vector<pll> vpll;
typedef vector<ll> vll;
 
#ifdef LOCAL
#define err(...) fprintf(stderr, __VA_ARGS__)
#else
#define err(...) while (0)
#endif
 
double START_TIME;
 
void exit() {
#ifdef LOCAL    
    cerr << "TIME: " << setprecision(5) << fixed << (clock() - START_TIME) / CLOCKS_PER_SEC << endl;
#endif  
    exit(0);
}
 
template<typename A, typename B>
ostream& operator<<(ostream& os, pair<A, B> p) {
    os << "(" << p.first << ", " << p.second << ")";
    return os;
}
 
template<typename T>
ostream& operator<<(ostream& os, vector<T> v) {
    fi(0, sz(v) - 1) {
        os << v[i] << " ";
    }
    return os;
}
 
template<typename T>
ostream& operator<<(ostream& os, set<T> t) {
    for (auto z : t) {
        os << z << " ";
    }
    return os;
}
 
template<typename T1, typename T2>
ostream& operator<<(ostream& os, map<T1, T2> t) {
    cerr << endl;
    for (auto z : t) {
        os << "\t" << z.first << " -> " << z.second << endl;
    }
    return os;
}
 
#ifdef LOCAL
#define dbg(x) {cerr << __LINE__ << "\t" << #x << ": " << x << endl;}
#define dbg0(x, n) {cerr << __LINE__ << "\t" << #x << ": "; for (int ABC = 0; ABC < n; ABC++) cerr << x[ABC] << ' '; cerr << endl;}
#else
#define dbg(x) while(0){}
#define dbg0(x, n) while(0){}
#endif
 
#ifdef LOCAL
#define ass(x) if (!(x)) { cerr << __LINE__ << "\tassertion failed: " << #x << endl, abort(); }
#else
#define ass(x) assert(x)
#endif
 
///////////////////////////////////////////////////
 
const int MOD = 1e9 + 7;
const int MAX = 2e5 + 41;
 
int add(int a, int b, int MOD) {
    a += b;
    if (a >= MOD) a -= MOD;
    return a;
}
 
int sub(int a, int b, int MOD) {
    a -= b;
    if (a < 0) a += MOD;
    return a;
}
 
int mul(int a, int b, int MOD) {
    return (ll) a * b % MOD;
}
 
int bp(int x, int d, int MOD) {
    int res = 1;
    while (d) {
        if (d & 1) res = mul(res, x, MOD);
        x = mul(x, x, MOD);
        d >>= 1;
    }
    return res;
}
 
int inv(int x, int MOD) {
    return bp(x, MOD - 2, MOD);
}
 
namespace FFT {
 
    const int MAGIC = 200;
 
    const int ROOT_PW = (1 << 20);
 
    int MODS[3] = {985661441, 976224257, 975175681};
    int ROOTS[3] = {717, 315, 1335};
    int IROOTS[3] = {92105044, 951431260, 590949158}; 
 
    void fft (vi & a, int root, int iroot, int root_pw, int mod, bool invert) {
        int n = (int) a.size();
 
        for (int i=1, j=0; i<n; ++i) {
            int bit = n >> 1;
            for (; j>=bit; bit>>=1)
                j -= bit;
            j += bit;
                if (i < j)
                swap (a[i], a[j]);
        }
 
        for (int len=2; len<=n; len<<=1) {
            int wlen = invert ? iroot : root;
            for (int i=len; i<root_pw; i<<=1)
                wlen = int (wlen * 1ll * wlen % mod);
            for (int i=0; i<n; i+=len) {
                int w = 1;
                for (int j=0; j<len/2; ++j) {
                    int u = a[i+j],  v = int (a[i+j+len/2] * 1ll * w % mod);
                    a[i+j] = u+v < mod ? u+v : u+v-mod;
                    a[i+j+len/2] = u-v >= 0 ? u-v : u-v+mod;
                    w = int (w * 1ll * wlen % mod);
                }
            }
        }
        if (invert) {
            int nrev = inv(n, mod);
            fi(0, n - 1) {
                a[i] = int (a[i] * 1ll * nrev % mod);
            }
        }
    }
 
    void multiply (const vi & a, const vi & b, vi & res, int root, int iroot, int root_pw, int mod, bool square = false) {
        vi fa (a.begin(), a.end()),  fb (b.begin(), b.end());
        size_t n = 1;
        while (n < max (a.size(), b.size()))  n <<= 1;
        n <<= 1;
        fa.resize (n),  fb.resize (n);
 
        fi(0, sz(fa) - 1) {
            fa[i] %= mod;
        }   
        fi(0, sz(fb) - 1) {
            fb[i] %= mod;
        }
 
        fft (fa, root, iroot, root_pw, mod, false);
        if (!square) {  
            fft (fb, root, iroot, root_pw, mod, false);
        } else {
            fb = fa;
        }
        fi(0, (int) n - 1) {
            fa[i] = mul(fa[i], fb[i], mod);
        }
        fft (fa, root, iroot, root_pw, mod, true);
 
        res.resize (n);
        fi(0, (int) n - 1) {
            res[i] = fa[i];
        }
    }
    
    const int INVMOD0INMODS1 = inv(MODS[0], MODS[1]);
    const int INV2 = inv(mul(MODS[0], MODS[1], MODS[2]), MODS[2]);
 
    int crt(const vi &rems, int MOD) {
        int k = (rems[1] - rems[0]) % MODS[1];
        if (k < 0) k += MODS[1];
        k = mul(k, INVMOD0INMODS1, MODS[1]);
        ll x = (ll) k * MODS[0] + rems[0];
 
        k = (int) ( (rems[2] - x) % MODS[2]);
        if (k < 0) k += MODS[2];
        k = mul(k, INV2, MODS[2]);
 
        int res = 0;
        res = add(res, (int) (x % MOD), MOD);
 
        int tmp = mul(MODS[0], MODS[1], MOD);
        tmp = mul(tmp, k, MOD);
        res = add(res, tmp, MOD);
 
        return res;
    }   
 
    vi multiplybrute(const vi &a, const vi &b, int MOD) {       
        vi res(sz(a) + sz(b) - 1, 0);
        fi(0, sz(a) - 1) {
            if (a[i]) {
                fj(0, sz(b) - 1) {
                    res[i + j] = add(res[i + j], mul(a[i], b[j], MOD), MOD);
                }
            }
        }
        return res;
    }
 
    vi multiply(const vi &a, const vi &b, int MOD, bool square = false) {           
        if ( (ll) sz(a) * sz(b) <= MAGIC) {
            return multiplybrute(a, b, MOD);
        }
        vector<vi> resp;
        fi(0, 2) {  
            vi c;
            multiply(a, b, c, ROOTS[i], IROOTS[i], ROOT_PW, MODS[i], square);
            resp.pb(c);
        }
        vi res;
        fi(0, sz(resp[0]) - 1) {
            vi rems;
            fo(0, 2) {
                rems.pb(resp[o][i]);                
            }
            int g = crt(rems, MOD);
            res.pb(g);
        }
        return res;
    }
};
 
vi getinversepolynomial(vi vec) {//return inverse polynomial
    vi v = vec;
    vi res(1, inv(vec[0], MOD));
    int n = sz(vec);
    for (int i = 1; true; i++) {
        
        int len = (1 << i);
        
        vi a(len, 0);       
        fj(0, len / 2 - 1) {
            a[j] = mul(res[j], 2, MOD);
        }
        
        vi res2 = FFT::multiply(res, res, MOD, true);
        if (sz(res2) > len) res2.resize(len);
 
        vi c;
        c.insert(c.begin(), v.begin(), v.begin() + min(len, sz(v)));
        vi b = FFT::multiply(res2, c, MOD);
 
        res.resize(len);
        fj(0, len - 1) {
            res[j] = sub(a[j], b[j], MOD);
        }
 
        if (len > n) break;
    }
 
    return res;
} 
 
int f[MAX];//factorials
int invf[MAX];//inverse factorials
int n, m, k;
 
int getc(int n, int k) {//return combinations
    if (n == k) return 1;
    if (k > n) return 0;
    return mul(f[n], mul(invf[k], invf[n - k], MOD), MOD);
}
 
void init() {
    f[0] = 1;
    fi(1, MAX - 1) f[i] = mul(f[i - 1], i, MOD);
    fi(0, MAX - 1) invf[i] = inv(f[i], MOD);
}
 
int b[MAX];//b[i] = i-th Bernoulli number of second kind
 
void findbernoulli() {//calculates Bernoulli numbers of second kind
    vi ve;
    fi(0, n + m + 1) {
        ve.pb(invf[i + 1]);
    }   
    ve = getinversepolynomial(ve);      
    ve.resize(n + m + 2);   
    fi(0, n + m + 1) {
        b[i] = mul(ve[i], f[i], MOD);       
    }   
    b[1] = inv(2, MOD);         
}
 
int degsum[MAX];//degsum[i] = sum_{j = 1}{k} j ^ i
 
void findfauhalber() {//calculates Fauhalbers sums, save them in degsum
    vi a = vi(n + m + 2, 0);
    fi(0, n + m + 1) {
        int v1 = ::b[i];        
        int v2 = invf[i];
        int v3 = inv(bp(k, i, MOD), MOD);       
        int v = mul(v1, mul(v2, v3, MOD), MOD);     
        a[i] = v;
    }       
    vi b = vi(n + m + 2, 0);
    fi(0, n + m + 1) {
        b[i] = invf[i];
    }
    vi c = FFT::multiply(a, b, MOD);
    fi(0, n + m + 1) {
        c[i] = mul(c[i], f[i], MOD);        
        c[i] = mul(c[i], bp(k, i, MOD), MOD);
        c[i] = sub(c[i], ::b[i], MOD);
        c[i] = mul(c[i], inv(i, MOD), MOD);
    }   
    fi(0, n + m) {
        degsum[i] = c[i + 1];
    }   
}
 
void precalc() {
//  n = 100 * 1000;
//  m = 100 * 1000;
    findbernoulli();        
    findfauhalber();            
}
 
void solve() {
    int ans = 0;
    {//sum_{i = 1}{k} i ^ (n + m) + (i - 1) ^ (n + m)
        int ab = mul(2, degsum[n + m], MOD);
        ab = sub(ab, bp(k, n + m, MOD), MOD);
        ans = add(ans, ab, MOD);
    }
    {//sum_{i = 1}{k} (i - 1) ^ n * i ^ m       
        fj(0, m) {
            int sign = (j % 2 == 1 ? MOD - 1 : 1);
            int v1 = getc(m, j);
            int v2 = degsum[n + m - j];
            int v = mul(v1, mul(v2, sign, MOD), MOD);
            ans = sub(ans, v, MOD);
        }
    }
    {//sum_{i = 1}{k} (i - 1) ^ m * i ^ n       
        fj(0, n) {
            int sign = (j % 2 == 1 ? MOD - 1 : 1);
            int v1 = getc(n, j);
            int v2 = degsum[m + n - j];
            int v = mul(v1, mul(v2, sign, MOD), MOD);
            ans = sub(ans, v, MOD);
        }
    }
    printf("%d\n", ans);
}
 
int main() {
#ifdef LOCAL
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    START_TIME = (double)clock();
#endif
 
    init();
    int t;
    scanf("%d", &t);    
    while (t--) {
        scanf("%d %d %d", &n, &m, &k);  
        precalc();
        solve();
    }   
 
    exit();
    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
#define endl '\n'
 
//#pragma GCC optimize ("O3")
//#pragma GCC target ("sse4")
 
#define SZ(x) ((int)x.size())
#define ALL(V) V.begin(), V.end()
#define L_B lower_bound
#define U_B upper_bound
#define pb push_back
 
using namespace std;
template<class T, class T2> inline int chkmax(T &x, const T2 &y) { return x < y ? x = y, 1 : 0; }
template<class T, class T2> inline int chkmin(T &x, const T2 &y) { return x > y ? x = y, 1 : 0; }
const int MAXN = (1 << 20);
const int mod = (int)1e9 + 7; 
 
int pw(int x, int p)
{
    int r = 1;
    while(p)
    {
        if(p & 1) r = r * 1ll * x % mod;
        x = x * 1ll * x % mod;
        p >>= 1;
    }
 
    return r;
}
 
int inv(int x) { return pw(x, mod - 2); } 
 
int n, m, x;
 
void read()
{
    cin >> n >> m >> x;
}
 
int y[MAXN];
int fact[MAXN];
 
void solve()
{
    y[0] = 0;
    for(int i = 1; i <= n + m; i++)
        y[i] = (y[i - 1] + (pw(i, n) - pw(i - 1, n) + mod) * 1ll * (pw(i, m) - pw(i - 1, m) + mod)) % mod;
 
    if(x <= n + m)
    {
        cout << y[x] << endl;
        return;
    }
 
    // F(k) = SUM y[i] * L(i, k)
    // L(i, k) = PRODUCT (k - j) / (i - j)
 
    int product_up = 1;
    for(int i = 0; i <= n + m; i++)
        product_up = product_up * 1ll * (x - i) % mod; 
 
    fact[0] = 1;
    for(int i = 1; i <= n + m; i++)
        fact[i] = fact[i - 1] * 1ll * i % mod;
 
    int ans = 0;
    for(int i = 0; i <= n + m; i++)
    {
        int v = product_up * 1ll * inv(x - i) % mod;
        v = v * 1ll * inv(fact[i]) % mod;
        v = v * 1ll * inv(fact[n + m - i]) % mod;
        if((n + m - i) & 1) 
            v = (mod - v) % mod;
 
        v = v * 1ll * y[i] % mod;
        ans = (ans + v) % mod;
    }
 
    cout << ans << endl;
}
 
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
 
    int T;
    cin >> T;
    while(T--)
    {
        read();
        solve();
    }
 
    return 0;
}
Editorialist's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
int powi(int a,int b){
    int ans=1;
    while(b){
        if(b%2){
            ans*=a;
            ans%=mod;
        }
        a*=a;
        a%=mod;
        b/=2;
    }
    return ans;
}
int getinv(int i){
    return powi(i,mod-2);
}
int fact[412345],getsum[412345],pn[412345],pm[412345];
 
main(){
    std::ios::sync_with_stdio(false); cin.tie(NULL);
    int t;
    cin>>t;
   
    int i;
    fact[0]=1;
    f(i,1,412345){
        fact[i]=fact[i-1]*i;
        fact[i]%=mod;
    }
    //cout<<"das"<<endl;
    while(t--){
        int n,m;
        cin>>n>>m;
         int k;
        cin>>k;
        rep(i,2*(n+m)+10){
            pn[i]=powi(i,n);
            pm[i]=powi(i,m);
            //cout<<pn[i]<<endl;
        }
        int val;
        getsum[0]=0;    
        f(i,1,2*(n+m)+10){
            val=pn[i]-pn[i-1];
            val*=pm[i]-pm[i-1];
            val%=mod;
            val+=mod;
            val%=mod;
            getsum[i]=getsum[i-1]+val;
            getsum[i]%=mod;
        }
        if(k<2*(n+m)+10){
            cout<<getsum[k]<<endl;
            continue;
        }
        ll gg=1;
        f(i,1,2*(n+m)+10){
            gg*=(k-i);
            gg%=mod;
        }
        ll wow=0;
        f(i,1,2*(n+m)+10){
            val=fact[i-1];
            val*=fact[2*(n+m)+9-i];
            val%=mod;
            if((2*(n+m)+9-i)%2){
                val*=-1;
                val%=mod;
                val+=mod;
                val%=mod;
            }
            val=getinv(val);
            val*=gg;
            val%=mod;
            val*=getinv(k-i);
            val%=mod;
            val*=getsum[i];
            val%=mod;
            wow+=val;
            wow%=mod;
        }
        cout<<wow<<endl;
 
    }
    return 0;   
}

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

8 Likes

For the people who are new to Lagrange interpolation method, this question can be a good upsolve

4 Likes

Here’s another solution that I think should work but I gave up trying to implement it since I figured it out around 1:00 AM the night before and I had never implemented FFT.

After you get the relation \sum_{x = 1}^{k} (x^m - (x - 1)^m) (x^n - (x - 1)^n) rewrite it as such

\sum_{x = 1}^{k} (x^m - (x - 1)^m) (x^n - (x - 1)^n) = \sum_{x = 0}^{k - 1} ((x + 1)^m - x^m) ((x + 1)^n - x^n)\\ = \sum_{x = 0}^{k - 1} (x + 1)^{m + n} + x^{m + n} - x^m(x + 1)^n - (x+1)^mx^n\\ = 2 S_{m + n} + k^{m + n} - \sum_{x = 0}^{k - 1} \left(x^m(x + 1)^n + (x+1)^mx^n \right)\\

where S_{r} = 1^r + 2^r + \dots + (k - 1)^r.
Now, looking at x^m(x + 1)^n

\sum_{x = 0}^{k - 1} x^m(x + 1)^n = \sum_{x = 0}^{k - 1} x^m\sum_{j = 0}^n \binom{n}{j} x^j = \sum_{j = 0}^n \binom{n}{j} \sum_{x = 0}^{k - 1} x^{m + j} = \sum_{j = 0}^n \binom{n}{j} S_{m + j}

Thus, the final answer would be

2 S_{m + n} + k^{m + n} - \sum_{j = 0}^n \binom{n}{j} S_{m + j} - \sum_{j = 0}^m \binom{m}{j} S_{n + j}\\

So, once you find S_r for all r \le m + n, you can simply find the answer in O(m + n) time since the factorials of binomial coefficient can be precomputed.

If we let

f(x) = \sum_{r = 0}^{\infty} S_r \frac{x^r}{r!} = \sum_{r = 0}^{\infty} \sum_{j = 0}^{k - 1} j^r \frac{x^r}{r!} = \sum_{j = 0}^{k - 1} \sum_{r = 0}^{\infty} \frac{(jx)^r}{r!} = \sum_{j = 0}^{k - 1} e^{jx} = \frac{e^{kx} - e^x}{e^x - 1}.

Thus, if we find the first m + n coefficients in the power series of f(x) we’d be done.

The numerator is easy to deal with since it’s simply

e^{kx} - e^x = \sum_{j = 0}^{\infty} (k^j - 1) \frac{x^j}{j!}

For the denominator, note that for each power series there exists it’s inverse, ie, for each g(x) = \sum_{j = 0}^{\infty} a_j x^j there exists an h(x) = \sum_{j = 0}^{\infty} b_jx^j such that h(x) g(x) = 1.
Therefore, we need to find the inverse power series of

e^x - 1 = \sum_{j = 1}^{\infty} \frac{x^j}{j!}.

However since we only need to find the first m + n coefficients of f(x) we can truncate the sequence at m + n, ie, find the power series

\frac{1}{\sum_{j = 1}^{m + n} \frac{x^j}{j!}}.

after which we can simply multiply the two sequences to get our required S.

Finding the inverse is a classic application of Newton-Ralphson’s method, for a polynomial f let g_j be the polynomial with coefficients such that the first 2^j coefficients for g_jf is 1. Then we get the relation

h(g) = f - g^{-1} \implies g_{n + 1} = g_n - \frac{h(g_n)}{h'(g_n)} = g_n - \frac{f - g_n^{-1}}{g_n^{-2}} = g_n(2 - g_nf)

this is precisely the method described in CF user adamant’s article here. There is also an alternate proof described in CF 438E’s editorial.

The most important thing that has not been described here is to multiply two polynomials efficiently. This is known as FFT and you can read more about it in adamant’s article linked above and the much-revered CLRS.
Note that we need to multiply the polynomials modulo 10^9+7 for which reading the section “Multiplication of arbitrary modulus” on page 6 on adamant’s article will be helpful.
Along with this cp-algorithms.com (translation of e-maxx.ru) has an article on FFT.

Also, while searching around I also found SERSUM - Editorial which was a problem on Long May 2018 Div 1 and the part about computing L in the editorial is exactly the problem about computing S tackled there.
And PFRUIT - Editorial too.
admant also has a blog post on computing S.

The final time complexity of the solution would be O((m + n) \log(m + n)) however, note that you can simply precompute the coefficients of \frac{e^x}{e^x - 1} till \max(m + n) and then multiply k^j to the j-th coefficient to get the sequence \frac{e^{kx}}{e^x - 1} from here, all you need to do is subtract the two series which would give a time complexity of O(a \log a + (m + n) T) where T is the number of test cases and a = \max(m + n) which is 2\times 10^5 here.

I tried to implement my solution at like 7:00 AM but I gave up when I realised when something was wrong with my attempt to find \frac{x}{e^x - 1} since I started getting waaayyy off answers if I tried to find more than ~30 coefficients of the series. Here’s my failed try if you want to see.

tl;dr: GenFun are cool, kids.

7 Likes

Although I could see problem was based on Lagrange Interpolation the moment I got it converted to the formula but was a bit confused seeing jtnydv25 solution having of execution time of around 4.5 sec on the problem. It took me days to figure out this approach so I find this better than the original problem.

What are some good resources out there so that one can improve in solving math based challenges like this?

can anyone explain…
how get (x^n -(x-1)^n) *(x^m -(x-1)^m) this equation??

Since max(Row) = max(Column) (explained above),
For K = x, x^n is all possible combinations(with max 0 to x-1) &
For K = x-1, (x-1)^n is all possible combinations(with max 0 to x-2)
Therefore, For max equals x-1, (x^n-(x-1)^n) is total possible combinations :slightly_smiling_face: @savaliya_vivek

1 Like

Can some one explain the paragraphs to me after the above line ?

these too
i. http://www.topcoder.com/stat?c=problem_statement&pm=10239
ii. http://www.topcoder.com/stat?c=problem_statement&pm=8725