SUMRANGEPOW - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: arvindf232
Testers: iceknight1093, rivalq
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Segment trees with lazy propagation

PROBLEM:

You have an array A and an integer K.
Let f(l, r) denote the number of distinct integers in the subarray [A_l, A_{l+1}, \ldots, A_r].

Find

\sum_{l=1}^N \sum_{r=l}^N f^k (l, r)

EXPLANATION:

First, let’s try to solve this problem for K = 1, and then generalize that solution to higher K.

Solving K = 1

Our task is to compute the sum of the number of distinct elements across all subarrays.
While counting the number of distinct elements in a range is easy in \mathcal{O}(N), and even answering Q queries about different ranges is doable in \mathcal{O}((N+Q)\log N) (see here for several ways to do it, both online and offline), querying for every subarray is still too slow.

However, we can slightly modify one of the ways to answer queries offline to achieve what we want.

Let’s fix r, and then let D be an array of size N such that D_l = f(l, r).
If we know D for r, we can quickly update it for r+1 using the following observations:

  • If A_{r+1} occurs in the subarray A[l, r], then f(l, r+1) = f(l, r); i.e, there’s no need to change D_l.
  • Otherwise, f(l, r+1) = f(l, r) + 1; i.e, D_l increases by exactly 1.
  • Further, notice that if D_l remains unchanged, so does D_{l-1}; and if D_l increases by 1, so does D_{l+1}.

Essentially, this induces a range update on D: a certain range of values is increased by 1, and everything else remains unchanged.
Finding which range increases isn’t too hard either: we just need to find the previous occurrence of A_{r+1}.

In particular, we have the following algorithm:

  • Let D be as described above, initially containing all zeros.
  • We’ll consider r from 1 to N.
  • At r, let p be the position of the previous occurrence of A_r.
  • Then, add 1 to the range [p+1, r] of D.
  • Finally, add D_1 + D_2 + \ldots + D_r to the answer.

To do this quickly, we need a data structure that allows us to quickly add a value to a range, and find the sum of values in a range.
This can be done using a segment tree with lazy propagation.

We do N updates and N queries, so K = 1 is solved in \mathcal{O}(N\log N).

Solving K \gt 1

Now, we need to generalize this solution to higher K.

Let’s do almost exactly the same thing; except that now, D_l = f^K(l, r).

Our updates and queries are also exactly the same; the only difference is in the part where we propagate updates down the segment tree.
In particular, we need to update the sum of K-th powers of several values, if all those values are increased by x.

Suppose we know \sum_i a_i^K, and we want to compute \sum_i (a_i + x)^K. That can be done as follows:

\sum_i (a_i + x)^K = \sum_i \sum_{j=0}^K \binom{K}{j} a_i^j x^{K-j} \\ = \sum_{j=0}^K \binom{K}{j}x^{K-j}\sum_i a_i^j

Notice that the second summation is simply the sum of j-th powers of all the a_i, and is information that doesn’t depend on x.
If we maintained that information, this update can be done in \mathcal{O}(K).

So, in each node of the segtree we keep the sum of the j-th powers of the values in it, for each 0 \leq j \leq K.
When pushing an update, simply update all of these K+1 sums as mentioned above: this takes \mathcal{O}(K^2) time — \mathcal{O}(K) for each.

This gives us a solution in \mathcal{O}(K^2N\log N).

TIME COMPLEXITY:

\mathcal{O}(K^2 N\log N) per testcase.

CODE:

Setter's code (Kotlin)
// 2022.12.26 at 14:48:03 HKT
import java.io.BufferedInputStream
import java.io.File
import java.io.PrintWriter
import kotlin.system.measureTimeMillis
import java.util.TreeMap
import java.util.TreeSet
import kotlin.random.Random
import kotlin.random.nextInt
 
