PRDRAW - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Jubayer Nirjhor
Tester: Raja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Suffix Arrays or Suffix Tree, LCP Array, Combinatorics and Disjoint Set Union.

PROBLEM:

Given a string S and an integer K, for each L 1 \leq L \leq |S|,

  • K times, select a substring of S of length L
  • Find the probability that all the chosen K strings are pairwise distinct.

All the computations are done modulo 998244353

QUICK EXPLANATION

  • Build suffix array and LCP array of the given string, and
  • selecting K distinct objects out of N objects (include duplicates) can be seen as the coefficient of x^K in \prod (1+a_i*x) where a_i denote the frequency of each distinct object.
  • For length L, there are total |S|-L+1 suffices, out of which, some might have LCP \geq L. Considering L in decreasing order. Only when L \leq LCP_i for some i that suffix at index i and suffix at index i+1 in suffix array shall have same first L characters.
  • We maintain using DSU, the size of each distinct subset, and simultaneously maintain the first 1+K coefficients of this product. Whenever we get LCP_i = L, we divide this by (1+a_i*x) and (1+b_i*x) and multiply by (1+(a_i+b_i)*x)
  • finally, we consider all the K! orderings and divide by the total number of ways to select L length substrings to get final probabilities.

EXPLANATION

A simple problem
Let’s consider a different problem. You have N buckets, each of which contains A_i balls. Find out the number of ways to select K \leq N balls, such that at most one ball is selected from a bucket.

Writing in terms of polynomial, we can see that the required number of ways is given by the coefficient of x^K in \displaystyle\prod_{i = 1}^N (1+A_i*x) (One way to interpret this is that either we select no ball from the current bucket (in 1 way) or select one ball (in A_i)$ ways.)

The following illustrates how the coefficients of the above polynomial behave

polynomial: 				x^0		x^1			x^2 			x^3
(1+a*x):					1		a			0				0
(1+a*x)*(1+b*x):	 		1		a+b			a*b				0
(1+a*x)*(1+b*x)*(1+c*x): 	1		a+b+c		a*b+(a+b)*c		a*b*c
(1+a*x)*(1+c*x):			1		a+c			a*c				0

Suppose we have coefficients of P(x) representing a polynomial, we can find coefficients of P(x)*(1+a*x) in O(K). Similarly, If we have coefficients of P(x) such that (1+a*x) divides P(x), then we can obtain coefficients of P(x)/(1+a*x) in O(K) time.

So, we have a special DS, which stores a polynomial (Initially just 1) and supports

  • Multiply a polynomial by (1+a*x) in time O(K)
  • Divide a polynomial by (1+a*x) assuming (1+a*x) | P(x) in time O(K)
  • Return coefficient of x^K in time O(1)

Coming back to the original problem now.

The required probability for a given L can be written as the number of ways to select K strings of length L (in any order) \times K! (considering all order of selection) divided by the total number of ways to select K strings (given by (|S|-L+1)^K)

Hence, for a fixed L, if C_L denotes the number of ways to select K distinct substrings of length L irrespective of the order of selection, then the answer for length L is given as \displaystyle\frac{C_L*K!}{(|S|-L+1)^K} (in modular arithmetic). Our task now is to compute C_L for each length L.

Let’s iterate over L in decreasing order. For length L, let’s add the suffix of length L into our DS (equivalent to adding (1+x) into our DS). Also, for length L, it might be the case that two suffices to have the first L character the same.

For example, consider string “ababc”, considering two suffices “ababc” and “abc”. Till length > 2, the two suffices remain different, but when L = 2, the two suffices have the same first L characters.

This hints towards Suffix arrays and LCP arrays. So, let’s build the suffix array and LCP array. Also, let’s maintain the current group size for each group using a disjoint set Union.

Let’s iterate over length L in decreasing order. For all pairs of adjacent suffices, if they have LCP \geq L, we need to merge them into same group. Suppose the first suffix has group size a and the second suffix has group size b.

At this point, it is required to remove (1+a*x) and (1+b*x) from our DS and add (1+(a+b)*x) into our DS.

This is all we do. We iterate over all length L in decreasing order, add (1+x) for suffix of current length, merge all groups having LCP == L, and query for the coefficient of x^K for each length, which is the required value of C_L.

Learning resources
Suffix Arrays and LCP array: here and here
Disjoint Set Union

Problem to try
KPRB

After-thought
Can this problem be solved using suffix automation or suffix tree directly? Share your approaches.

TIME COMPLEXITY

