PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: sayeef_mahmud
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Dynamic programming, bitmasks, stacks/segment trees
PROBLEM:
A string is called good if every character it contains appears an odd number of times.
You’re given a string S. Find the minimum size of a partition of S into good substrings.
EXPLANATION:
We’ll use dynamic programming.
Let dp_i denote the minimum number of good substrings required to partition the prefix of S of length i.
We trivially have dp_i = \max(dp_j + 1) across all j \lt i such that the string S[j+1\ldots i] is good.
Our focus is now on speeding up this quadratic algorithm.
For a string T, let c(T) denote the set of characters that appear in T.
Observe that out of all the substrings S[j\ldots i] ending at an index i, there aren’t actually that many different values of c(S[j\ldots i]), because c(S[j\ldots i]) = c(S[j+1\ldots i]) \cup \{S_j\}.
In particular, since S contains at most 20 distinct characters, there are at most 20 different sets of characters among substrings at i.
Finding these sets is fairly simple too: sort every character by its last appearance at an index \leq i, and simply take them in this order, from i downwards.
This also tells us that a fixed set of characters will appear on some range of left endpoints j \leq i; and finding this range is also easy - it starts whenever a new character appears, and ends just before the next new character appears.
Let’s fix a set of characters this way. Let L and R be the bounds of j such that for every L \leq j \leq R, the substring S[j\ldots i] has this set.
Using bitmasking and prefix XORs, we can in fact find all j between L and R such that S[j\ldots i] is good.
How?
There are 20 distinct characters. Let’s map them to 2^0, 2^1, \ldots, 2^{19}.
Suppose A_i is the mapping of S_i.
Consider what happens to some subarray (l, r) of A (which represents a substring of S).
Specifically, consider the value A_l \oplus A_{l+1}\oplus\ldots\oplus A_r.
In this bitwise XOR:
- If 2^x is present in it, the character corresponding to it appears an odd number of times in the range.
- If 2^x is not present, the corresponding character appears an even number of times.
So, under this encoding, the bitwise XOR of a range tells us exactly which characters in it appear an odd number of times.
Let P denote the prefix XOR array under the above encoding.
Then, we’re interested in all indices L\leq j\leq R such that the value P_i\oplus P_{j-1} contains exactly the set of characters we’ve fixed.
That is, if mask is the bitmask corresponding to our set of characters, we want to look at all j in this range such that P_i\oplus P_{j-1} = mask, or in other words P_{j-1} = mask\oplus P_i.
Notice that the right side is a constant, since i and mask are fixed.
So, if we had a list of indices corresponding to each prefix XOR, all we want to do is look at the list corresponding to P_i\oplus mask, see which of them lie in the range [L-1, R-1] (which will be some continuous segment of the list), and take the minimum dp value among all these indices.
Since this has essentially been reduced to a range minimum query problem with point updates (dp_i needs to be set once it’s computed), the problem can be solved using a segment tree.
The complexity is \mathcal{O}(N\Sigma \log N) where \Sigma = 20.
Even though it’s somewhat slow, if implemented reasonably, this should get AC.
However, we can do better!
Let’s make another observation: suppose the same mask appears at indices i_1 and i_2 (i_1\lt i_2), with corresponding ranges [L_1, R_1] and [L_2, R_2].
Then, either L_1 = L_2 or R_1 \lt L_2, i.e, the two ranges are either disjoint or share the same left endpoint.
Proof
If some character x that’s not present in mask appears between indices i_1 and i_2, then clearly L_2 must be after this character (and hence after i_1 as well).
Otherwise, the first new character not in mask is at index L_1 - 1, and we’ll have L_1 = L_2.
This allows us to keep a disjoint set of intervals for each prefix XOR, and repeatedly merge the last two of them as long as they both fit into the current [L, R] for mask.
In the end, simply take the answer for the last interval.
This is easily simulated with a stack, removing the \log N from the complexity.
TIME COMPLEXITY:
Between \mathcal{O}(N\cdot \Sigma) and \mathcal{O}(N\Sigma\log N) per testcase, where \Sigma = 20 is the alphabet size.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define INF (int)1e9
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct Node{
int v;
int l, r;
};
const int N = 2e5 + 5;
const int M = 6e7 + 5;
Node seg[M];
int n;
int h[N];
string s;
int base[1 << 26];
int last[26];
int v = 1;
void upd(int l, int r, int pos, int qp, int val){
seg[pos].v = min(seg[pos].v, val);
if (l == r) return;
int mid = (l + r) / 2;
if (seg[pos].l == -1){
seg[pos].l = v++;
seg[pos].r = v++;
}
if (qp <= mid) upd(l, mid, seg[pos].l, qp, val);
else upd(mid + 1, r, seg[pos].r, qp, val);
}
int query(int l, int r, int pos, int ql, int qr){
if (l >= ql && r <= qr) return seg[pos].v;
else if (l > qr || r < ql) return INF;
int mid = (l + r) / 2;
if (seg[pos].l == -1){
seg[pos].l = v++;
seg[pos].r = v++;
}
return min(query(l, mid, seg[pos].l, ql, qr), query(mid + 1, r, seg[pos].r, ql, qr));
}
int query(int x, int l, int r){
if (base[x] == 0){
base[x] = v++;
}
return query(1, n, base[x], l, r);
}
void upd(int x, int i, int y){
if (base[x] == 0){
base[x] = v++;
}
upd(1, n, base[x], i, y);
}
void Solve()
{
cin >> n >> s;
for (int i = 0; i < M; i++){
seg[i].v = INF;
seg[i].l = -1;
seg[i].r = -1;
}
for (int i = 1; i <= n; i++){
h[i] = h[i - 1] ^ (1 << (s[i - 1] - 'a'));
}
vector <int> dp(n + 1, INF);
dp[0] = 0;
upd(h[0], 1, dp[0]);
for (int i = 1; i <= n; i++){
last[s[i - 1] - 'a'] = i;
vector <pair<int, int>> b;
for (int j = 0; j < 26; j++){
b.push_back({last[j], j});
}
b.push_back({0, 26});
sort(b.begin(), b.end(), greater<pair<int, int>>());
int curr = h[i];
int ok = __builtin_popcount(h[i]);
for (int j = 0; j < 26; j++) ok -= last != 0;
if (ok == 0){
dp[i] = 1;
}
//if (i == 90000) return;
// cout << "FOR " << i << "\n";
for (int j = 0; j < 26; j++){
if (b[j].first == 0) break;
curr ^= 1 << (b[j].second);
int L = b[j + 1].first + 1;
int R = b[j].first;
dp[i] = min(dp[i], query(curr, L, R) + 1);
}
if (i != n)
upd(h[i], i + 1, dp[i]);
}
cout << dp[n] << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("out", "w", stdout);
// cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define dbg(x...) cerr << "[" << #x << "] = ["; _print(x)
#else
#define dbg(...)
#endif
void __print(int32_t x) {cerr << x;}
void __print(int64_t x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(string x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
template<typename T>void __print(complex<T> x) {cerr << '{'; __print(x.real()); cerr << ','; __print(x.imag()); cerr << '}';}
template<typename T>
void __print(const T &x);
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto it = x.begin() ; it != x.end() ; it++) cerr << (f++ ? "," : ""), __print(*it); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
constexpr int M = 't' - 'a' + 1;
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
vector<vector<int>> con(1 << M);
con[0].push_back(0);
// input_checker inp;
// int N = inp.readInt(1, (int)1e5); inp.readEoln(); NN += N;
// string S = inp.readString(N, N, "abcdefghijklmnopqrst"); inp.readEoln();
int N; string S;
cin >> N >> S;
vector<int> last(M + 1), dp(N + 1, N + 1); dp[0] = 0;
vector<int> ord(M + 1); iota(ord.begin(), ord.end(), 0);
int mask = 0, req = 0;
for(int i = 0 ; i < N ; ++i) {
last[S[i] - 'a'] = i + 1;
mask ^= 1 << int(S[i] - 'a'), req = mask;
sort(ord.begin(), ord.end(), [&](int i, int j) {
return last[i] > last[j];
});
for(int j = 1 ; j <= M ; ++j) {
req ^= (1 << ord[j - 1]);
int p = lower_bound(con[req].begin(), con[req].end(), last[ord[j]]) - con[req].begin();
if(p < (int)con[req].size()) dp[i + 1] = min(dp[i + 1], dp[con[req][p]] + 1);
if(last[ord[j]] == 0) break;
}
while(!con[mask].empty() && dp[con[mask].back()] > dp[i + 1])
con[mask].pop_back();
con[mask].push_back(i + 1);
}
cout << dp[N] << '\n';
// inp.readEof();
return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
const int alp = 20;
int ID[1 << alp];
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int n; cin >> n;
string s; cin >> s;
vector<int> pmask(n+1);
for (int i = 0; i < n; ++i)
pmask[i+1] = pmask[i] ^ (1 << (s[i] - 'a'));
memset(ID, -1, sizeof ID);
ID[0] = 0;
int id = 1;
vector<int> dp(n+1);
vector<vector<array<int, 3>>> intervals(1);
intervals[0] = {{0, 0, 0}};
vector<int> last(alp+1, 0);
for (int i = 1; i <= n; ++i) {
dp[i] = i;
last[s[i-1] - 'a'] = i;
vector<int> ord(alp+1);
iota(begin(ord), end(ord), 0);
sort(rbegin(ord), rend(ord), [&] (int x, int y) {return last[x] < last[y];});
int mask = 0;
for (int c = 1; c < alp+1 and last[ord[c-1]]; ++c) {
mask |= 1 << (ord[c-1]);
int L = last[ord[c]], R = last[ord[c-1]];
int want = pmask[i] ^ mask;
if (ID[want] == -1) continue;
int x = ID[want];
auto &stk = intervals[x];
while (stk.size() >= 2) {
int k = stk.size();
int l = stk[k-2][0], r = stk[k-1][1], y = min(stk[k-2][2], stk[k-1][2]);
if (l >= L and r < R) {
stk.pop_back();
stk.back() = {l, r, y};
}
else break;
}
auto [l, r, y] = stk.back();
if (l >= L and r < R) dp[i] = min(dp[i], y + 1);
}
if (ID[pmask[i]] == -1) {
ID[pmask[i]] = id++;
intervals.emplace_back();
}
intervals[ID[pmask[i]]].push_back({i, i, dp[i]});
}
cout << dp[n] << '\n';
}