// 1. Modded
const val p = 998244353L
const val pI = p.toInt()
fun Int.adjust():Int{ if(this >= pI){ return this  - pI }else if (this < 0){ return this + pI };return this }
fun Int.snap():Int{ if(this >= pI){return this - pI} else return this}
infix fun Int.mm(b:Int):Int{ return ((this.toLong() * b) % pI).toInt() }
infix fun Int.mp(b:Int):Int{ val ans = this + b;return if(ans >= pI) ans - pI else ans }
infix fun Int.ms(b:Int):Int{ val ans = this - b;return if(ans < 0) ans + pI else ans }
fun Int.inverse():Int = intPow(this,pI-2,pI)
infix fun Int.modDivide(b:Int):Int{ return this mm (b.inverse()) }
fun intPow(x:Int,e:Int,m:Int):Int{
    var X = x ; var E =e ; var Y = 1
    while(E > 0){
        if(E and 1 == 0){
            X = ((1L * X * X) % m).toInt()
            E = E shr 1
        }else{
            Y = ((1L * X * Y) % m).toInt()
            E -= 1
        }
    }
    return Y
}
// 2. DP initial values
const val plarge = 1_000_000_727
const val nlarge = -plarge
const val phuge = 2_727_000_000_000_000_000L
const val nhuge = -phuge
// 3. convenience conversions
val Boolean.chi:Int get() = if(this) 1 else 0 //characteristic function
val BooleanArray.chiarray:IntArray get() = IntArray(this.size){this[it].chi}
val Char.code :Int get() = this.toInt() -  'a'.toInt()
//3. hard to write stuff
fun IntArray.put(i:Int,v:Int){ this[i] = (this[i] + v).adjust() }
val mint:MutableList<Int> get() = mutableListOf<Int>()
val mong:MutableList<Long> get() = mutableListOf<Long>()
//4. more outputs
fun List<Char>.conca():String = this.joinToString("")
val CharArray.conca :String get() = this.joinToString("")
val IntArray.conca :String get() = this.joinToString(" ")
@JvmName("concaInt")
fun List<Int>.conca():String = this.joinToString(" ")
val LongArray.conca:String get() = this.joinToString(" ")
@JvmName("concaLong")
fun List<Long>.conca():String = this.joinToString(" ")
//5. Pair of ints
const val longmask = (1L shl 32) - 1
fun makepair(a:Int, b:Int):Long = (a.toLong() shl 32) xor (longmask and b.toLong())
val Long.first get() = (this ushr 32).toInt()
val Long.second get() = this.toInt()
//6. strings
val String.size get() = this.length
const val randCount = 100
//7. bits
fun Int.has(i:Int):Boolean = (this and (1 shl i) != 0)
fun Long.has(i:Int):Boolean = (this and (1L shl i) != 0L)
//8 TIME
inline fun TIME(f:()->Unit){
    val t = measureTimeMillis(){
        f()
    }
    println("$t ms")
}
//9.ordered pair
fun order(a:Int, b:Int):Pair<Int,Int>{
    return Pair(minOf(a,b), maxOf(a,b))
}
//10 rand
fun rand(x:Int) = Random.nextInt(x)
fun rand(x:IntRange) = Random.nextInt(x)
const val interactive = false
object Reader{
    private const val BS = 1 shl 16
    private const val NC = 0.toChar()
    private val buf = ByteArray(BS)
    private var bId = 0
    private var size = 0
    private var c = NC
 
    var warningActive = true
    var fakein = StringBuilder()
 
    private var IN: BufferedInputStream = BufferedInputStream(System.`in`, BS)
    val OUT: PrintWriter = PrintWriter(System.out)
 
    private val char: Char
        get() {
            if(interactive){
                return System.`in`.read().toChar()
            }
            while (bId == size) {
                size = IN.read(buf) // no need for checked exceptions
                if (size == -1) return NC
                bId = 0
            }
            return buf[bId++].toChar()
        }
 