The time complexity is O(N*K+N*log(MOD)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int A = 26;
const int K = 505;
const int N = 200010;
const int MOD = 998244353;

char s[N];
bitset <N << 1> vis;
ll fac[K], dp[K], ans[N];
vector <int> g[N << 1], in[N], out[N];
int len[N << 1], link[N << 1], sz, last;
int t, n, k, cnt[N << 1], to[N << 1][A];

inline void init() {
  len[0] = 0, link[0] = -1, sz = 1, last = 0;
  memset(to[0], -1, sizeof to[0]);
}

void feed (char ch) {
  int cur = sz++, p = last, c = ch - 'a';
  len[cur] = len[last] + 1, link[cur] = 0, cnt[cur] = 1;
  memset(to[cur], -1, sizeof to[cur]);
  while (~p and to[p][c] == -1) to[p][c] = cur, p = link[p];
  if (~p) {
	int q = to[p][c];
	if (len[q] - len[p] - 1) {
	  int r = sz++;
	  len[r] = len[p] + 1, link[r] = link[q];
	  for (int i = 0; i < A; ++i) to[r][i] = to[q][i];
	  while (~p and to[p][c] == q) to[p][c] = r, p = link[p];
	  link[q] = link[cur] = r;
	} else link[cur] = q;
  } last = cur;
}

void go (int u = 0) {
  for (int v : g[u]) go(v), cnt[u] += cnt[v], cnt[u] %= MOD;
}

ll bigMod (ll a, ll e) {
  if (e < 0) e += MOD - 1;
  ll ret = 1;
  while (e) {
	if (e & 1) ret = ret * a % MOD;
	a = a * a % MOD, e >>= 1;
  }
  return ret;
}

void dfs (int u = 0) {
  vis[u] = 1;
  for (int i = 0; i < A; ++i) {
	int v = to[u][i];
	if (v == -1) continue;
	if (!vis[v]) dfs(v);
  }
  if (~link[u]) {
	int l = len[link[u]] + 1, r = len[u];
	in[l].emplace_back(cnt[u]);
	out[r].emplace_back(cnt[u]);
  }
}

int main() {
  fac[0] = 1;
  for (int i = 1; i < K; ++i) fac[i] = i * fac[i - 1] % MOD;
  cin >> t;
  while (t--) {
	scanf("%s %d", s, &k);
	n = strlen(s); init();
	for (int i = 1; i <= n; ++i) {
	  in[i].clear(), out[i].clear();
	}
	for (int i = 0; i < n; ++i) feed(s[i]);
	for (int i = 0; i < sz; ++i) if (~link[i]) {
	  g[link[i]].emplace_back(i);
	}
	go(); dfs();
	dp[0] = 1;
	for (int i = 1; i <= k; ++i) dp[i] = 0;
	for (int i = 1; i <= n; ++i) {
	  for (int s : in[i]) {
	    for (int j = k; j >= 1; --j) {
	      dp[j] += dp[j - 1] * s;
	      dp[j] %= MOD;
	    }
	  }
	  ans[i] = dp[k];
	  if (ans[i] < 0) ans[i] += MOD;
	  for (int s : out[i]) {
	    for (int j = 1; j <= k; ++j) {
	      dp[j] -= dp[j - 1] * s;
	      dp[j] %= MOD;
	    }
	  }
	}
	for (int i = 1; i <= n; ++i) {
	  ll mul = bigMod(n - i + 1, -k) * fac[k] % MOD;
	  ans[i] *= mul, ans[i] %= MOD;
	}
	for (int i = 1; i <= n; ++i) printf("%lld ", ans[i]);
	puts("");
	for (int i = 0; i < sz; ++i) {
	  vis[i] = cnt[i] = 0, g[i].clear();
	}
  }
  return 0;
}
Tester's Solution
//raja1999

//#pragma comment(linker, "/stack:200000000")
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")

#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string> 
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip> 
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16)a; cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 <<endl;prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val

using namespace std;
using namespace __gnu_pbds;
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define int ll
#define endl "\n"

typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;


//std::ios::sync_with_stdio(false);

const int MAXN = 1 << 21;
string s;
int N, gap;
int sa[MAXN], pos[MAXN], tmp[MAXN], lcp[MAXN];

bool sufCmp(int i, int j)
{
	if (pos[i] != pos[j])
		return pos[i] < pos[j];
	i += gap;
	j += gap;
	return (i < N && j < N) ? pos[i] < pos[j] : i > j;
}

void buildSA()
{
	N = s.length();
	int i;
	rep(i, N) sa[i] = i, pos[i] = s[i];
	for (gap = 1;; gap *= 2)
	{
		sort(sa, sa + N, sufCmp);
		rep(i, N - 1) tmp[i + 1] = tmp[i] + sufCmp(sa[i], sa[i + 1]);
		rep(i, N) pos[sa[i]] = tmp[i];
		if (tmp[N - 1] == N - 1) break;
	}
}

void buildLCP()
{	int k;
	for(int i = 0, k = 0; i < N; ++i){
		if (pos[i] != N - 1)
		{
			for(int j = sa[pos[i]+1];i+k < N && j+k < N && s[i + k] == s[j + k];)
				++k;
			lcp[pos[i]] = k;
			if (k)--k;
		}
		else{
			k=0;
		}
	
	} 
}

