MMNN01 - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Mamnoon Siam
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Online Fast Fourier Transformation, Generating Functions.

PROBLEM:

Consider all possible arrays of length N such that each element lies in the range [1, M]. Now for each array, partition it into a minimum number of contiguous partitions such that no partition contains any value more than once. Count the expected number of partitions over all possible arrays.

For example, assume array [1, 2, 2, 1, 1] We can see that there would be minimum three partitions, Namely [1, 2], [2, 1], [1]. Let denote f([1, 2, 2, 1, 1]) = 3 for this array.

We want to compute the expected value of f(x) where x \in S, S denote the set of all possible N length arrays with values in range [1, M]

QUICK EXPLANATION

  • Consider A_i denote the number of i-length sequences such that there is a partition ending at position i-1. Since for each i-length sequence with a partition ending at i-1 it contributes a partition for all choices of remaining n-i elements, Giving total sum of number of partitions as \sum_{i = 1}^{N} A_{i}*M^{N-i}. The required expected value can be found by dividing this by M^N.
    We also assume a fictitious partition to be ending just before the first element so as to count the first partition.
  • To Compute A_i, we iterate over the size of the partition ending at position i-1 and set all elements in the previous partition to be distinct and value at i-th position as one of the value in the previous partition. Specifically, we can write A_i = \sum_{x = 1}^{i-1} A_{i-x}*^{M-1}P_{x-1}*x, solving our problem in O(N^2) time.
  • In order to speed up above, we need to apply the Online version of FFT, which allows computing the same in O(N*log^2(N)) time.

EXPLANATION

I am going to discuss two solutions, quite similar in idea, one using Online FFT while another using Generating functions used by the setter.

Firstly, we can see, that the total number of N length arrays with values in [1, M] is M^N since for each position we have M independent choices.

So we need to count the sum of the number of partitions over all arrays. Suppose we see partitions as putting a bar between two elements. In the above example, the array would look like [1, 2 | 2, 1| 1] and consider one-based indexing.

We can see, that for any array, the number of bars is exactly one less than the number of partitions. For sake of simplicity, we can assume a bar before the first element, hence the number of bars will always be the same as the number of partitions. We need to count the number of bars over all possible arrays. (This happens because a bar between each pair of adjacent positions contributes independently to the number of partitions. You can read more about Contribution Trick to get a better idea of this. This blog might help.)

Let us denote A_x as the number of x length sequences with values in range [1, M] such that there is a bar between x element and x-1 element, and all elements up to x-th element are fixed. Now, we can see that a bar between positions x-1 and x isn’t affected by any element after position x, thus for each x length sequence with a bar between x-1 and x, it contributes M^{N-x} bars to the final answer.

Hence, our final answer is \displaystyle\frac{\sum_{i = 1}^{N} A_i*M^{n-i}}{M^N} Only task left is to compute A_i. As a base case, we already have A_1 = M as we have M choices for first element.

Now, let us assume we have already computed A_j for j < i and want to compute A_i. Let us iterate over the size of the last partition. Suppose the size of partition ending at position i-1 is x. So, there must be a bar between position i-x-1 and i-x too. Hence, we need all values lying in subarray [i-x, i-1] to be distinct and A_i should be equal to any value in subarray [i-x, i-1].

Since element at position i-x is already fixed, we need values in subarray [i-x+1, i-1] to be distinct and none should be same as A_{i-x}. This can be done in ^{M-1}P_{x-1} ways. Also, A_i can be chosen in x ways.

Hence, we can write A_i = \displaystyle\sum_{x = 1}^{i-1}A_{i-x}* ^{M-1}P_{x-1}*x. This recurrence allows us to compute the answer in O(N^2) time which, sadly, isn’t fast enough.

Let’s write B_i = ^{M-1}P_{i-1}*i, then we can write A_i = \displaystyle\sum_{x = 1}^{i-1}A_{i-x}*B_x FFT may seem tempting, but we cannot naively apply FFT and hope for AC, as in order to get i-th term of A, we need all previous terms. But, there exists an online variant of Fast Fourier Transform, which, with extra log factor, can compute this recurrence online. You may read about it in the blog here.

Setter’s solution
Not ignoring all ideas discussed above, denote f(n, k) as the number of n-length sequence with k partitions and last element is fixed. Clearly f(n, 1) = ^{M-1}P_{n-1}
Now, we can write f(n, k) = \displaystyle\sum_{x = 1}^{n-1} f(n-x, k-1)*B_x

Writing as polynomials C(x) = \sum_{n > 0}f(n, 1)*x^n and B(x) = \sum_{n>0} {^{M-1}P_{n-1}}*n*x^n and We can see that f(n, k) is the coefficient of x^n in C(x)*B(x)^{k-1}

Let [C(x)]_n denote coefficient of x^n in C(x)

Our required answer is S = \sum_{i = 1}^{n} f(n, i)*i*M (M choices for last fixed element)