    fun nextInt(): Int {
        var neg = false
        if (c == NC) c = char
        while (c < '0' || c > '9') {
            if (c == '-') neg = true
            c = char
        }
        var res = 0
        while (c in '0'..'9') {
            res = (res shl 3) + (res shl 1) + (c - '0')
            c = char
        }
        return if (neg) -res else res
    }
    fun nextLong(): Long {
        var neg = false
        if (c == NC) c = char
        while (c < '0' || c > '9') {
            if (c == '-') neg = true
            c = char
        }
        var res = 0L
        while (c in '0'..'9') {
            res = (res shl 3) + (res shl 1) + (c - '0')
            c = char
        }
        return if (neg) -res else res
    }
    fun nextString():String{
        val ret = StringBuilder()
        while (true){
            c = char
            if(!isWhitespace(c)){ break}
        }
        ret.append(c)
        while (true){
            c = char
            if(isWhitespace(c)){ break}
            ret.append(c)
        }
        return ret.toString()
    }
    fun isWhitespace(c:Char):Boolean{
        return c == ' ' || c == '\n' || c == '\r' || c == '\t'
    }
    fun rerouteInput(){
        if(warningActive){
            put("Custom test enabled")
            println("Custom test enabled")
            warningActive = false
        }
        val S = fakein.toString()
        println("New Case ")
        println(S.take(80))
        println("...")
        fakein.clear()
        IN = BufferedInputStream(S.byteInputStream(),BS)
    }
    fun flush(){
        OUT.flush()
    }
    fun takeFile(name:String){
        IN = BufferedInputStream(File(name).inputStream(),BS)
    }
}
fun eat(){ val st1 = TreeSet<Int>(); val st2 = TreeMap<Int,Int>()}
fun put(aa:Any){
    Reader.OUT.println(aa)
    if(interactive){ Reader.flush()}
}
fun done(){ Reader.OUT.close() }
fun share(aa:Any){
    if(aa is IntArray){Reader.fakein.append(aa.joinToString(" "))}
    else if(aa is LongArray){Reader.fakein.append(aa.joinToString(" "))}
    else if(aa is List<*>){Reader.fakein.append(aa.toString())}
    else{Reader.fakein.append(aa.toString())}
    Reader.fakein.append("\n")
}
 
val getintfast:Int get() = Reader.nextInt()
val getint:Int get(){ val ans = getlong ; if(ans > Int.MAX_VALUE) IntArray(1000000000); return ans.toInt() }
val getlong:Long get() = Reader.nextLong()
val getstr:String get() = Reader.nextString()
fun getline(n:Int):IntArray{
    return IntArray(n){getint}
}
fun getlineL(n:Int):LongArray{
    return LongArray(n){getlong}
}
var dmark = -1
infix fun Any.dei(a:Any){
    dmark++
    var str = "<${dmark}>   "
    debug()
    if(this is String){ str += this
    }else if(this is Int){ str += this.toString()
    }else if(this is Long){ str += this.toString()
    }else{ str += this.toString()}
    if(a is List<*>){ println("$str : ${a.joinToString(" ")}")
    }else if(a is IntArray){ println("$str : ${a.joinToString(" ")}")
    }else if(a is LongArray){ println("$str : ${a.joinToString(" ")}")
    }else if(a is BooleanArray){ println("$str :${a.map{if(it)'1' else '0'}.joinToString(" ")}")
    }else if(a is Array<*>){
        println("$str : ")
        for(c in a){if(c is IntArray){println(c.joinToString(" "))}
        else if(c is LongArray){println(c.joinToString(" "))}
        else if(c is BooleanArray){println(c.map { if(it) '1' else '0' }.joinToString(""))
        }
        }
        println()
    }else{ println("$str : $a")
    }
}
const val just = " "
fun crash(){
    throw Exception("Bad programme")}
fun assert(a:Boolean){
    if(!a){
        throw Exception("Failed Assertion")
    }}
enum class solveMode {
    real, rand, tc
}
object solve{
    var mode:solveMode = solveMode.real
    var tcNum:Int = 0
    var rand:()->Unit = {}
    var TC:MutableMap<Int,()->Unit> = mutableMapOf()
    var tn:Long = 0
    fun cases(onecase:()->Unit){
        val t = if(mode == solveMode.real){if(singleCase) 1 else getint} else if(mode == solveMode.tc){1 } else randCount
        if(pI != 998_244_353 && pI != 1_000_000_007){
            throw Exception("Not usual primes!")
        }
        if(t == 1 && mode != solveMode.real){
            tn = System.currentTimeMillis()
        }
        repeat(t){
            if(mode == solveMode.tc){
                TC[tcNum]?.let { it() }
                Reader.rerouteInput()
            }else if(mode == solveMode.rand){
                rand()
                Reader.rerouteInput()
            }
            onecase()
        }
        if(t == 1 && mode != solveMode.real){
            val dt = System.currentTimeMillis() - tn
            println("Time $dt ms ")
        }
    }
    inline fun singleCase(a:solve.()->Unit){
        val t = if(mode != solveMode.rand){1} else randCount
        repeat(t) { a() }
    }
    fun rand(a:()->Unit){
        this.rand = a
    }
    fun tc(id:Int = 0,a:()->Unit){
        TC[id] = a
    }
    fun usetc(a:Int = 0 ){
        this.tcNum = a
        this.mode = solveMode.tc
    }
    fun userand(){
        this.mode = solveMode.rand
    }
}
fun debug(){}
 
