MINSEG - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sayeef_mahmud
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming, bitmasks, stacks/segment trees

PROBLEM:

A string is called good if every character it contains appears an odd number of times.
You’re given a string S. Find the minimum size of a partition of S into good substrings.

EXPLANATION:

We’ll use dynamic programming.

Let dp_i denote the minimum number of good substrings required to partition the prefix of S of length i.
We trivially have dp_i = \max(dp_j + 1) across all j \lt i such that the string S[j+1\ldots i] is good.
Our focus is now on speeding up this quadratic algorithm.

For a string T, let c(T) denote the set of characters that appear in T.
Observe that out of all the substrings S[j\ldots i] ending at an index i, there aren’t actually that many different values of c(S[j\ldots i]), because c(S[j\ldots i]) = c(S[j+1\ldots i]) \cup \{S_j\}.
In particular, since S contains at most 20 distinct characters, there are at most 20 different sets of characters among substrings at i.

Finding these sets is fairly simple too: sort every character by its last appearance at an index \leq i, and simply take them in this order, from i downwards.
This also tells us that a fixed set of characters will appear on some range of left endpoints j \leq i; and finding this range is also easy - it starts whenever a new character appears, and ends just before the next new character appears.

Let’s fix a set of characters this way. Let L and R be the bounds of j such that for every L \leq j \leq R, the substring S[j\ldots i] has this set.
Using bitmasking and prefix XORs, we can in fact find all j between L and R such that S[j\ldots i] is good.

How?

There are 20 distinct characters. Let’s map them to 2^0, 2^1, \ldots, 2^{19}.
Suppose A_i is the mapping of S_i.

Consider what happens to some subarray (l, r) of A (which represents a substring of S).
Specifically, consider the value A_l \oplus A_{l+1}\oplus\ldots\oplus A_r.
In this bitwise XOR:

  • If 2^x is present in it, the character corresponding to it appears an odd number of times in the range.
  • If 2^x is not present, the corresponding character appears an even number of times.

So, under this encoding, the bitwise XOR of a range tells us exactly which characters in it appear an odd number of times.

Let P denote the prefix XOR array under the above encoding.
Then, we’re interested in all indices L\leq j\leq R such that the value P_i\oplus P_{j-1} contains exactly the set of characters we’ve fixed.

That is, if mask is the bitmask corresponding to our set of characters, we want to look at all j in this range such that P_i\oplus P_{j-1} = mask, or in other words P_{j-1} = mask\oplus P_i.

Notice that the right side is a constant, since i and mask are fixed.
So, if we had a list of indices corresponding to each prefix XOR, all we want to do is look at the list corresponding to P_i\oplus mask, see which of them lie in the range [L-1, R-1] (which will be some continuous segment of the list), and take the minimum dp value among all these indices.

Since this has essentially been reduced to a range minimum query problem with point updates (dp_i needs to be set once it’s computed), the problem can be solved using a segment tree.
The complexity is \mathcal{O}(N\Sigma \log N) where \Sigma = 20.
Even though it’s somewhat slow, if implemented reasonably, this should get AC.


However, we can do better!

Let’s make another observation: suppose the same mask appears at indices i_1 and i_2 (i_1\lt i_2), with corresponding ranges [L_1, R_1] and [L_2, R_2].
Then, either L_1 = L_2 or R_1 \lt L_2, i.e, the two ranges are either disjoint or share the same left endpoint.

Proof

If some character x that’s not present in mask appears between indices i_1 and i_2, then clearly L_2 must be after this character (and hence after i_1 as well).
Otherwise, the first new character not in mask is at index L_1 - 1, and we’ll have L_1 = L_2.

This allows us to keep a disjoint set of intervals for each prefix XOR, and repeatedly merge the last two of them as long as they both fit into the current [L, R] for mask.
In the end, simply take the answer for the last interval.

This is easily simulated with a stack, removing the \log N from the complexity.

TIME COMPLEXITY:

Between \mathcal{O}(N\cdot \Sigma) and \mathcal{O}(N\Sigma\log N) per testcase, where \Sigma = 20 is the alphabet size.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define INF (int)1e9

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct Node{
    int v;
    int l, r;
};

const int N = 2e5 + 5;
const int M = 6e7 + 5;
Node seg[M];
int n;
int h[N];
string s;
int base[1 << 26];
int last[26];
int v = 1;

void upd(int l, int r, int pos, int qp, int val){
    seg[pos].v = min(seg[pos].v, val);
    if (l == r) return;
    int mid = (l + r) / 2;
    if (seg[pos].l == -1){
        seg[pos].l = v++;
        seg[pos].r = v++;
    }

    if (qp <= mid) upd(l, mid, seg[pos].l, qp, val);
    else upd(mid + 1, r, seg[pos].r, qp, val);
}

int query(int l, int r, int pos, int ql, int qr){
    if (l >= ql && r <= qr) return seg[pos].v;
    else if (l > qr || r < ql) return INF;

    int mid = (l + r) / 2;

    if (seg[pos].l == -1){
        seg[pos].l = v++;
        seg[pos].r = v++;
    }

    return min(query(l, mid, seg[pos].l, ql, qr), query(mid + 1, r, seg[pos].r, ql, qr));
}

int query(int x, int l, int r){
    if (base[x] == 0){
        base[x] = v++;
    }
    return query(1, n, base[x], l, r);
}

void upd(int x, int i, int y){
    if (base[x] == 0){
        base[x] = v++;
    }

    upd(1, n, base[x], i, y);
}