We can reqrite above as S = M*\displaystyle\sum_{i = 1}^{n} \Big[ i*C(x)*B(x)^{i-1} \Big]_n
S = M*\Big[ C(x)* \displaystyle\sum_{i = 0}^{n-1} (i+1)*B(x)^i \Big]_{n}. Using formula for AGP sum, we can rewrite above as

S = M*\Big[ C(x)*\Big[ \big\{ \frac{1}{B(x)-1} \big\} * \big\{ (n-1)*B(x)^n - \frac{B(x)^n-B(x)}{B(x)-1}\big\} + \big\{ \frac{B(x)^n-1}{B(x)-1}\big\} \Big] \Big]_n

Now, setter computes the above S and divides by M^N to get the expected value.

TIME COMPLEXITY

The time complexity is O(N*log^2(N)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

namespace algebra {
	const int inf = 1e9;
	const int magic = 50;
	const int mod = 1e9 + 7;
	const int N = 1 << 19;
	 
	inline int mul(int a, int b) {
	#if !defined(_WIN32) || defined(_WIN64)
	  return (int) ((long long) a * b % mod);
	#endif
	  unsigned long long x = (long long) a * b;
	  unsigned xh = (unsigned) (x >> 32), xl = (unsigned) x, d, m;
	  asm(
		"divl %4; \n\t"
		: "=a" (d), "=d" (m)
		: "d" (xh), "a" (xl), "r" (mod)
	  );
	  return m;
	}

	inline int qpow(int b, int p) {
		int ret = 1;
		while(p) {
			if(p & 1) ret = mul(ret, b);
			b = mul(b, b), p >>= 1;
		} return ret;
	}

	namespace tools {
		int npr_for_n_equal = -1;
		int inv[N], f[N], finv[N];
		int npr_table[N];
		inline int factorial(int n) {
			return f[n];
		}
		inline int invfact(int n) {
			return finv[n];
		}
		inline int inverse(int n) {
			if(0 <= n and n < N) return inv[n];
			return qpow(n, mod - 2);
		}
		void init() {
			inv[0] = inv[1] = f[0] = finv[0] = 1;
			for(int i = 2; i < N; i++) {
				inv[i] = mul(inv[mod % i], (mod - mod/i));
			}
			for(int i = 1; i < N; i++) {
				finv[i] = mul(finv[i-1], i);
				f[i] = mul(f[i-1], i);
			}
		}
		void init_npr(int n, int r) {
			// prep table for fixed n
			// and r in range [0, n]
			assert(0 <= n);
			assert(0 <= r and r < N);
			npr_for_n_equal = n;
			npr_table[0] = 1;
			for(int i = 1; i <= r; i++, n--) {
				npr_table[i] = mul(npr_table[i-1], n);
			}
		}
		int npr(int n, int r) {
			assert(npr_for_n_equal == n);
			if(r > n or r < 0) return 0;
			return npr_table[r];
		}
	}

	namespace fft {
		using ld = double;
		ld PI = acos(-1);

		struct point {
			ld a, b;
			point(ld _a = 0.0, ld _b = 0.0) : a(_a), b(_b) {}
			const point operator + (const point &c) const 
			    { return point(a + c.a, b + c.b); }
			const point operator - (const point &c) const
			    { return point(a - c.a, b - c.b); }
			const point operator * (const point &c) const
			    { return point(a * c.a - b * c.b, a * c.b + b * c.a); }
			const point conj() const { return point(a, -b); }
		}; 

		vector<int> rev; 
		vector<point> w; 

		void prepare(int n) { // n is power of 2
			int sz = __builtin_ctz(n); 
			if(rev.size() != n) {
			    rev.assign(n, 0); 
			    for(int i = 0; i < n; ++i) {
			        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (sz - 1)); 
			    }
			}

			if(w.size() >= n) return; 
			if(w.empty()) w = {{0, 0}, {1, 0}}; 
			
			sz = __builtin_ctz(w.size());
			w.resize(max(2, n));

			// w[n + i] = w_{2n}^i, n power of 2, i < n
			while((1 << sz) < n) {
			    double ang = 2 * PI / (1 << (sz + 1)); 
			    point wn(cos(ang), sin(ang));
			    for(int i = 1 << (sz - 1); i < (1 << sz); ++i) {
			        w[i << 1] = w[i]; 
			        w[i << 1 | 1] = w[i] * wn; 
			    } ++sz; 
			}
		}

		void fft(point *p, int n) {
			prepare(n); 
			for(int i = 1; i < n - 1; ++i) 
			    if(i < rev[i]) swap(p[i], p[rev[i]]);
			for(int h = 1; h < n; h <<= 1) {
			    for(int s = 0; s < n; s += h << 1) {
			        for(int i = 0; i < h; ++i) {
			            point &u = p[s + i], &v = p[s + h + i], 
			                t = v * w[h + i];
			            v = u - t; u = u + t;
			        }
			    }
			} 
		}

		template<typename T>
		void mul_slow(vector<T> &a, const vector<T> &b) {
			vector<T> res(a.size() + b.size() - 1);
			for(size_t i = 0; i < a.size(); i++) {
				for(size_t j = 0; j < b.size(); j++) {
					res[i + j] += a[i] * b[j];
				}
			}
			a = res;
		}

		template<typename T>
		void mul(vector<T> &a, const vector<T> &b) {
			if(min(a.size(), b.size()) < magic) {
				mul_slow(a, b);
				return;
			}

			int n = a.size();
			int m = b.size();

			static point f[N], g[N], u[N], v[N];

			int sz = 1;
			while(sz < n + m - 1) sz <<= 1;
			for(int i = 0; i < n; i++) f[i] = point(a[i] & 32767, a[i] >> 15);
			for(int i = 0; i < m; i++) g[i] = point(b[i] & 32767, b[i] >> 15);
			for(int i = n; i < sz; i++) f[i] = point(0, 0); 
			for(int i = m; i < sz; i++) g[i] = point(0, 0);
			prepare(sz); 

			fft(f, sz); fft(g, sz);
			for(int i = 0; i < sz; i++) {
			    int j = (sz - i) & (sz - 1);
			    point a1 = (f[i] + f[j].conj()) * point(0.5, 0);
			    point a0 = (f[i] - f[j].conj()) * point(0, -0.5);
			    point b1 = (g[i] + g[j].conj()) * point(0.5 / sz, 0);
			    point b0 = (g[i] - g[j].conj()) * point(0, -0.5 / sz);
			    u[j] = a1 * b1 + a0 * b0 * point(0, 1);
			    v[j] = a1 * b0 + a0 * b1;
			}
			fft(u, sz); fft(v, sz);
			a.clear();
			for(int i = 0; i < sz; i++) {
			    T aa = (long long)(u[i].a + 0.5);
			    T bb = (long long)(v[i].a + 0.5);
			    T cc = (long long)(u[i].b + 0.5);
			    a.emplace_back(aa + T(bb << 15) + T(cc << 30));
			}
			return;
		}
	}
	template<typename T>
	T bpow(T x, int n) {
		T ret = T(1);
		while(n) {
			if(n & 1) ret = ret * x;
			x = x * x, n >>= 1;
		} return ret;
	}

	template<int m>
	struct modular {
		long long r;
		modular() : r(0) {}
		modular(int64_t rr) : r(rr) {if(abs(r) >= m) r %= m; if(r < 0) r += m;}
		modular inv() const {return bpow(*this, m - 2);}
		modular operator * (const modular &t) const {return (r * t.r) % m;}
		modular operator / (const modular &t) const {return *this * t.inv();}
		modular operator += (const modular &t) {r += t.r; if(r >= m) r -= m; return *this;}
		modular operator -= (const modular &t) {r -= t.r; if(r < 0) r += m; return *this;}
		modular operator + (const modular &t) const {return modular(*this) += t;}
		modular operator - (const modular &t) const {return modular(*this) -= t;}
		modular operator *= (const modular &t) {return *this = *this * t;}
		modular operator /= (const modular &t) {return *this = *this / t;}
	
		bool operator == (const modular &t) const {return r == t.r;}
		bool operator != (const modular &t) const {return r != t.r;}
	
		operator int64_t() const {return r;}
		// operator int32_t() const {return (int)r;}
	};
	template<int T>
	istream& operator >> (istream &in, modular<T> &x) {
		return in >> x.r;
	}


	template<typename T>
	struct poly {
		vector<T> a;
	
		void normalize() { // get rid of leading zeroes
			while(!a.empty() && a.back() == T(0)) {
				a.pop_back();
			}
		}
	
		poly(){}
		poly(T a0) : a{a0}{normalize();}
		poly(vector<T> t) : a(t){normalize();}
	
		poly operator += (const poly &t) {
			a.resize(max(a.size(), t.a.size()));
			for(size_t i = 0; i < t.a.size(); i++) {
				a[i] += t.a[i];
			}
			normalize();
			return *this;
		}
		poly operator -= (const poly &t) {
			a.resize(max(a.size(), t.a.size()));
			for(size_t i = 0; i < t.a.size(); i++) {
				a[i] -= t.a[i];
			}
			normalize();
			return *this;
		}
		poly operator + (const poly &t) const {return poly(*this) += t;}
		poly operator - (const poly &t) const {return poly(*this) -= t;}
	
		poly mod_xk(size_t k) const { // get same polynomial mod x^k
			k = min(k, a.size());
			return vector<T>(begin(a), begin(a) + k);
		}
		poly mul_xk(size_t k) const { // multiply by x^k
			poly res(*this);
			res.a.insert(begin(res.a), k, 0);
			return res;
		}
		poly div_xk(size_t k) const { // divide by x^k, dropping coefficients
			k = min(k, a.size());
			return vector<T>(begin(a) + k, end(a));
		}
		poly substr(size_t l, size_t r) const { // return mod_xk(r).div_xk(l)
			l = min(l, a.size());
			r = min(r, a.size());
			return vector<T>(begin(a) + l, begin(a) + r);
		}
		poly inv(size_t n) const { // get inverse series mod x^n
			assert(!is_zero());
			poly ans = a[0].inv();
			size_t a = 1;
			while(a < n) {
				poly C = (ans * mod_xk(2 * a)).substr(a, 2 * a);
				ans -= (ans * C).mod_xk(a).mul_xk(a);
				a *= 2;
			}
			return ans.mod_xk(n);
		}
	
		poly operator *= (const poly &t) {fft::mul(a, t.a); normalize(); return *this;}
		poly operator * (const poly &t) const {return poly(*this) *= t;}

		poly operator *= (const T &x) {
			for(auto &it: a) {
				it *= x;
			}
			normalize();
			return *this;
		}
		poly operator /= (const T &x) {
			for(auto &it: a) {
				it /= x;
			}
			normalize();
			return *this;
		}
		poly operator * (const T &x) const {return poly(*this) *= x;}
		poly operator / (const T &x) const {return poly(*this) /= x;}
	
		void print() const {
			for(auto it: a) {
				cout << it << ' ';
			}
			cout << endl;
		}
	
		T& lead() { // leading coefficient
			return a.back();
		}
		int deg() const { // degree
			return a.empty() ? -inf : a.size() - 1;
		}
		bool is_zero() const { // is polynomial zero
			return a.empty();
		}
		T operator [](int idx) const {
			return idx >= (int)a.size() || idx < 0 ? T(0) : a[idx];
		}
	
		T& coef(size_t idx) { // mutable reference at coefficient
			return a[idx];
		}
		bool operator == (const poly &t) const {return a == t.a;}
		bool operator != (const poly &t) const {return a != t.a;}
	
		poly deriv() { // calculate derivative
			vector<T> res;
			for(int i = 1; i <= deg(); i++) {
				res.push_back(T(i) * a[i]);
			}
			return res;
		}
		poly integr() { // calculate integral with C = 0
			vector<T> res = {0};
			for(int i = 0; i <= deg(); i++) {
				res.push_back(a[i] * T(tools::inverse(i+1)));
			}
			return res;
		}
		size_t leading_xk() const { // Let p(x) = x^k * t(x), return k
			if(is_zero()) {
				return inf;
			}
			int res = 0;
			while(a[res] == T(0)) {
				res++;
			}
			return res;
		}
		poly log(size_t n) { // calculate log p(x) mod x^n
			assert(a[0] == T(1));
			return (deriv().mod_xk(n) * inv(n)).integr().mod_xk(n);
		}
		poly exp(size_t md) { // calculate exp p(x) mod x^n
			if(is_zero()) {
				return T(1);
			}
			assert(a[0] == T(0));
			poly ans = T(1);
			size_t sz = 1;
			while(sz < md) {
				poly C = ans.log(2 * sz).div_xk(sz) - substr(sz, 2 * sz);
				ans -= (ans * C).mod_xk(sz).mul_xk(sz);
				sz *= 2;
			}
			return ans.mod_xk(md);
		}
		poly pow(size_t k, size_t n) { // calculate p^k(n) mod x^n
			if(is_zero()) {
				return *this;
			}
			int i = leading_xk();
			T j = a[i];
			poly t = div_xk(i) / j;
			return bpow(j, k) * (t.log(n) * T(k)).exp(n).mul_xk(i * k).mod_xk(n);
		}
	};
	template<typename T>
	poly<T> operator * (const T& a, const poly<T>& b) {
		return b * a;
	}
} using namespace algebra;