val binom = Array(100300){IntArray(7)}
val pow = Array(100300){IntArray(7)}
fun choose(n:Int, r:Int):Int{
    return binom[n-r][r]
}
 
fun IntArray.shiftright(a:Int){
    //if arr[i] represents sum of x^i, then the results represents sum of (x+a)^i
    if(a == 0) return
    val new = IntArray(this.size)
    val k = this.lastIndex
    for(start in k downTo 0){
        for(end in start downTo 0){
            new[start] = new[start] mp (this[end] mm choose(start,end) mm pow[a][start-end])
        }
    }
    new.copyInto(this)
}
 
class lazyitem(var lazysum:Int = 0, val sum:IntArray){
    companion object{
        const val k = 5
        val empty:lazyitem get() = lazyitem(0, IntArray(k+1))
        fun construct(v:Int):lazyitem{
            assert(v == 0 )
            val have = IntArray(k+1)
            have[0] = 1
            return lazyitem(v,have)
        }
    }
    fun addlazy(v:Int){
        // can be called on bottom level nodes. Doesn't matter
        this.sum.shiftright(v)
        this.lazysum += v
    }
    fun push(left:lazyitem, right:lazyitem){
        //Actually, pushing is unnecessary in this problem. This function is never called
        left.addlazy(this.lazysum)
        right.addlazy(this.lazysum)
        this.lazysum = 0
    }
    fun mergewrite(left:lazyitem, right:lazyitem){
        //use lazy items from self
        for(i in 0..k){
            this.sum[i] = left.sum[i] mp right.sum[i]
        }
        this.sum.shiftright(this.lazysum)
    }
 
    fun debug():String{
        return "${sum[0]}"
    }
}
 
 
const val propagateOnly = false
class lazyBuildableSegTree (withArray: IntArray){
    //Items lenght must be a power of 2
 
    val nSuggest = withArray.size
    val n = if(nSuggest >= 2) (nSuggest - 1).takeHighestOneBit() shl 1 else nSuggest
    val levels = (31 - n.countLeadingZeroBits()) // (Levels) amount of levels then a layer of leaf
    val arr = Array(n * 2 ){
        if(it  - n in withArray.indices) lazyitem.construct(withArray[it - n]) else lazyitem.empty}
 
    init{
        updateAll()
    }
 
    private fun updateNode(i:Int, level:Int){
        // both nodes + Lazy
        if(!propagateOnly)
            arr[i].mergewrite(arr[i shl 1], arr[(i shl 1)+1])
    }
 
    fun updatePath(i:Int){
        // i is the endpoint, typically (n+i)
        // bottom up, recalculate
        var here = i
        var level = 0
        while(here > 1){
            here = here shr 1
            level ++
            updateNode(here,level)
        }
    }
    fun updateAll(){
        for(i in n-1 downTo 1){
            updateNode(i,i.calculateLevel)
        }
    }
    fun pushPath(i:Int){
        // i must be in [n,2n)
        for (s in levels downTo 1) {
            val i1 = i shr s
            arr[i1].push(arr[i1 shl 1], arr[(i1 shl 1) + 1])
        }
    }
    val firstIndex = n
    val lastIndex = (n shl 1 ) - 1
    val indices = firstIndex..lastIndex
 
    inline fun segDivision(l:Int, r:Int, act:(index:Int,level:Int)->Unit){
        var left = l + n
        var right = r + n + 1
        var level = 0
        while(left < right){
            if(left and 1 != 0){
 
                act(left,level)
                left += 1
            }
            if(right and 1 != 0){
                right -= 1
                act(right,level)
            }
            left = left shr 1
            right = right shr 1
            level ++
        }
    }
    fun rangeApplyLazy(l:Int, r:Int, inc:Int){
        if(r < l) return
 
        //THere is no need to push
//        pushPath(l+n)
//        pushPath(r+n)
        segDivision(l,r) { i, level ->
            arr[i].addlazy(inc)        // apply the update to the node , then store further children updates
        }
        updatePath(l+n)
        updatePath(r+n)
    }
    fun rangeQueryLazy(l:Int,r:Int): lazyitem {
        pushPath(l+n)
        pushPath(r+n)
        var ret = lazyitem.empty
        segDivision(l,r){i, level ->
            val new = lazyitem.empty
            new.mergewrite(ret,arr[i])
            ret = new
        }
        return ret
    }
 
    fun pointQuery(i:Int):lazyitem{
        pushPath(i+n)
        return arr[i+n]
    }
    inline fun pointset(i:Int, act:(x:lazyitem)->Unit){
        pushPath(i+n)
        act(arr[i+n])
        updatePath(i+n)
    }
 
    val Int.leftNode:Int get(){
        // assert(this <= n )
        return this shl 1
    }
    val Int.rightNode:Int get(){
        // assert(this <= n)
        return (this shl 1) + 1
    }
    val Int.endpoints:Pair<Int,Int> get(){
        val offSet = this - this.takeHighestOneBit()
        val widthLevel = levels - (31 - this.countLeadingZeroBits())
        return Pair(offSet shl widthLevel, (offSet shl widthLevel) + (1 shl widthLevel) - 1)
    }
    val Int.calculateLevel:Int get(){
        val base = 31 - this.countLeadingZeroBits()
        return levels - base
    }
 
 
    fun lazyPrint():String{
        val ret = mutableListOf<String>()
        for(i in 0 until n){
            ret.add(pointQuery(i).debug())
        }
        return ret.joinToString(" ")
    }
 
    inline fun segDivisonOrdered(l:Int, r:Int, act:(index:Int)->Unit){
        var left = l + n
        var right = r + n + 1
        var level = 0
        while(left < right){
            if(left and 1 != 0){
                act(left)
                left += 1
            }
            left = left shr 1
            right = right shr 1
            level ++
        }
        right = r + n + 1
        for(lev in level - 1 downTo 0){
            if((right shr lev) and 1 == 1){
                act((right shr lev) - 1 )
            }
        }
    }
//    override fun toString(): String {
//        return this.lazyPrint()
//    }
}
const val singleCase = false
fun main(){
    val a = binom.size
    val b = binom[0].size
    for(i in 0 until a){
        binom[i][0] = 1
    }
    for(j in 0 until b){
        binom[0][j] = 1
    }
    for(i in 1 until a){
        for(j in 1 until b){
            binom[i][j] = binom[i][j-1] mp binom[i-1][j]
        }
    }
    for(i in 0 until 100300){
        pow[i][0] = 1
        for(level in 1 until 7){
            pow[i][level] = pow[i][level-1] mm i
        }
    }
    solve.tc{
        share(100000)
        share(10)
        share(List(100000){it+1}.conca())
    }
//    solve.usetc()
    solve.cases{
        val n = getint
        val k = getint
        val L = getline(n)
        val sack = Array(n+1){TreeSet <Int>()}
        for((i,v) in L.withIndex()){
            sack[v].add(i)
        }
        val st = lazyBuildableSegTree(IntArray(n))
 
        var ret = 0
        for(i in 0 until n){
            val v = L[i]
            val previous = (sack[v].lower(i) ?: -1 ) + 1
            st.rangeApplyLazy(previous,i,1)
 
            val got = st.arr[1]
            ret = ret mp got.sum[k]
        }
        put(ret)
    }
//    just dei calls
    done()
}
Tester's code (C++)
// Jai Shree Ram  
  