void Solve() 
{
    cin >> n >> s;

    for (int i = 0; i < M; i++){
        seg[i].v = INF;
        seg[i].l = -1;
        seg[i].r = -1;
    }

    for (int i = 1; i <= n; i++){
        h[i] = h[i - 1] ^ (1 << (s[i - 1] - 'a'));
    }

    vector <int> dp(n + 1, INF);
    dp[0] = 0;
    upd(h[0], 1, dp[0]);

    for (int i = 1; i <= n; i++){
        last[s[i - 1] - 'a'] = i;

        vector <pair<int, int>> b;
        for (int j = 0; j < 26; j++){
            b.push_back({last[j], j});
        }
        b.push_back({0, 26});
        sort(b.begin(), b.end(), greater<pair<int, int>>());

        int curr = h[i];
        int ok = __builtin_popcount(h[i]);
        for (int j = 0; j < 26; j++) ok -= last != 0;
        if (ok == 0){
            dp[i] = 1;
        }

        //if (i == 90000) return;

      //  cout << "FOR " << i << "\n";

        for (int j = 0; j < 26; j++){
            if (b[j].first == 0) break;

            curr ^= 1 << (b[j].second);
            int L = b[j + 1].first + 1;
            int R = b[j].first;

            dp[i] = min(dp[i], query(curr, L, R) + 1);
        }

        if (i != n)
        upd(h[i], i + 1, dp[i]);
    }

    cout << dp[n] << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("out", "w", stdout);
    
   // cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#define dbg(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define dbg(...)
#endif
 
void __print(int32_t x) {cerr << x;}
void __print(int64_t x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(string x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T>void __print(complex<T> x) {cerr << '{'; __print(x.real()); cerr << ','; __print(x.imag()); cerr << '}';}
 
template<typename T>
void __print(const T &x);
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto it = x.begin() ; it != x.end() ; it++) cerr << (f++ ? "," : ""), __print(*it); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

constexpr int M = 't' - 'a' + 1;

int32_t main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);

	vector<vector<int>> con(1 << M);
	con[0].push_back(0);

	// input_checker inp;

	// int N = inp.readInt(1, (int)1e5);	inp.readEoln();	 NN += N;
	// string S = inp.readString(N, N, "abcdefghijklmnopqrst");	inp.readEoln();
	int N;	string S;
	cin >> N >> S;
	vector<int> last(M + 1), dp(N + 1, N + 1);	dp[0] = 0;

	vector<int> ord(M + 1);	iota(ord.begin(), ord.end(), 0);
	int mask = 0, req = 0;
	for(int i = 0 ; i < N ; ++i) {
		last[S[i] - 'a'] = i + 1;
		mask ^= 1 << int(S[i] - 'a'), req = mask;
		sort(ord.begin(), ord.end(), [&](int i, int j) {
			return last[i] > last[j];
		});
		for(int j = 1 ; j <= M ; ++j) {
			req ^= (1 << ord[j - 1]);
			int p = lower_bound(con[req].begin(), con[req].end(), last[ord[j]]) - con[req].begin();
			if(p < (int)con[req].size()) 	dp[i + 1] = min(dp[i + 1], dp[con[req][p]] + 1);
			if(last[ord[j]] == 0)	break;
		}
		while(!con[mask].empty() && dp[con[mask].back()] > dp[i + 1])
			con[mask].pop_back();
		con[mask].push_back(i + 1);
	}
	cout << dp[N] << '\n';

	// inp.readEof();

	return 0;
}

Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

const int alp = 20;
int ID[1 << alp];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int n; cin >> n;
    string s; cin >> s;

    vector<int> pmask(n+1);
    for (int i = 0; i < n; ++i)
        pmask[i+1] = pmask[i] ^ (1 << (s[i] - 'a'));
    
    memset(ID, -1, sizeof ID);
    ID[0] = 0;
    int id = 1;
    vector<int> dp(n+1);
    vector<vector<array<int, 3>>> intervals(1);
    intervals[0] = {{0, 0, 0}};
    
    vector<int> last(alp+1, 0);
    for (int i = 1; i <= n; ++i) {
        dp[i] = i;
        last[s[i-1] - 'a'] = i;
        vector<int> ord(alp+1);
        iota(begin(ord), end(ord), 0);
        sort(rbegin(ord), rend(ord), [&] (int x, int y) {return last[x] < last[y];});
        int mask = 0;
        for (int c = 1; c < alp+1 and last[ord[c-1]]; ++c) {
            mask |= 1 << (ord[c-1]);
            int L = last[ord[c]], R = last[ord[c-1]];

            int want = pmask[i] ^ mask;
            if (ID[want] == -1) continue;

            int x = ID[want];
            auto &stk = intervals[x];
            while (stk.size() >= 2) {
                int k = stk.size();
                int l = stk[k-2][0], r = stk[k-1][1], y = min(stk[k-2][2], stk[k-1][2]);
                if (l >= L and r < R) {
                    stk.pop_back();
                    stk.back() = {l, r, y};
                }
                else break;
            }
            auto [l, r, y] = stk.back();
            if (l >= L and r < R) dp[i] = min(dp[i], y + 1);
        }

        if (ID[pmask[i]] == -1) {
            ID[pmask[i]] = id++;
            intervals.emplace_back();
        }
        intervals[ID[pmask[i]]].push_back({i, i, dp[i]});
    }

    cout << dp[n] << '\n';
}
1 Like

What kind of variation of segment tree is this? It is different from normal range queries, because it is a subset of index in the range that is queried.

Exactly I was wondering whats that. And the editorial explains it as if its a standard thing. Am totally confused.

The important part here is:

Essentially, you have many smaller lists (one corresponding to each value of prefix XOR), and instead of looking at some random subset of a range, you now have a continuous subarray within one of these lists (to find said subarray, you can binary search).
That transformation is what allows for the use of a segment tree (and yes, it does become a simple range min query now).
The sum of sizes of all these lists is N (each prefix appears in exactly one list) so you can even just build a separate segment tree for each one and be fine.