// algebra template ^^

typedef modular<mod> base;
typedef poly<base> polyn;

using ll = long long;

base solve(int n, int m) {
	/* 	p_i = P(R-1, i-1)
		q_i = P(R-1, i-1) * i
	*/
	polyn P, Q;
	base nPr = 1, dR = m - 1;
	P.a.emplace_back(0);
	Q.a.emplace_back(0);
	for(int i = 1; i <= n; i++) {
		P.a.emplace_back(nPr);
		Q.a.emplace_back(nPr * base(i));
		nPr *= dR;
		dR -= 1; // once dR becomes 0, nPr will be 0 afterwards
	}
	/* 	
		$$R(x)=P(x)\left[\left\{\frac{1}{Q(x)-1}\right\}\left\{(n-1)Q(x)^n-\frac{Q(x)^n-Q(x)}{Q(x)-1}\right\}+\left\{\frac{Q(x)^n-1}{Q(x)-1}\right\}\right]$$
		iQ(x) = 1 / (Q(x) - 1)
		Qn = Q(x)^n
	*/
	polyn iQ = (Q - base(1)).inv(n+1);
	polyn Qn = Q.pow(n, n+1);
	polyn S = ((Qn - base(1)) * iQ).mod_xk(n+1);

	polyn mQ = (iQ * ((base(n-1) * Qn) - ((Qn - Q) * iQ).mod_xk(n+1))).mod_xk(n+1) + S;

	polyn R = P * mQ;
	base ans  = R[n] * base(m);
	ans = ans / bpow(base(m), n);
	return ans;
}

int main(int argc, char const *argv[])
{
	// freopen("in", "r", stdin);
	int t; scanf("%d", &t);
	while(t--) {
		int n, m;
		scanf("%d %d", &n, &m);
		int ans = solve(n, m);
		printf("%d\n", ans);
	}
	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define viii vector<tri>
 
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
#define int ll
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
#define ld double
#define vll vector<ll >
#define endl "\n"
 
struct base{
	ld x,y;
	base(){x=y=0;}
	base(ld _x, ld _y){x = _x,y = _y;}
	base(ld _x){x = _x, y = 0;}
	void operator = (ld _x){x = _x,y = 0;}
	ld real(){return x;}
	ld imag(){return y;}
	base operator + (const base& b){return base(x+b.x,y+b.y);}
	void operator += (const base& b){x+=b.x,y+=b.y;}
	base operator * (const base& b){return base(x*b.x - y*b.y,x*b.y+y*b.x);}
	void operator *= (const base& b){ld p = x*b.x - y*b.y, q = x*b.y+y*b.x; x = p, y = q;}
	void operator /= (ld k){x/=k,y/=k;}
	base operator - (const base& b){return base(x - b.x,y - b.y);}
	void operator -= (const base& b){x -= b.x, y -= b.y;}
	base conj(){ return base(x, -y);}
	base operator / (ld k) { return base(x / k, y / k);}
	void Print(){ cerr << x <<  " + " << y << "i\n";}
};
double PI = 2.0*acos(0.0);
const int MAXN = 19;
const int maxn = (1<<MAXN);
base W[maxn],invW[maxn], P1[maxn], Q1[maxn];
void precompute_powers(){
	for(int i = 0;i<maxn/2;i++){
	    double ang = (2*PI*i)/maxn; 
	    ld _cos = cos(ang), _sin = sin(ang);
	    W[i] = base(_cos,_sin);
	    invW[i] = base(_cos,-_sin);
	}
}
void fft (vector<base> & a, bool invert) {
	int n = (int) a.size();
 
	for (int i=1, j=0; i<n; ++i) {
	    int bit = n >> 1;
	    for (; j>=bit; bit>>=1)
	        j -= bit;
	    j += bit;
	    if (i < j)
	        swap (a[i], a[j]);
	}
	for (int len=2; len<=n; len<<=1) {
	    for (int i=0; i<n; i+=len) {
	        int ind = 0,add = maxn/len;
	        for (int j=0; j<len/2; ++j) {
	            base u = a[i+j],  v = (a[i+j+len/2] * (invert?invW[ind]:W[ind]));
	            a[i+j] = (u + v);
	            a[i+j+len/2] = (u - v);
	            ind += add;
	        }
	    }
	}
	if (invert) for (int i=0; i<n; ++i) a[i] /= n;
}
 
// 4 FFTs in total for a precise convolution
void mul_big_mod(vll &a, vll & b, ll mod){
	int n1 = a.size(),n2 = b.size();
	int final_size = a.size() + b.size() - 1;
	int n = 1;
	while(n < final_size) n <<= 1;
	vector<base> P(n), Q(n);
	int SQRTMOD = (int)sqrt(mod) + 10;
	for(int i = 0;i < n1;i++) P[i] = base(a[i] % SQRTMOD, a[i] / SQRTMOD);
	for(int i = 0;i < n2;i++) Q[i] = base(b[i] % SQRTMOD, b[i] / SQRTMOD);
	fft(P, 0);
	fft(Q, 0);
	base A1, A2, B1, B2, X, Y;
	for(int i = 0; i < n; i++){
	    X = P[i];
	    Y = P[(n - i) % n].conj();
	    A1 = (X + Y) * base(0.5, 0);
	    A2 = (X - Y) * base(0, -0.5);
	    X = Q[i];
	    Y = Q[(n - i) % n].conj();
	    B1 = (X + Y) * base(0.5, 0);
	    B2 = (X - Y) * base(0, -0.5);
	    P1[i] = A1 * B1 + A2 * B2 * base(0, 1);
	    Q1[i] = A1 * B2 + A2 * B1;
	}
	for(int i = 0; i < n; i++) P[i] = P1[i], Q[i] = Q1[i];
	fft(P, 1);
	fft(Q, 1);
	a.resize(final_size);
	for(int i = 0; i < final_size; i++){
	    ll x = (ll)(P[i].real() + 0.5);
	    ll y = (ll)(P[i].imag() + 0.5) % mod;
	    ll z = (ll)(Q[i].real() + 0.5);
	    a[i] = (x + ((y * SQRTMOD + z) % mod) * SQRTMOD) % mod;
	}
}
 
void polypower_big_mod(vll& a, int n, ll mod, int max_size = 100000000){
	vll x = {1}, b;
	while(n){
	    if(n&1) mul_big_mod(x,a,mod);
	    b = a;
	    mul_big_mod(a,b,mod);
	    n>>=1;
	    x.resize(min((int)x.size(), max_size));
	    a.resize(min((int)a.size(), max_size));
	}
	a = x;
}   
 
ll mod = (1000*1000*1000+7);
ll getpow(ll a,ll b){
	ll ans = 1;
	while(b){
		if(b%2){
			ans*=a;
			ans%=mod;
		}
		a*=a;
		a%=mod;
	    b/=2;
	}
	return ans;
}
int a[1234567],b[1234567];
 
vector<vi> vec(123);
int rempow[1234];
int mpow[1234567];
main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	precompute_powers();
	cin>>t;
	while(t--){
		int n,m;
		cin>>n>>m;
		ll i;
		ll val=1;
		ll gg=m-1;	
		rep(i,26){
			vec[i].clear();
		}
	    mpow[0]=1;
	    f(i,1,n+10){
	        mpow[i]=mpow[i-1]*m;
	        mpow[i]%=mod;
	        a[i]=0;
	    }
		f(i,1,2*n+10){
			b[i]=val*i;
			b[i]%=mod;
			val*=gg;
			val%=mod;
			gg--;
		}
		int sz=2;
		int ind=1;
	    rempow[1]=3;
		f(i,3,2*n+10){
			vec[ind].pb(b[i]);
			sz--;
			if(sz==0){
				ind++;
				sz=(1<<ind);
	            rempow[ind]=i+1;
			}
		}
		a[0]=m;
		a[1]+=a[0]*b[1];
		a[2]+=a[0]*b[2];
	    a[1]%=mod;
	    a[2]%=mod;
	    int j,k,curpow;
	    vi wow,gao;
	    int sumi=0;
		f(i,1,n){
			a[i+1]+=a[i]*b[1];
			a[i+2]+=a[i]*b[2];
	        a[i+1]%=mod;
	        a[i+2]%=mod;
			f(j,1,23){
				if(i%(1<<j)==0){
	                wow.clear();
	                gao.clear();
					fd(k,(1<<j),1){
						wow.pb(a[i-k]);
					}
	                rep(k,vec[j].size()){
	                    gao.pb(vec[j][k]);
	                }
	                mul_big_mod(wow,gao,mod);
 
	                curpow=rempow[j]+i-(1<<j);
	                rep(k,wow.size()){
	                    if(curpow>n+3)
	                        break;
	                    a[curpow]+=wow[k];
	                    if(a[curpow]>=mod)
	                        a[curpow]-=mod;
	                    curpow++;
	                }
 
				}
 
			}
	        // //brute
	        // int boo=0;
	        // rep(j,i){
	        //     boo+=a[j]*b[i-j];
	        //     boo%=mod;
	        // }
	        // cout<<boo<<" ds "<<a[i]<<endl;
	        //cout<<(a[i]*mpow[n-1-i])<<endl;
			sumi+=(a[i]*mpow[n-1-i]);
	        sumi%=mod;
		}
	    sumi+=mpow[n];
		sumi%=mod;
	    //cout<<sumi<<endl;
	    ll invden = getpow(getpow(m,n),mod-2);
	    //cout<<sumi<<endl;
		sumi*=invden;
	    sumi%=mod;	
		cout<<sumi<<endl;
	}
	return 0;   
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class MMNN01{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni();long m = nl();
	    long[] p = new long[1+n];//P^{M-1}_{i}
	    p[0] = 1;
	    for(int i = 1; i<= n; i++) p[i] = (p[i-1]*(m-i))%MOD;
	    long[] a = new long[1+n], b = new long[1+n];
	    b[0] = 1;
	    for(int i = 1; i<= n; i++)b[i] = (p[i-1]*i)%MOD;
	    a[1] = m;
//        O(N^2) solution
//        for(int i = 2; i<= n; i++)
//            for(int j = 1; j< i; j++)
//                a[i] += (a[j]*b[i-j])%MOD;
//        O(N*log^2(N)) solution
	    for(int i = 1; i<= n-1; i++){
	        if(i+1 <= n)a[i+1] = (a[i+1]+a[i]*b[1])%MOD;
	        if(i+2 <= n)a[i+2] = (a[i+2]+a[i]*b[2])%MOD;
	        for(int pw = 2; i%pw == 0 && pw+1 <= n; pw*=2)
	            convolve(a, b,i-pw, i-1, pw+1, Math.min(pw*2, n));
	    }
	    long ans = 0;
	    for(int i = 1; i<= n; i++)ans = (ans+a[i]*pow(m, n-i))%MOD;
	    ans = (ans*pow(pow(m, n), MOD-2))%MOD;
	    pn(ans);
	}
	FFT fft = new FFT();
	void convolve(long[] a, long[] b, int l1, int r1, int l2, int r2){
	    long[] c = fft.multiplyPrecise(Arrays.copyOfRange(a, l1, r1+1), Arrays.copyOfRange(b, l2, r2+1));
	    for(int i = 0; i< c.length && l1+l2+i < a.length; i++)a[l1+l2+i] = (a[l1+l2+i]+c[i])%MOD;
	}
	long pow(long a, long p){
	    long o = 1;a%=MOD;
	    for(;p>0;p>>=1){
	        if((p&1)==1)o = (o*a)%MOD;
	        a = (a*a)%MOD;
	    }
	    return o;
	}
	long MOD = (long)1e9+7;
	// Copied FFT template
	// SET maxk appropriately!!! ~log(maxn) //%
	class FFT {
	    final int maxk = 20, maxn = (1 << maxk) + 1;
	    // Init: wR, wI, rR, rI, aR, aI   to   new double[maxn] !!!
	    //#
	    double[] wR = new double[maxn],
	            wI = new double[maxn],
	            rR = new double[maxn],
	            rI = new double[maxn],
	            aR  = new double[maxn],
	            aI  = new double[maxn]; //$
	    int n, k, lastk = -1, dp[] = new int[maxn];

	    void fft(boolean inv) {
	        if (lastk != k) {
	            lastk = k;  dp[0] = 0;
	            for (int i = 1, g = -1; i < n; i++) {
	                if ((i & (i - 1)) == 0) g++;
	                dp[i] = dp[i ^ (1 << g)] ^ (1 << (k - 1 - g));
	            }
	            wR[1] = 1;
	            wI[1] = 0;
	            for (int t = 0; t < k - 1; t++) {
	                double a = Math.PI / n * (1 << (k - 1 - t));
	                double curR = Math.cos(a), curI = Math.sin(a);
	                int p2 = (1 << t), p3 = p2 * 2;
	                for (int j = p2, k = j * 2; j < p3; j++, k += 2) {
	                    wR[k] = wR[j];
	                    wI[k] = wI[j];
	                    wR[k + 1] = wR[j] * curR - wI[j] * curI;
	                    wI[k + 1] = wR[j] * curI + wI[j] * curR;
	                }
	            }
	        }
	        for (int i = 0; i < n; i++) {
	            int d = dp[i];
	            if (i < d) {
	                double tmp = aR[i];
	                aR[i] = aR[d];
	                aR[d] = tmp;
	                tmp = aI[i];
	                aI[i] = aI[d];
	                aI[d] = tmp;
	            }
	        }
	        if (inv) for (int i = 0; i < n; i++) aI[i] = -aI[i];
	        for (int len = 1; len < n; len <<= 1) {
	            for (int i = 0; i < n; i += len) {
	                int wit = len;
	                for (int it = 0, j = i + len; it < len; it++, i++, j++) {
	                    double tmpR = aR[j] * wR[wit] - aI[j] * wI[wit];
	                    double tmpI = aR[j] * wI[wit] + aI[j] * wR[wit];
	                    wit++;
	                    aR[j] = aR[i] - tmpR;
	                    aI[j] = aI[i] - tmpI;
	                    aR[i] += tmpR;
	                    aI[i] += tmpI;
	                }
	            }
	        }
	    }

	    long[] multiply(long[] a, long[] b) {
	        int na = a.length, nb = b.length;
	        for (k = 0, n = 1; n < na + nb - 1; n <<= 1, k++) {}
	        for (int i = 0; i < n; ++i) {
	            aR[i] = i < na ? a[i] : 0;
	            aI[i] = i < nb ? b[i] : 0;
	        }
	        fft(false);
	        aR[n] = aR[0];
	        aI[n] = aI[0];
	        double q = -1.0 / n / 4.0;
	        for (int i = 0; i <= n - i; ++i) {
	            double tmpR = aR[i] * aR[i] - aI[i] * aI[i];
	            double tmpI = aR[i] * aI[i] * 2;
	            tmpR -= aR[n - i] * aR[n - i] - aI[n - i] * aI[n - i];
	            tmpI -= aR[n - i] * aI[n - i] * -2;
	            aR[i] = -tmpI * q;
	            aI[i] = tmpR * q;
	            aR[n - i] = aR[i];
	            aI[n - i] = -aI[i];
	        }
	        fft(true);
	        long[] ans = new long[n = na + nb - 1]; // ONLY MOD IF NEEDED
	        for (int i = 0; i < n; i++) ans[i] = Math.round(aR[i]) % MOD;
	        return ans;
	    }
	    void fft2(double[][] xr, double[][] xi, boolean inv) {
	        n = xr[0].length;
	        k = Integer.numberOfTrailingZeros(n);
	        for (int i = 0; i < xr.length; i++) {
	            for (int j = 0; j < n; j++) { aR[j] = xr[i][j];  aI[j] = xi[i][j]; }
	            fft(inv);
	            for (int j=0;j<n;j++){xr[i][j] = aR[j] / (inv ? n : 1);  xi[i][j] = aI[j] / (inv ? -n : 1);}
	        }
	        n = xr.length;
	        k = Integer.numberOfTrailingZeros(n);
	        for (int j = 0; j < xr[0].length; j++) {
	            for (int i = 0; i < n; i++) { aR[i] = xr[i][j];  aI[i] = xi[i][j]; }
	            fft(inv);
	            for (int i=0;i<n;i++){xr[i][j] = aR[i] / (inv ? n : 1);  xi[i][j] = aI[i] / (inv ? -n : 1);}
	        }
	    }
	    long[][] multiply2(long[][] a, long[][] b) {
	        int n1, n2;
	        for (n1 = 1; n1 < a.length + b.length - 1; n1 <<= 1) {}
	        for (n2 = 1; n2 < a[0].length + b[0].length - 1; n2 <<= 1) {}
	        double[][] ar = new double[n1][n2], ai = new double[n1][n2];
	        double[][] br = new double[n1][n2], bi = new double[n1][n2];
	        for (int i = 0; i < a.length; i++) for(int j=0;j<a[i].length;j++) ar[i][j] = a[i][j];
	        for (int i = 0; i < b.length; i++) for(int j=0;j<b[i].length;j++) br[i][j] = b[i][j];
	        fft2(ar,ai,false); fft2(br,bi,false);
	        for (int i = 0; i < n1; i++) {
	            for(int j = 0; j < n2; j++) {
	                double r1 = ar[i][j], r2 = br[i][j];
	                double i1 = ai[i][j], i2 = bi[i][j];
	                double real = r1 * r2 - i1 * i2;
	                ai[i][j] = i1 * r2+ i2*r1;
	                ar[i][j] = real;
	            }
	        }
	        fft2(ar,ai,true);  long[][] result = new long[n1=a.length+b.length-1][n2=a[0].length+b[0].length-1];
	        for (int i = 0; i < n1; i++)
	            for(int j = 0; j < n2; j++) result[i][j] = Math.round(ar[i][j]);
	        return result;
	    }
	    //#
	    long[] multiplyPrecise(long[] a, long[] b) {
	        long k = (long)(Math.sqrt(MOD));
	        long[] a1 = new long[a.length], a2 = new long[a.length];
	        long[] b1 = new long[b.length], b2 = new long[b.length];
	        for(int i=0;i<a.length;i++) {
	            a1[i] = a[i] % k;
	            a2[i] = a[i] / k;
	        }
	        for(int i=0;i<b.length;i++) {
	            b1[i] = b[i] % k;
	            b2[i] = b[i] / k;
	        }
	        long[] r11 = multiply(a1, b1), r12 = multiply(a1, b2), r21 = multiply(a2, b1), r22 = multiply(a2, b2);
	        long[] res = new long[r11.length];
	        for(int i=0;i<res.length;i++)
	            res[i] = (k*k*r22[i] + k*(r12[i] + r21[i]) + r11[i]) % MOD;
	        return res;
	    }
	    //$
	    //#
	    long[] multiplyOr(long[] eq1, long[] eq2) {
	        int n = Math.max(eq1.length, eq2.length);
	        if((n & (n-1)) != 0)
	            n = Integer.highestOneBit(n)*2;
	        eq1 = fftOr(eq1, n, false);
	        eq2 = fftOr(eq2, n, false);
	        for(int i=0;i<eq1.length;i++)
	            eq1[i] *= eq2[i];
	        eq1 = fftOr(eq1, n, true);
	        return eq1;
	    }
	    //$
	    // To use: FFT both, product, iFFT (n is next power of 2)
	    long[] fftOr(long[] arr, int n, boolean invert) {
	        long[] ans = Arrays.copyOf(arr, n);
	        for (int b = 1; b < n; b <<= 1)
	            for (int i = 0; i < n; i++) {
	                if ((i & b) != 0) continue;
	                ans[i + b] += invert ? -ans[i] : ans[i];
	            }
	        return ans;
	    }
	    long[] fftXor(long[] arr, int n, boolean invert) {
	        long[] ans = Arrays.copyOf(arr, n);
	        for (int b = 1; b < n; b <<= 1)
	            for (int i = 0; i < n; i++) {
	                if((i & b) != 0) continue;
	                long u = ans[i], v = ans[i+b];
	                ans[i] = u + v;  ans[i + b] = u - v;
	            }
	        if (invert) for (int i = 0; i < n; i++) ans[i] /= n;
	        return ans;
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new MMNN01().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile: