PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Jubayer Nirjhor
Tester: Raja Vardhan Reddy
Editorialist: Taranpreet Singh
DIFFICULTY:
Medium-Hard
PREREQUISITES:
Suffix Arrays or Suffix Tree, LCP Array, Combinatorics and Disjoint Set Union.
PROBLEM:
Given a string S and an integer K, for each L 1 \leq L \leq |S|,
- K times, select a substring of S of length L
- Find the probability that all the chosen K strings are pairwise distinct.
All the computations are done modulo 998244353
QUICK EXPLANATION
- Build suffix array and LCP array of the given string, and
- selecting K distinct objects out of N objects (include duplicates) can be seen as the coefficient of x^K in \prod (1+a_i*x) where a_i denote the frequency of each distinct object.
- For length L, there are total |S|-L+1 suffices, out of which, some might have LCP \geq L. Considering L in decreasing order. Only when L \leq LCP_i for some i that suffix at index i and suffix at index i+1 in suffix array shall have same first L characters.
- We maintain using DSU, the size of each distinct subset, and simultaneously maintain the first 1+K coefficients of this product. Whenever we get LCP_i = L, we divide this by (1+a_i*x) and (1+b_i*x) and multiply by (1+(a_i+b_i)*x)
- finally, we consider all the K! orderings and divide by the total number of ways to select L length substrings to get final probabilities.
EXPLANATION
A simple problem
Let’s consider a different problem. You have N buckets, each of which contains A_i balls. Find out the number of ways to select K \leq N balls, such that at most one ball is selected from a bucket.
Writing in terms of polynomial, we can see that the required number of ways is given by the coefficient of x^K in \displaystyle\prod_{i = 1}^N (1+A_i*x) (One way to interpret this is that either we select no ball from the current bucket (in 1 way) or select one ball (in A_i)$ ways.)
The following illustrates how the coefficients of the above polynomial behave
polynomial: x^0 x^1 x^2 x^3
(1+a*x): 1 a 0 0
(1+a*x)*(1+b*x): 1 a+b a*b 0
(1+a*x)*(1+b*x)*(1+c*x): 1 a+b+c a*b+(a+b)*c a*b*c
(1+a*x)*(1+c*x): 1 a+c a*c 0
Suppose we have coefficients of P(x) representing a polynomial, we can find coefficients of P(x)*(1+a*x) in O(K). Similarly, If we have coefficients of P(x) such that (1+a*x) divides P(x), then we can obtain coefficients of P(x)/(1+a*x) in O(K) time.
So, we have a special DS, which stores a polynomial (Initially just 1) and supports
- Multiply a polynomial by (1+a*x) in time O(K)
- Divide a polynomial by (1+a*x) assuming (1+a*x) | P(x) in time O(K)
- Return coefficient of x^K in time O(1)
Coming back to the original problem now.
The required probability for a given L can be written as the number of ways to select K strings of length L (in any order) \times K! (considering all order of selection) divided by the total number of ways to select K strings (given by (|S|-L+1)^K)
Hence, for a fixed L, if C_L denotes the number of ways to select K distinct substrings of length L irrespective of the order of selection, then the answer for length L is given as \displaystyle\frac{C_L*K!}{(|S|-L+1)^K} (in modular arithmetic). Our task now is to compute C_L for each length L.
Let’s iterate over L in decreasing order. For length L, let’s add the suffix of length L into our DS (equivalent to adding (1+x) into our DS). Also, for length L, it might be the case that two suffices to have the first L character the same.
For example, consider string “ababc”, considering two suffices “ababc” and “abc”. Till length > 2, the two suffices remain different, but when L = 2, the two suffices have the same first L characters.
This hints towards Suffix arrays and LCP arrays. So, let’s build the suffix array and LCP array. Also, let’s maintain the current group size for each group using a disjoint set Union.
Let’s iterate over length L in decreasing order. For all pairs of adjacent suffices, if they have LCP \geq L, we need to merge them into same group. Suppose the first suffix has group size a and the second suffix has group size b.
At this point, it is required to remove (1+a*x) and (1+b*x) from our DS and add (1+(a+b)*x) into our DS.
This is all we do. We iterate over all length L in decreasing order, add (1+x) for suffix of current length, merge all groups having LCP == L, and query for the coefficient of x^K for each length, which is the required value of C_L.
Learning resources
Suffix Arrays and LCP array: here and here
Disjoint Set Union
Problem to try
KPRB
After-thought
Can this problem be solved using suffix automation or suffix tree directly? Share your approaches.
TIME COMPLEXITY
The time complexity is O(N*K+N*log(MOD)) per test case.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int A = 26;
const int K = 505;
const int N = 200010;
const int MOD = 998244353;
char s[N];
bitset <N << 1> vis;
ll fac[K], dp[K], ans[N];
vector <int> g[N << 1], in[N], out[N];
int len[N << 1], link[N << 1], sz, last;
int t, n, k, cnt[N << 1], to[N << 1][A];
inline void init() {
len[0] = 0, link[0] = -1, sz = 1, last = 0;
memset(to[0], -1, sizeof to[0]);
}
void feed (char ch) {
int cur = sz++, p = last, c = ch - 'a';
len[cur] = len[last] + 1, link[cur] = 0, cnt[cur] = 1;
memset(to[cur], -1, sizeof to[cur]);
while (~p and to[p][c] == -1) to[p][c] = cur, p = link[p];
if (~p) {
int q = to[p][c];
if (len[q] - len[p] - 1) {
int r = sz++;
len[r] = len[p] + 1, link[r] = link[q];
for (int i = 0; i < A; ++i) to[r][i] = to[q][i];
while (~p and to[p][c] == q) to[p][c] = r, p = link[p];
link[q] = link[cur] = r;
} else link[cur] = q;
} last = cur;
}
void go (int u = 0) {
for (int v : g[u]) go(v), cnt[u] += cnt[v], cnt[u] %= MOD;
}
ll bigMod (ll a, ll e) {
if (e < 0) e += MOD - 1;
ll ret = 1;
while (e) {
if (e & 1) ret = ret * a % MOD;
a = a * a % MOD, e >>= 1;
}
return ret;
}
void dfs (int u = 0) {
vis[u] = 1;
for (int i = 0; i < A; ++i) {
int v = to[u][i];
if (v == -1) continue;
if (!vis[v]) dfs(v);
}
if (~link[u]) {
int l = len[link[u]] + 1, r = len[u];
in[l].emplace_back(cnt[u]);
out[r].emplace_back(cnt[u]);
}
}
int main() {
fac[0] = 1;
for (int i = 1; i < K; ++i) fac[i] = i * fac[i - 1] % MOD;
cin >> t;
while (t--) {
scanf("%s %d", s, &k);
n = strlen(s); init();
for (int i = 1; i <= n; ++i) {
in[i].clear(), out[i].clear();
}
for (int i = 0; i < n; ++i) feed(s[i]);
for (int i = 0; i < sz; ++i) if (~link[i]) {
g[link[i]].emplace_back(i);
}
go(); dfs();
dp[0] = 1;
for (int i = 1; i <= k; ++i) dp[i] = 0;
for (int i = 1; i <= n; ++i) {
for (int s : in[i]) {
for (int j = k; j >= 1; --j) {
dp[j] += dp[j - 1] * s;
dp[j] %= MOD;
}
}
ans[i] = dp[k];
if (ans[i] < 0) ans[i] += MOD;
for (int s : out[i]) {
for (int j = 1; j <= k; ++j) {
dp[j] -= dp[j - 1] * s;
dp[j] %= MOD;
}
}
}
for (int i = 1; i <= n; ++i) {
ll mul = bigMod(n - i + 1, -k) * fac[k] % MOD;
ans[i] *= mul, ans[i] %= MOD;
}
for (int i = 1; i <= n; ++i) printf("%lld ", ans[i]);
puts("");
for (int i = 0; i < sz; ++i) {
vis[i] = cnt[i] = 0, g[i].clear();
}
}
return 0;
}
Tester's Solution
//raja1999
//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill - cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x) cout<<fixed<<val; // prints x digits after decimal in val
using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define int ll
#define endl "\n"
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
//std::ios::sync_with_stdio(false);
const int MAXN = 1 << 21;
string s;
int N, gap;
int sa[MAXN], pos[MAXN], tmp[MAXN], lcp[MAXN];
bool sufCmp(int i, int j)
{
if (pos[i] != pos[j])
return pos[i] < pos[j];
i += gap;
j += gap;
return (i < N && j < N) ? pos[i] < pos[j] : i > j;
}
void buildSA()
{
N = s.length();
int i;
rep(i, N) sa[i] = i, pos[i] = s[i];
for (gap = 1;; gap *= 2)
{
sort(sa, sa + N, sufCmp);
rep(i, N - 1) tmp[i + 1] = tmp[i] + sufCmp(sa[i], sa[i + 1]);
rep(i, N) pos[sa[i]] = tmp[i];
if (tmp[N - 1] == N - 1) break;
}
}
void buildLCP()
{ int k;
for(int i = 0, k = 0; i < N; ++i){
if (pos[i] != N - 1)
{
for(int j = sa[pos[i]+1];i+k < N && j+k < N && s[i + k] == s[j + k];)
++k;
lcp[pos[i]] = k;
if (k)--k;
}
else{
k=0;
}
}
}
int k;
int ans[200005],new_coef[505],coef[505];
stack<int>st;
int ns[200005];
vector<vi> add(200005),divi(200005);
int power(int a,int b){
int res=1;
while(b>0){
if(b%2){
res*=a;
res%=mod;
}
b/=2;
a*=a;
a%=mod;
}
return res;
}
int multiply(int c){
int i;
fd(i,k,1){
coef[i]=coef[i]+c*coef[i-1];
coef[i]%=mod;
}
return 0;
}
int divide(int c){
int i;
f(i,1,k+1){
coef[i]=(coef[i]-c*coef[i-1]);
coef[i]%=mod;
}
}
int range_update(int l,int r,int c){
if(l>r){
return 0;
}
add[l].pb(c);
divi[r].pb(c);
}
main(){
std::ios::sync_with_stdio(false); cin.tie(NULL);
int t;
double clk,tim=0,tim_1=0,tim_2=0;
cin>>t;
while(t--){
cin>>s>>k;
int n,i,p,prev,l,cur,num,den,fact=1,temp,j;
n=s.length();
// n=1
if(n==1){
if(k==1){
cout<<1<<endl;
}
else{
cout<<0<<endl;
}
continue;
}
// construct suffix array
clk=clock();
buildSA();
buildLCP();
lcp[n-1]=0;
tim+=(clock()-clk)/CLOCKS_PER_SEC;
// next smaller elements
st.push(0);
i=1;
while(i<n){
if(!st.empty()){
p=st.top();
}
while(!st.empty() && lcp[p]>lcp[i]){
ns[p]=i;
st.pop();
if(!st.empty()){
p=st.top();
}
}
st.push(i);
i++;
}
while(!st.empty()){
p=st.top();
st.pop();
ns[p]=n;
}
clk=clock();
prev=0;
rep(i,n-1){
l=(n-sa[i]);
//update
range_update(max(lcp[i]+1,prev+1),l,1);
l=lcp[i];
cur=i;
while(l>prev){
temp=ns[cur];
//update
range_update(max(lcp[temp]+1,prev+1),l,(temp-i+1));
l=lcp[temp];
cur=temp;
}
prev=lcp[i];
}
// i= n-1
l=(n-sa[i]);
// update
range_update(lcp[i-1]+1,l,1);
tim_1+=(clock()-clk)/CLOCKS_PER_SEC;
clk=clock();
coef[0]=1;
f(i,1,n+1){
rep(j,add[i].size()){
multiply(add[i][j]);
}
ans[i]=coef[k];
if(ans[i]<0){
ans[i]+=mod;
}
rep(j,divi[i].size()){
divide(divi[i][j]);
}
add[i].clear();
divi[i].clear();
}
tim_2+=(clock()-clk)/CLOCKS_PER_SEC;
f(i,1,k+1){
fact*=i;
fact%=mod;
}
f(i,1,n+1){
num=ans[i];
num*=fact;
num%=mod;
den=power(n-i+1,k);
num*=power(den,mod-2);
num%=mod;
cout<<num<<" ";
}
cout<<endl;
}
cerr<<tim<<" "<<tim_1<<" "<<tim_2<<endl;
return 0;
}
Editorialist's Solution (TLEs but clear to read)
import java.util.*;
import java.io.*;
import java.util.stream.IntStream;
class PRDRAW{
//SOLUTION BEGIN
long MOD = 998244353;
void pre() throws Exception{}
void solve(int TC) throws Exception{
String s = n();
int N = s.length();
int K = ni();
int[] sa = suffixArray(s), lcp = lcp(s, sa);
int[][] P = new int[N-1][];
for(int i = 0; i< N-1; i++)P[i] = new int[]{i, lcp[i]};
Arrays.sort(P, (int[] i1, int[] i2) -> Integer.compare(i2[1], i1[1]));
long[] ans = new long[1+N];
int ptr = 0;
int[] set = new int[N], sz = new int[N];
for(int i = 0; i< N; i++){set[i] = i;sz[i] = 1;}
long[] ways = new long[1+K];
ways[0] = 1;
for(int len = N; len >= 1; len--){
addToDS(ways, K, 1);
while(ptr < N-1 && P[ptr][1] == len){
int idx = P[ptr][0];
removeFromDS(ways, K, sz[find(set, idx)]);
removeFromDS(ways, K, sz[find(set, idx+1)]);
sz[find(set, idx)] += sz[find(set, idx+1)];
set[find(set, idx+1)] = find(set, idx);
addToDS(ways, K, sz[find(set, idx)]);
ptr++;
}
ans[len] = ways[K];
}
long fact = 1;
for(int i = 1; i<= K; i++)fact = (fact*i)%MOD;
for(int len = 1; len <= N; len++){
ans[len] = (ans[len]*fact)%MOD;
ans[len] = (ans[len]*pow(pow(N-len+1, K), MOD-2))%MOD;
}
for(int len = 1; len <= N; len++)p(ans[len]+" ");pn("");
}
long pow(long a, long p){
long o = 1;
for(;p>0;p>>=1){
if((p&1)==1)o = (o*a)%MOD;
a = (a*a)%MOD;
}
return o;
}
void addToDS(long[] ways, int K, long x){
for(int i = K; i>= 1; i--)
ways[i] = (ways[i]+ways[i-1]*x)%MOD;
}
void removeFromDS(long[] ways, int K, long x){
for(int i = 1; i<= K; i++)
ways[i] = (ways[i]+MOD-(ways[i-1]*x)%MOD)%MOD;
}
int find(int[] set, int i){return set[i] = (set[i] == i)?i:find(set, set[i]);}
//http://code-library.herokuapp.com/suffix-array/java
public static int[] suffixArray(CharSequence S) {
int n = S.length();
// stable sort of characters
int[] sa = IntStream.range(0, n).mapToObj(i -> n - 1 - i).
sorted((a, b) -> Character.compare(S.charAt(a), S.charAt(b))).mapToInt(Integer::intValue).toArray();
int[] classes = S.chars().toArray();
// sa[i] - suffix on i'th position after sorting by first len characters
// classes[i] - equivalence class of the i'th suffix after sorting by first len characters
for (int len = 1; len < n; len *= 2) {
int[] c = classes.clone();
for (int i = 0; i < n; i++) {
// condition sa[i - 1] + len < n simulates 0-symbol at the end of the string
// a separate class is created for each suffix followed by simulated 0-symbol
classes[sa[i]] = i > 0 && c[sa[i - 1]] == c[sa[i]] && sa[i - 1] + len < n && c[sa[i - 1] + len / 2] == c[sa[i] + len / 2] ? classes[sa[i - 1]] : i;
}
// Suffixes are already sorted by first len characters
// Now sort suffixes by first len * 2 characters
int[] cnt = IntStream.range(0, n).toArray();
int[] s = sa.clone();
for (int i = 0; i < n; i++) {
// s[i] - order of suffixes sorted by first len characters
// (s[i] - len) - order of suffixes sorted only by second len characters
int s1 = s[i] - len;
// sort only suffixes of length > len, others are already sorted
if (s1 >= 0)
sa[cnt[classes[s1]]++] = s1;
}
}
return sa;
}
class Suffix implements Comparable<Suffix>{
int index, rank, next;
public Suffix(int ind, int r, int nr){
index = ind; rank = r; next = nr;
}
public int compareTo(Suffix s){
if(rank != s.rank)return Integer.compare(rank, s.rank);
return Integer.compare(next, s.next);
}
}
int[] lcp(String s, int[] sa){
int n = sa.length;
int[] lcp = new int[n];
int[] invSuf = new int[n];
for(int i = 0; i< n; i++)invSuf[sa[i]] = i;
int k = 0;
for(int i = 0; i< n; i++){
if(invSuf[i] == n-1){k = 0;continue;}
int j = sa[invSuf[i]+1];
while(i+k < n && j+k < n && s.charAt(i+k) == s.charAt(j+k))k++;
lcp[invSuf[i]] = k;
if(k > 0)k--;
}
return lcp;
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new PRDRAW().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.