int k;
int ans[200005],new_coef[505],coef[505];
stack<int>st;
int ns[200005];
vector<vi> add(200005),divi(200005);
int power(int a,int b){
	int res=1;
	while(b>0){
		if(b%2){
			res*=a;
			res%=mod;
		}
		b/=2;
		a*=a;
		a%=mod;
	}
	return res;
}

int multiply(int c){
	int i; 
	fd(i,k,1){
		coef[i]=coef[i]+c*coef[i-1];
		coef[i]%=mod;
	}
	return 0;
}

int divide(int c){
	int i;
	f(i,1,k+1){
		coef[i]=(coef[i]-c*coef[i-1]);
		coef[i]%=mod;
	}
}
int range_update(int l,int r,int c){
	if(l>r){
		return 0;
	}
	add[l].pb(c);
	divi[r].pb(c);
}

main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	double clk,tim=0,tim_1=0,tim_2=0;
	cin>>t;
	while(t--){
		cin>>s>>k;
		int n,i,p,prev,l,cur,num,den,fact=1,temp,j;
		n=s.length();
		// n=1
		if(n==1){
			if(k==1){
				cout<<1<<endl;
			}
			else{
				cout<<0<<endl;
			}
			continue;
		}

		// construct suffix array
		clk=clock();
		buildSA();
		buildLCP();
		lcp[n-1]=0;
		tim+=(clock()-clk)/CLOCKS_PER_SEC;
		// next smaller elements
		st.push(0);
		i=1;
		while(i<n){
			if(!st.empty()){
				p=st.top();
			}
			while(!st.empty() && lcp[p]>lcp[i]){
				ns[p]=i;
				st.pop();
				if(!st.empty()){
					p=st.top();
				}
			}
			st.push(i);
			i++;
		}
		while(!st.empty()){
			p=st.top();
			st.pop();
			ns[p]=n;
		}

		clk=clock();
		prev=0;
		rep(i,n-1){
			l=(n-sa[i]);
			//update
			range_update(max(lcp[i]+1,prev+1),l,1);
			l=lcp[i];
			cur=i;
			while(l>prev){
				temp=ns[cur];
				//update
				range_update(max(lcp[temp]+1,prev+1),l,(temp-i+1));
				l=lcp[temp];
				cur=temp;
			}
			prev=lcp[i];
		}
		// i= n-1
		l=(n-sa[i]);
		// update
		range_update(lcp[i-1]+1,l,1);
		tim_1+=(clock()-clk)/CLOCKS_PER_SEC;

		clk=clock();
		coef[0]=1;
		f(i,1,n+1){
			rep(j,add[i].size()){
				multiply(add[i][j]);
			}
			ans[i]=coef[k];
			if(ans[i]<0){
				ans[i]+=mod;
			}
			rep(j,divi[i].size()){
				divide(divi[i][j]);
			}
			add[i].clear();
			divi[i].clear();
		}
		tim_2+=(clock()-clk)/CLOCKS_PER_SEC;
		f(i,1,k+1){
			fact*=i;
			fact%=mod;
		}
		f(i,1,n+1){
			num=ans[i];
			num*=fact;
			num%=mod;
			den=power(n-i+1,k);
			num*=power(den,mod-2);
			num%=mod;
			cout<<num<<" ";
		}
		cout<<endl;
	}
	cerr<<tim<<" "<<tim_1<<" "<<tim_2<<endl;
	return 0;
} 
Editorialist's Solution (TLEs but clear to read)
import java.util.*;
import java.io.*;
import java.util.stream.IntStream;
class PRDRAW{
	//SOLUTION BEGIN
	long MOD = 998244353;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    String s = n();
	    int N = s.length();
	    int K = ni();
	    int[] sa = suffixArray(s), lcp = lcp(s, sa);
	    int[][] P = new int[N-1][];
	    for(int i = 0; i< N-1; i++)P[i] = new int[]{i, lcp[i]};
	    Arrays.sort(P, (int[] i1, int[] i2) -> Integer.compare(i2[1], i1[1]));
	    long[] ans = new long[1+N];
	    int ptr = 0;
	    int[] set = new int[N], sz = new int[N];
	    for(int i = 0; i< N; i++){set[i] = i;sz[i] = 1;}
	    
