FILLIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: ro27
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Elementary combinatorics

PROBLEM:

You’re given a permutation P of length 2N with some of its elements missing.
Find the number of ways of filling in the zeros such that

|P_1 - P_2| + |P_3 - P_4| + \ldots + |P_{2N} - P_{2N-1}|

is maximized.

EXPLANATION:

Recall from the solution to the easy version that we ideally want to pair small elements (\leq N) with large ones (\gt N).

Now, however, we might not always be able to do that because of already existing small-small and large-large pairs.

Suppose there are k_1 small-small and k_2 large-large pairs.
Also suppose that k_1 \leq k_2.

Then, no matter how we fill in values, we’ll always end up with at least k_2 small-small pairs (observe that the number of large-large and small-small pairs must be equal in the end).
So, it’s in our best interest to keep this count to just k_2 - since the more small-small pairs we make, the less small-large ones we can make (and the latter is what we want).

This means we require an additional (k_2 - k_1) small-small pairs - which in turn means that an additional (k_2 - k_1) small elements will be added to the sum (rather than subtracted).

Since our aim is to maximize the sum, we should of course choose the largest (k_2 - k_1) unpaired elements to be the ones added, and all the smaller ones to be subtracted.
Note that “unpaired” means that they either don’t appear in P at all, or they appear but next to a 0.

In any case, find these (k_2 - k_1) largest unpaired small elements.
Suppose there are x small elements in not yet in P (excluding ones among the (k_2 - k_1) above), and y large elements not yet in P.
Suppose also that among the (k_2-k_1) largest unpaired small elements, z of them are not yet in P.

Observe that the “large” small elements can essentially be treated as large elements - after all, to maximize the sum, they can only pair with “small” small elements, not not within themselves or large elements.
So, functionally we have y+z large elements not yet in P, and x small elements not yet in P.

Now, applying the exact same reasoning as the easy version, the number of arrangements is simply

x! \cdot (y+z)! \cdot 2^k

where k is the number of pairs with two zeros initially.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

const long long mod = 998244353;

void solve(istringstream cin) {
    int n;
    cin >> n;
    vector<int> p(2 * n);
    for (int i = 0; i < 2 * n; i++) {
        cin >> p[i];
    }
    vector<int> a(2 * n + 1);
    for (int i = 0; i < n; i++) {
        if (p[2 * i] != 0 && p[2 * i + 1] != 0) {
            a[p[2 * i]] = -1;
            a[p[2 * i + 1]] = -1;
        } else {
            a[p[2 * i]] = 1;
            a[p[2 * i + 1]] = 1;
        }
    }
    vector<int> b;
    for (int i = 1; i <= 2 * n; i++) {
        if (a[i] != -1) {
            b.emplace_back(a[i]);
        }
    }
    int sz = (int) b.size();
    assert(sz % 2 == 0);
    auto factorial = [&](int k) {
        long long res = 1;
        for (int i = 1; i <= k; i++) {
            res *= i;
            res %= mod;
        }
        return res;
    };
    int c0 = (int) count(b.begin(), b.begin() + sz / 2, 0);
    int c1 = (int) count(b.begin() + sz / 2, b.end(), 0);
    long long ans = factorial(c0) * factorial(c1) % mod;
    for (int i = 0; i < n; i++) {
        if (p[2 * i] == 0 && p[2 * i + 1] == 0) {
            ans *= 2;
            ans %= mod;
        }
    }
    cout << ans << '\n';
}

////////////////////////////////////////

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 2e5);
        in.readEoln();
        sn += n;
        auto p = in.readInts(2 * n, 0, 2 * n);
        in.readEoln();
        vector<int> a(2 * n + 1);
        for (int i = 0; i < 2 * n; i++) {
            if (p[i]) {
                assert(a[p[i]] == 0);
                a[p[i]] = 1;
            }
        }
        ostringstream sout;
        sout << n << '\n';
        for (int i = 0; i < 2 * n; i++) {
            sout << p[i] << " \n"[i == 2 * n - 1];
        }
        solve(istringstream(sout.str()));
    }
    cerr << sn << endl;
    assert(sn <= 2e5);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 998244353;

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> p(2*n);
        for (int &x : p) cin >> x;

        // pair > n with <= n
        // suppose x pairs >n and y pairs <= n already; x >= y
        // have to make x-y pairs where both are >= n
        // choose the smallest x-y unpaired elements to be subtracted
        // then distribute the rest

        vector<int> mark(2*n + 1);
        int small = 0, large = 0, empty = 0;
        for (int i = 0; i < 2*n; i += 2) {
            if (p[i] == 0 and p[i+1] == 0) ++empty;
            if (p[i]) {
                ++mark[p[i]];
                ++mark[p[i+1]];
            }
            if (p[i+1]) {
                ++mark[p[i]];
                ++mark[p[i+1]];
            }
        }

        int ans = 1;
        for (int i = 1; i <= empty; ++i)
            ans = (2ll * ans) % mod;
        
        int x = 0, y = 0;
        small = count(begin(mark)+1, begin(mark)+n+1, 2);
        large = count(begin(mark)+n+1, end(mark), 2);
        if (small <= large) {
            int want = large/2 - small/2;
            
            for (int i = n+1; i <= 2*n; ++i)
                x += mark[i] == 0;
            for (int i = n; i >= 1; --i) {
                if (mark[i] == 2) continue;
                if (want) {
                    x += mark[i] == 0;
                    --want;
                }
                else y += mark[i] == 0;
            }
        }
        else {
            int want = small/2 - large/2;
            
            for (int i = 1; i <= n; ++i)
                x += mark[i] == 0;
            for (int i = n+1; i <= 2*n; ++i) {
                if (mark[i] == 2) continue;
                if (want) {
                    x += mark[i] == 0;
                    --want;
                }
                else y += mark[i] == 0;
            }
        }

        for (int i = 1; i <= x; ++i) ans = (1ll * ans * i) % mod;
        for (int i = 1; i <= y; ++i) ans = (1ll * ans * i) % mod;
        cout << ans << '\n';
    }
}

Such a terribly written editorial. Sigh