#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define pb             push_back
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define gcd(a,b)       __gcd(a,b)
#define mem1(a)        memset(a,-1,sizeof(a))
#define mem0(a)        memset(a,0,sizeof(a))
#define sz(a)          (int)a.size()
#define pii            pair<int,int>
#define hell           1000000007
#define elasped_time   1.0 * clock() / CLOCKS_PER_SEC



template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}


// -------------------- Input Checker Start --------------------
 
long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
 
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
 
vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}
 
// -------------------- Input Checker End --------------------

const int N = 1e5 + 5;

const int MOD = 998244353;
 
struct mod_int {
    int val;
 
    mod_int(long long v = 0) {
        if (v < 0)
            v = v % MOD + MOD;
 
        if (v >= MOD)
            v %= MOD;
 
        val = v;
    }
 
    static int mod_inv(int a, int m = MOD) {
        int g = m, r = a, x = 0, y = 1;
 
        while (r != 0) {
            int q = g / r;
            g %= r; swap(g, r);
            x -= q * y; swap(x, y);
        }
 
        return x < 0 ? x + m : x;
    }
 
    explicit operator int() const {
        return val;
    }
 
    mod_int& operator+=(const mod_int &other) {
        val += other.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
 
    mod_int& operator-=(const mod_int &other) {
        val -= other.val;
        if (val < 0) val += MOD;
        return *this;
    }
 
    static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
           #if !defined(_WIN32) || defined(_WIN64)
                return x % m;
           #endif
           unsigned x_high = x >> 32, x_low = (unsigned) x;
           unsigned quot, rem;
           asm("divl %4\n"
            : "=a" (quot), "=d" (rem)
            : "d" (x_high), "a" (x_low), "r" (m));
           return rem;
    }
 
    mod_int& operator*=(const mod_int &other) {
        val = fast_mod((uint64_t) val * other.val);
        return *this;
    }
 
    mod_int& operator/=(const mod_int &other) {
        return *this *= other.inv();
    }
 
    friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
    friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
    friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
    friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
 
    mod_int& operator++() {
        val = val == MOD - 1 ? 0 : val + 1;
        return *this;
    }
 
    mod_int& operator--() {
        val = val == 0 ? MOD - 1 : val - 1;
        return *this;
    }
 
    mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
    mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
 
    mod_int operator-() const {
        return val == 0 ? 0 : MOD - val;
    }
 
    bool operator==(const mod_int &other) const { return val == other.val; }
    bool operator!=(const mod_int &other) const { return val != other.val; }
 
    mod_int inv() const {
        return mod_inv(val);
    }
 
    mod_int pow(long long p) const {
        assert(p >= 0);
        mod_int a = *this, result = 1;
 
        while (p > 0) {
            if (p & 1)
                result *= a;
 
            a *= a;
            p >>= 1;
        }
 
        return result;
    }
 
    friend ostream& operator<<(ostream &stream, const mod_int &m) {
        return stream << m.val;
    }
    friend istream& operator >> (istream &stream, mod_int &m) {
        return stream>>m.val;   
    }
};

struct node{
    vector<mod_int> a;
    node (int val=0){
        a.resize(6);
        a[0] = 1;
    }
    void merge(node &n1,node &n2){
        for(int j = 0; j <= 5; j++){
        	a[j] = n1.a[j] + n2.a[j];
        }
    }
};
// (a + x)^3 = a**3 + x**3 + 3ax^2 + 3a^2x
struct update{
  mod_int val=0;
  update(mod_int t=0){
     val=t;
  }
  void combine(update &par,int tl,int tr){
      val+=par.val;
  }
  void apply(node &node, int tl, int tr){
      vector<mod_int> b(6);
      auto &a = node.a;
      mod_int len = tr - tl + 1;
      b[0] = a[0];
      b[1] = a[1] + val*len;
      b[2] = a[2] + val.pow(2)*len + 2*a[1]*val;
      b[3] = a[3] + val.pow(3)*len + 3*a[1]*val.pow(2) + 3*a[2]*val;
      b[4] = a[4] + val.pow(4)*len + 4*a[1]*val.pow(3) + 4*a[3]*val + 6*a[2]*val.pow(2);
      b[5] = a[5] + val.pow(5)*len + 5*a[1]*val.pow(4) + 10*a[2]*val.pow(3) + 10*a[3]*val.pow(2) + 5*a[4]*val;
      a = b;
      val = 0;
  }
};
template<typename node,typename update>
struct segtree{
  node t[4*N];
  bool lazy[4*N];
  update zaker[4*N];
  int tl[4*N];
  int tr[4*N];
  node nul;
  inline void pushdown(int v){
     if(lazy[v]){    
       apply(zaker[v],v);
       lazy[v]=0;
       zaker[v].apply(t[v], tl[v], tr[v]);
     }
  }
  inline void apply(update &u,int v){
      if(tl[v]!=tr[v]){
          lazy[2*v]=lazy[2*v+1]=1;
          zaker[2*v].combine(u,tl[2*v],tr[2*v]);
          zaker[2*v+1].combine(u,tl[2*v+1],tr[2*v+1]);
      }
  }
  void build(int v,int start,int end){
      tl[v]=start;
      tr[v]=end;
      lazy[v] = 0;
      zaker[v].val = 0;
      t[v].a = vector<mod_int>(6,0);
      t[v].a[0] = 1;
      if(start==end) return;
      else{
          int m=(start+end)/2;
          build(2*v,start,m);
          build(2*v+1,m+1,end);
          t[v].merge(t[2*v],t[2*v+1]);
     }
  }
  void zeno(int v,int l,int r,update val){
      pushdown(v);
      if(tr[v]<l || tl[v]>r)return;
      if(l<=tl[v] && tr[v]<=r){
      	    zaker[v].combine(val, tl[v], tr[v]);
      	    zaker[v].apply(t[v],tl[v], tr[v]);	
            apply(val,v); 
            return;
      }
      zeno(2*v,l,r,val);
      zeno(2*v+1,l,r,val);
      t[v].merge(t[2*v],t[2*v+1]);
  }
  node query(int v,int l,int r){
      if(tr[v]<l || tl[v]>r)return nul;
      pushdown(v);
      if(l<=tl[v] && tr[v]<=r){
         return t[v];
      }
      node a=query(2*v,l,r);
      node b=query(2*v+1,l,r);
      node ans;
      ans.merge(a,b);
      return ans;
  }
public:
  node query(int l,int r){
      return query(1,l,r);
  }
  void upd(int l,int r,update val){
      return zeno(1,l,r,val);
  }   
};

segtree<node,update> seg;

int solve(){		
		static int sum_n = 0;
 		int n,k;
 		n = readIntSp(1,1e5);
 		sum_n += n;
 		assert(sum_n <= 1e5);
 		k = readIntLn(1,5);
 		vector<int> a = readVectorInt(n, 1, n);
 		
 		seg.build(1,1,n);
 		vector<int> last(n + 1,0), lastlast(n + 1,0);
 		mod_int ans = 0;
 		for(int i = 1; i <= n; i++){
 			seg.upd(last[a[i - 1]] + 1, i, mod_int(1));
 			last[a[i - 1]] = i;
 			auto s = seg.t[1].a;
 			//cout << s[k] << " ";
 			ans += s[k];
 		}	
 		cout << ans << endl;
 return 0;
}
signed main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #ifdef SIEVE
    sieve();
    #endif
    #ifdef NCR
    init();
    #endif
    int t = readIntLn(1,15000);
    while(t--){
        solve();
    }
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int mod = 998244353;
int add (int a, int b) {
	return (a + b) % mod;
}
int mul (int a, int b) {
	return (1LL * a * b) % mod;
}

vector C(100, vector(100, 0));
vector pows(15, vector(100005, 0));

struct Node {
	static const int kmax = 6;
	using T = array<int, kmax>;
	T unit {};
	T f(T a, T b) { 
		for (int i = 0; i < kmax; ++i) a[i] = ::add(a[i], b[i]);
		return a;
	}

	Node *l = 0, *r = 0;
	int lo, hi;
	int madd = 0;
	T val = unit;
	Node(int _lo,int _hi):lo(_lo),hi(_hi){ val[0] = hi - lo; }
	T query(int L, int R) {
		if (R <= lo || hi <= L) return unit;
		if (L <= lo && hi <= R) return val;
		push();
		return f(l->query(L, R), r->query(L, R));
	}
	void add(int L, int R, int x) {
		if (R <= lo || hi <= L) return;
		if (L <= lo && hi <= R) {
			madd += x;
			for (int k = kmax-1; k > 0; --k) {
				int cur = 0;
				for (int j = k; j >= 0; --j) {
					cur = ::add(cur, mul(val[j], mul(C[k][j], pows[k-j][x])));
				}
				val[k] = cur;
			}
		}
		else {
			push(), l->add(L, R, x), r->add(L, R, x);
			val = f(l->val, r->val);
		}
	}
	void push() {
		if (!l) {
			int mid = lo + (hi - lo)/2;
			l = new Node(lo, mid); r = new Node(mid, hi);
		}
		if (madd)
			l->add(lo,hi,madd), r->add(lo,hi,madd), madd = 0;
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	for (int i = 0; i < 100; ++i) C[i][0] = 1;
	for (int i = 1; i < 100; ++i) for (int j = 1; j <= i; ++j)
		C[i][j] = add(C[i-1][j], C[i-1][j-1]);
	
	for (int i = 1; i < 100005; ++i) for (int j = 0; j < 12; ++j)
		if (j == 0) pows[j][i] = 1;
		else pows[j][i] = mul(i, pows[j-1][i]);

	int t; cin >> t;
	while (t--) {
		int n, k; cin >> n >> k;
		vector<int> a(n);
		for (int &x : a) cin >> x;
		vector<int> prev(n+1, -1);
		Node *seg = new Node(0, n);
		
		int ans = 0;
		for (int R = 0; R < n; ++R) {
			seg -> add(prev[a[R]]+1, R+1, 1);
			ans = add(ans, (seg -> query(0, n))[k]);
			prev[a[R]] = R;
		}
		cout << ans << '\n';
	}
}

Cancer impl, XOR_X is much better.

I wouldn’t really call this cancer impl at all, especially if you already have a lazy segtree template. You modify the update function a bit (which is 3-4 lines) and that’s really about it.

At least to me, cancer impl is when I have to do annoying things like writing large numbers of if statements or input parsing, perhaps your opinion is different.

I use atcoder library’s segtree and I found it difficult to come up with a reasonable identity element. In my implementation, I also had to handle 0-th powers (lengths) separately. Also, I messed up by confusing prefix and suffix sums at some point while reversing the notation, but that’s on me.

Update: Finally, the constraints were pretty tough too, as I had to precompute binomials. I think that it is not something that should be enforced by the constraints. I usually only precompute factorials and their inverses, as the optimal way to compute binomials depends on the problem.

Oh, unfortunately I’ve never used ACL so I can’t comment on modifying that code.

I used the kactl lazy segtree, and found it easy enough to modify: the data type was an array of size 6, merge was adding two such arrays pointwise, identity is a bunch of zeros.

Regarding constraints: it’s surprising that you had to precompute binomials, in testing I did try using only factorials and inverse factorials and that passed in 0.7s (code), which seems pretty ok for a 2.5s TL, especially when we had other solutions running in 0.3s.
Maybe this is also a quirk of ACL? I don’t know.
fwiw my code, the other tester’s code, and the author’s code (in Kotlin, which is probably slower than C++?) all ran nowhere close to TL.

afair the only constant factor opt really needed is to compute binomials in \mathcal{O}(1) somehow (either precompute or fact + inv fact), doing it in \mathcal{O}(\log MOD) was too slow (but this seemed reasonable to enforce since it’s so common and an actual complexity improvement).

There are somewhat fast \mathcal{O}(N^2) solutions so we couldn’t reduce constraints/increase TL too much.

Edit: Looking at your code, you’re using a vector of size K. I recall small vectors having a large overhead, can you try your TLE submission but with std::array instead?

Ok looking closer at your TLE code, it seems that you didn’t actually precompute inverse factorials, and used division instead; that’s what made it much slower.

Simply precomputing inverse factorials makes your first submission run in 0.65s: https://www.codechef.com/viewsolution/87124169

Using std::array instead of std::vector brings this down to 0.42s, so there’s definitely a noticeable difference (luckily constant factor opt doesn’t seem to have made a difference in this task).

1 Like