	    long[] ways = new long[1+K];
	    ways[0] = 1;
	    for(int len = N; len >= 1; len--){
	        addToDS(ways, K, 1);
	        while(ptr < N-1 && P[ptr][1] == len){
	            int idx = P[ptr][0];
	            removeFromDS(ways, K, sz[find(set, idx)]);
	            removeFromDS(ways, K, sz[find(set, idx+1)]);
	            sz[find(set, idx)] += sz[find(set, idx+1)];
	            set[find(set, idx+1)] = find(set, idx);
	            addToDS(ways, K, sz[find(set, idx)]);
	            ptr++;
	        }
	        ans[len] = ways[K];
	    }
	    long fact = 1;
	    for(int i = 1; i<= K; i++)fact = (fact*i)%MOD;
	    for(int len = 1; len <= N; len++){
	        ans[len] = (ans[len]*fact)%MOD;
	        ans[len] = (ans[len]*pow(pow(N-len+1, K), MOD-2))%MOD;
	    }
	    for(int len = 1; len <= N; len++)p(ans[len]+" ");pn("");
	}
	long pow(long a, long p){
	    long o = 1;
	    for(;p>0;p>>=1){
	        if((p&1)==1)o = (o*a)%MOD;
	        a = (a*a)%MOD;
	    }
	    return o;
	}
	void addToDS(long[] ways, int K, long x){
	    for(int i = K; i>= 1; i--)
	        ways[i] = (ways[i]+ways[i-1]*x)%MOD;
	}
	void removeFromDS(long[] ways, int K, long x){
	    for(int i = 1; i<= K; i++)
	        ways[i] = (ways[i]+MOD-(ways[i-1]*x)%MOD)%MOD;
	}
	int find(int[] set, int i){return set[i] = (set[i] == i)?i:find(set, set[i]);}
	//http://code-library.herokuapp.com/suffix-array/java
	public static int[] suffixArray(CharSequence S) {
	    int n = S.length();

	    // stable sort of characters
	    int[] sa = IntStream.range(0, n).mapToObj(i -> n - 1 - i).
	            sorted((a, b) -> Character.compare(S.charAt(a), S.charAt(b))).mapToInt(Integer::intValue).toArray();

	    int[] classes = S.chars().toArray();
	    // sa[i] - suffix on i'th position after sorting by first len characters
	    // classes[i] - equivalence class of the i'th suffix after sorting by first len characters

	    for (int len = 1; len < n; len *= 2) {
	        int[] c = classes.clone();
	        for (int i = 0; i < n; i++) {
	            // condition sa[i - 1] + len < n simulates 0-symbol at the end of the string
	            // a separate class is created for each suffix followed by simulated 0-symbol
	            classes[sa[i]] = i > 0 && c[sa[i - 1]] == c[sa[i]] && sa[i - 1] + len < n && c[sa[i - 1] + len / 2] == c[sa[i] + len / 2] ? classes[sa[i - 1]] : i;
	        }
	        // Suffixes are already sorted by first len characters
	        // Now sort suffixes by first len * 2 characters
	        int[] cnt = IntStream.range(0, n).toArray();
	        int[] s = sa.clone();
	        for (int i = 0; i < n; i++) {
	            // s[i] - order of suffixes sorted by first len characters
	            // (s[i] - len) - order of suffixes sorted only by second len characters
	            int s1 = s[i] - len;
	            // sort only suffixes of length > len, others are already sorted
	            if (s1 >= 0)
	                sa[cnt[classes[s1]]++] = s1;
	        }
	    }
	    return sa;
	}
	class Suffix implements Comparable<Suffix>{
	    int index, rank, next;
	    public Suffix(int ind, int r, int nr){
	        index = ind; rank = r; next = nr;
	    }
	    public int compareTo(Suffix s){
	        if(rank != s.rank)return Integer.compare(rank, s.rank);
	        return Integer.compare(next, s.next);
	    }
	}
	int[] lcp(String s, int[] sa){
	    int n = sa.length;
	    int[] lcp = new int[n];
	    int[] invSuf = new int[n];
	    for(int i = 0; i< n; i++)invSuf[sa[i]] = i;
	    int k = 0;
	    for(int i = 0; i< n; i++){
	        if(invSuf[i] == n-1){k = 0;continue;}
	        int j = sa[invSuf[i]+1];
	        while(i+k < n && j+k < n && s.charAt(i+k) == s.charAt(j+k))k++;
	        lcp[invSuf[i]] = k;
	        if(k > 0)k--;
	    }
	    return lcp;
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new PRDRAW().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

2 Likes

Somehow, the link is not working.
It’s showing - “Problem is not visible right now. Please try again later”

That’s a problem from ACM ICPC Kanpur Regionals 2019-20.
P.S. According to my memory, only 6 teams could solve it then.

1 Like

No. The problem from Kanpur Regional asks the probability that all strings drawn are the same. Which is easy; just adding occ^K in a range [l, r]. But this problem is asking probability that the strings are different.

1 Like

That was in reply to topcoder31’s comment and the problem I was talking about was KTHPROB.

Oh, I had typed the link wrong.

Corrected now. You should be able to access it.

1 Like

Yes, that’s correct :slight_smile: