PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: mathmodel, BurnedChicken
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Hard
PREREQUISITES:
Combinatorics, NTT, Inclusion-exclusion
PROBLEM:
You’re given N and K.
Count the number of permutations P of [1, \ldots, N] such that P_i + P_{i+1} \not\equiv 0 \pmod K for each 1 \leq i \lt N.
EXPLANATION:
Just as in the easy version, let’s reduce all numbers modulo K, and count the number of valid arrangements with c_i copies of i (0 \leq i \lt K).
In the end, we can multiply the answer by \displaystyle\prod_{i=0}^{K-1} c_i!, to account for all the labelings.
When working with the values \lt K, we care only about pairs (x, y) such that x+y\equiv 0 \pmod K.
These are:
- (0, 0)
- (\frac K 2, \frac K 2) when K is even
- (x, K-x) for all 0 \lt x \lt K-x
It’s not easy to try and build a valid configuration, because there’s no nice structure unlike in the case of K = 3.
Instead, we’ll try to count invalid configurations and subtract them out from the total.
An invalid configuration is one for which (A_i + A_{i+1}) \equiv 0 \pmod K.
An initial try to count these might look something like this:
- Let’s fix (i, i+1) as a “bad” pair of indices, and count the number of configurations.
This is simple combinatorics. - To account for overcounting, we want to subtract out configurations where there are two pairs of bad indices.
However, we immediately run into an issue: there are cases based on whether the two pairs are adjacent or disjoint. - Further, even if both cases are handled, we’ll need to add back in configurations with three bad pairs (because they would’ve been added three times in the first case, and subtracted three times in the second).
The number of possible interactions between pairs blows up even further, since we can have anywhere between 1 and 3 disjoint bad segments.
This starts to become impossible to handle very quickly, because the pairs start merging together in various ways to form longer segments.
Instead, let’s try to build up configurations in a more systematic way, using these very segments as a base.
That is, let’s try to compute \text{ans}[x] to be the number of configurations that contain x maximal segments of alternating elements (where we do allow segments of length 1).
To compute the required values, it’s useful to look at a single (x, K-x) pair of.
Suppose x \lt K-x, there are c_1 copies of x and c_2 copies of K-x.
Note that because of how c_1 and c_2 are computed, we’ll always have either c_1 = c_2 or c_1 = c_2 + 1.
Let’s compute the number of ways we can arrange these copies into exactly y disjoint alternating segments - call this value p_{x, y}.
Each alternating segment must have either even length or odd length.
Even length segments will have an equal number of copies of x and K-x, while odd length copies will have one extra occurrence of one of them.
Let’s fix i to be the number of even-length segments, so there are y-i odd-length segments.
Then, note that:
- Each even-length segment has two choices: [x, K-x, x, K-x, \ldots] or [K-x, x, K-x, x, \ldots].
Both are always valid. - Each odd-length segment has two choices too, in a similar fashion.
However, we must ensure that the count of segments starting with x is more than the count of them starting with y by (c_1 - c_2), so that the final counts of x and K-x are (c_1, c_2) respectively.
In particular this means the count of segments starting with x must equal \left\lceil \frac{y-i}{2} \right\rceil, and is only valid when the parities work out (y-i should be even if c_1 = c_2 and odd otherwise).
So, with both y and i (the number of segments and number of even-length segments fixed), the number of configurations can be computed as follows:
- \binom{y}{i} ways to choose which i segments have even length.
- 2^i ways to choose orientations for each of them.
- \displaystyle\binom{y-i}{\left\lceil \frac{y-i}{2} \right\rceil} ways to choose which odd-length segments start with x (or 0, if y-i has the wrong parity).
- Finally, this must be multiplied by the number of ways to choose the lengths, i.e. the number of ways of having i positive even numbers and y-i positive odd numbers add up to c_1 + c_2.
This can be computed using stars-and-bars, and is another binomial coefficient.
Thus, for a given (x, K-x) pair and fixed segment count y, we’re able to compute p_{x, y} in \mathcal{O}(y) time.
This is already too slow, since it’ll be quadratic for a single pair.
However, looking at the summation for a fixed y, it can be seen that it looks to be of the form \displaystyle\sum_{i=0}^y f(i)\cdot g(y-i) for some functions f and g.
This is a convolution, and so we can use NTT to compute the p_{x, y} values for all y in \mathcal{O}((c_1 + c_2)\log (c_1 + c_2)) time.
We now know how to compute the counts for a fixed (x, K-x) pair where x \neq K-x.
There are also pairs (x, K-x) where x = K-x, though at most two such pairs.
The p_{x, y} values for these can be found much easier since all elements are the same: if there are y blocks we only want the number of ways of writing c_x as a sum of y positive integers, which is direct stars-and-bars.
Note that since each p_{x} array is computed in time proportional to the sum of frequencies of x and K-x, doing this for all x \leq K-x will still be \mathcal{O}(N\log N) time overall.
Next, we need to combine the results of all the individual calculations - that is, combine the segments to form a full array.
So, suppose we choose y_x segments of the class (x, K-x).
The number of ways to arrange these segments into a full array is then
This is because there are p_{x, y_x} ways to choose the configuration of segments corresponding to (x, K-x), and with a total of \sum y_x segments, there are \binom{\sum y_x}{y_0, y_1, y_2, \ldots} ways to permute them since we’re not allowed to change the order of segments corresponding to x.
Here, \binom{\sum y_x}{y_0, y_1, y_2, \ldots} is a multinomial coefficient, and is the value \frac{(\sum y_x)!}{y_0! \cdot y_1! \cdot \ldots}
To compute \text{ans}[y], we need to sum up the above quantity across all possible combinations of sums that result in \sum y_x = y.
Since this is a sum of products across sums of indices with a fixed sum, once again it can be computed quickly for all y by using NTT.
Specifically, to deal with the multinomial coefficient, treat each p_x to be an EGF - that is, consider \frac{1}{y!} \cdot p_{x, y} instead, perform the convolution (which implicitly brings in the denominators), and then multiply \text{ans}[y] by y! in the end to bring in the numerator.
Since we have several polynomials whose total degree is N, the standard way of multiplying them all together is to use divide-and-conquer (or repeatedly choose the two lowest-degree ones) to obtain a complexity of \mathcal{O}(N\log^2 N).
However, the structure specific to this problem in fact allows for optimization to \mathcal{O}(N\log N).
Observe that we don’t actually have too many distinct polynomials: there’s one corresponding to x = 0, (maybe) one corresponding to x = \frac{K}{2}, and then for each (x, K-x) pair the resulting polynomial depends purely on their frequencies c_x and c_{K-x}.
However, since we have all the integers \leq N, every frequency must be either \left\lfloor\frac{N}{K} \right\rfloor or \left\lceil\frac{N}{K} \right\rceil, so there are at most four distinct polynomials possible (in fact, with a bit more analysis, you can show that there are at most two distinct polynomials).
So, the product we want to compute is really just the product of powers of a small number of polynomials.
The r-th power of a polynomial of degree m can be computed in \mathcal{O}(rm \log{rm}) by utilizing how FFT/NTT actually work: treat the input polynomial to be of degree rm, perform NTT to obtain the point-value form, exponentiate each value to the r-th power, and then perform inverse NTT to obtain the coefficient form as desired.
Once all the \text{ans}[y] values are computed, we need to extract the final answer from them.
Observe that a fair bit of overcounting would have happened when we combined segments, because we didn’t do anything to enforce adjacent segments being of different types.
For example, a configuration like [1, 2, 1, 2] with one segment will also have been counted as [1, 2] [1, 2] or [1] [2, 1, 2] with two segments or [1] [2, 1] [2] with three segments and so on and so forth.
More generally, consider a final configuration where there are exactly x pairs (i, i+1) such that (A_{i} + A_{i+1}) \equiv 0 \pmod K.
This configuration will then be counted exactly \binom{x}{y} times in \text{ans}[N-1-x+y], since it’ll be counted exactly once for each possible method of not breaking the x pairs.
To account for this, we apply inclusion-exclusion. It’s not hard to work out that the requisite quantity is:
TIME COMPLEXITY:
\mathcal{O}(N \log N) per testcase.
CODE:
Admin's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int mod = 998244353;
struct mint{
int x;
mint (){ x = 0;}
mint (int32_t xx){ x = xx % mod; if (x < 0) x += mod;}
mint (long long xx){ x = xx % mod; if (x < 0) x += mod;}
int val(){
return x;
}
mint &operator++(){
x++;
if (x == mod) x = 0;
return *this;
}
mint &operator--(){
if (x == 0) x = mod;
x--;
return *this;
}
mint operator++(int32_t){
mint result = *this;
++*this;
return result;
}
mint operator--(int32_t){
mint result = *this;
--*this;
return result;
}
mint& operator+=(const mint &b){
x += b.x;
if (x >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint &b){
x -= b.x;
if (x < 0) x += mod;
return *this;
}
mint& operator*=(const mint &b){
long long z = x;
z *= b.x;
z %= mod;
x = (int)z;
return *this;
}
mint operator+() const {
return *this;
}
mint operator-() const {
return mint() - *this;
}
mint operator/=(const mint &b){
return *this = *this * b.inv();
}
mint power(long long n) const {
mint ok = *this, r = 1;
while (n){
if (n & 1){
r *= ok;
}
ok *= ok;
n >>= 1;
}
return r;
}
mint inv() const {
return power(mod - 2);
}
friend mint operator+(const mint& a, const mint& b){ return mint(a) += b;}
friend mint operator-(const mint& a, const mint& b){ return mint(a) -= b;}
friend mint operator*(const mint& a, const mint& b){ return mint(a) *= b;}
friend mint operator/(const mint& a, const mint& b){ return mint(a) /= b;}
friend bool operator==(const mint& a, const mint& b){ return a.x == b.x;}
friend bool operator!=(const mint& a, const mint& b){ return a.x != b.x;}
mint power(mint a, long long n){
return a.power(n);
}
friend ostream &operator<<(ostream &os, const mint &m) {
os << m.x;
return os;
}
explicit operator bool() const {
return x != 0;
}
};
// Remember to check MOD
struct factorials{
int n;
vector <mint> ff, iff;
factorials(int nn){
n = nn;
ff.resize(n + 1);
iff.resize(n + 1);
ff[0] = 1;
for (int i = 1; i <= n; i++){
ff[i] = ff[i - 1] * i;
}
iff[n] = ff[n].inv();
for (int i = n - 1; i >= 0; i--){
iff[i] = iff[i + 1] * (i + 1);
}
}
mint C(int n, int r){
if (n == r) return mint(1);
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[r] * iff[n - r];
}
mint P(int n, int r){
if (n < 0 || r < 0 || r > n) return mint(0);
return ff[n] * iff[n - r];
}
mint solutions(int n, int r){
// Solutions to x1 + x2 + ... + xn = r, xi >= 0
return C(n + r - 1, n - 1);
}
mint catalan(int n){
return ff[2 * n] * iff[n] * iff[n + 1];
}
};
const int PRECOMP = 3e6 + 69;
factorials F(PRECOMP);
// REMEMBER To check MOD and PRECOMP
int ceil_pow2(int n) {
int x = 0;
while ((1U << x) < (unsigned int)(n)) x++;
return x;
}
int bsf(unsigned int n){
return __builtin_ctz(n);
}
void butterfly(std::vector<mint>& a) {
static constexpr int g = 3; // primitive root
int n = (int)(a.size());
int h = ceil_pow2(n);
static bool first = true;
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mod - 1);
mint e = mint(g).power((mod - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_e[i] = es[i] * now;
now *= ies[i];
}
}
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * now;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
now *= sum_e[bsf(~(unsigned int)(s))];
}
}
}
void butterfly_inv(std::vector<mint>& a) {
static constexpr int g = 3; // primitive root
int n = (int)(a.size());
int h = ceil_pow2(n);
static bool first = true;
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mod - 1);
mint e = mint(g).power((mod - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now;
now *= es[i];
}
}
for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] = (mod + l.val() - r.val()) * inow.val();
}
inow *= sum_ie[bsf(~(unsigned int)(s))];
}
}
}
std::vector<mint> convolution_p2(std::vector<mint> a, std::vector<mint> b) {
int n = (int)(a.size()), m = (int)(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) {
if (n < m) {
std::swap(n, m);
std::swap(a, b);
}
std::vector<mint> ans(n + m - 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ans[i + j] += a[i] * b[j];
}
}
return ans;
}
int z = 1 << ceil_pow2(n + m - 1);
a.resize(z);
butterfly(a);
b.resize(z);
butterfly(b);
for (int i = 0; i < z; i++) {
a[i] *= b[i];
}
butterfly_inv(a);
a.resize(n + m - 1);
mint iz = mint(z).inv();
for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
return a;
}
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
int n = (int)(a.size()), m = (int)(b.size());
if (!n || !m) return {};
std::vector<mint> a2(n), b2(m);
for (int i = 0; i < n; i++) {
a2[i] = mint(a[i]);
}
for (int i = 0; i < m; i++) {
b2[i] = mint(b[i]);
}
auto c2 = convolution_p2(move(a2), move(b2));
return c2;
}
void Solve()
{
int n, k; cin >> n >> k;
vector <int> f(k, 0);
for (int i = 1; i <= n; i++){
f[i % k]++;
}
vector <mint> ans(1, 0);
ans[0] = 1;
auto mult = [&](vector <mint> P){
ans = convolution(ans, P);
};
auto work = [&](int x){
int y = f[x];
if (y == 0){
return;
}
vector <mint> ans(y + 1, 0);
for (int i = 1; i <= y; i++){
ans[i] = F.solutions(i, y - i);
ans[i] /= F.ff[i];
}
mult(ans);
};
auto get = [&](int x, int y){
vector <mint> ans(x + y + 1, 0);
vector <mint> f(x + y + 1, 0), g(x + y + 1, 0);
for (int i = 1; i <= x + y; i++){
mint w = 0;
{
int r = i / 2;
int b = (i + 1) / 2;
w += F.solutions(r, x - r) * F.solutions(b, y - b);
}
{
int r = (i + 1) / 2;
int b = i / 2;
w += F.solutions(r, x - r) * F.solutions(b, y - b);
}
f[i] += w * F.ff[i - 1];
}
for (int d = 1; d <= x + y; d++){
g[x + y - d] = F.iff[d - 1];
}
auto h = convolution(f, g);
for (int j = 0; j < x + y; j++){
ans[j] += h[x + y + j] * F.iff[j];
}
reverse(ans.begin(), ans.end());
for (int i = 1; i <= x + y; i++){
ans[i] /= F.ff[i];
}
return ans;
};
work(0);
if (k % 2 == 0){
work(k / 2);
}
map <pair<int, int>, int> mp;
for (int i = 1; 2 * i < k; i++){
mp[make_pair(f[i], f[k - i])]++;
}
for (auto [p, c] : mp){
auto P = get(p.first, p.second);
vector <mint> res(1); res[0] = 1;
for (int i = 0; i < 20; i++){
if (c == 0){
break;
}
if (c >> i & 1){
c -= 1 << i;
res = convolution(res, P);
}
P = convolution(P, P);
}
mult(res);
}
mint factor = 1;
for (int i = 0; i < k; i++){
factor *= F.ff[f[i]];
}
for (int i = 1; i <= n; i++){
ans[i] *= factor;
ans[i] *= F.ff[i];
}
mint got = 0;
for (int i = 1; i <= n; i++){
if (i % 2 == n % 2){
got += ans[i];
} else {
got -= ans[i];
}
}
cout << got << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
// cout << "Case #" << i << ": \n";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}