GOOD_OPR - Editorial

Author: wuhudsm
Tester: jay_1048576
Editorialist: iceknight1093

2939

PREREQUISITES:

Observation, dynamic programming

PROBLEM:

You’re given N intervals [L_i, R_i] where L_i \lt R_i.

A good operation is as follows:
Pick M of these intervals, and K_i \gt 1 distinct integers Y_{i, 1}, Y_{i, 2}, \ldots, Y_{i, K_i} from the i-th chosen interval such that:

• The product K_1\times K_2 \times \ldots \times K_M doesn’t exceed N.
• M is as large as possible.

The score of a good operation is the product of all chosen Y_{i, j}, i.e,

\prod_{i=1}^M \prod_{j=1}^{K_i} Y_{i, j}

Find the sum of scores across all possible good operations.

EXPLANATION:

There are a lot of parts to this problem, so let’s analyze them one at a time.

Good operations

First, let’s look at when exactly an operation is good.
We pick M intervals and at least 2 elements from each chosen interval.
The product of the numbers of chosen elements shouldn’t exceed N, which already tells us that M can’t be very large since this product is at least 2^M.
In particular, this gives us the upper bound M \leq \lfloor log_2 N \rfloor.

Since M should be maximum, we can always choose M = \lfloor log_2 N \rfloor (we can just pick 2 elements from each interval).
So, M is fixed.

Now, let’s look at the K_i values.
Clearly most of them should be 2, but if 2^M \lt N we have some leeway.
For example, if N = 6 we have M = 2, and we have 2\cdot 2 \leq 6 and 2\cdot 3 \leq 6, which means we can have one of the K_i equal to 3 if we like.

However, despite this leeway, the restrictions are in fact fairly tight.
Specifically, a little analysis should tell you the following:

• Every K_i will be \leq 3.
• At most one K_i can be 3.
Proof

Suppose K_i \geq 4.
Then, we can instead pick 2 elements each from 2 other intervals for a larger M but not larger product, which is surely not optimal.

Similarly, if we have 2 intervals with K_i = 3, their product is 9 and we can instead pick 3 intervals and choose 2 from each for larger M and lower product, which once again isn’t optimal.

This puts a rather strict structure on what a good operation looks like:

• We always choose exactly M = \lfloor log_2 N \rfloor intervals.
• Then, either we pick two elements from all M of these, or (product allowing) we pick three elements from one of them and two from all the rest.

This allows us to compute the score using dynamic programming.

Computing the sum of scores

There are two cases here: either we pick 2 elements from every interval, or we pick two elements from all but one and 3 from the last one.
Dealing with both is similar, so let’s see how to answer the first case (i.e, all twos) first.

To get an idea of how to do do this, let’s look at a simple case first: there are two intervals [L_1, R_1] and [L_2, R_2], and we pick exactly two elements from each one.
How would we compute the sum of scores in this case?

Well, suppose we pick the elements x_1, x_2 from the first interval and x_3, x_4 from the second. They add x_1x_2x_3x_4 to the sum.
Summing this across all possible values, we want

\begin{align*} &\sum_{L_1 \leq x_1 \lt x_2 \leq R_1} \sum_{L_2 \leq x_3 \lt x_4 \leq R_2} \left(x_1x_2x_3x_4 \right) \\ \\ &= \sum_{(x_1, x_2)} \left(x_1x_2 \sum_{(x_3, x_4)} x_3x_4\right) \\ \\ &= \left(\sum_{(x_1, x_2)}x_1x_2\right) \left(\sum_{(x_3, x_4)}x_3x_4\right) \end{align*}

That is, if we know the sum of pairwise products of each the two intervals, all we’d need to do is multiply them.

It’s easy to see that this generalizes to M intervals as well: if P_i is the sum of pairwise products of the i-th interval, then what we’re looking for is the sum of products of the P_i values taken M at a time.

First, let’s compute the P_i values. That can be done with a bit of math.
For an interval [L_i, R_i], the corresponding P_i value is

\frac{(L_i + L_{i+1} + \ldots + R_i)^2 - (L_i^2 + L_{i+1}^2 + \ldots + R_i^2)}{2}

because:

• (L_i + L_{i+1} + \ldots + R_i)^2 is the product of all pairs of values in the interval, disregarding order or distinctness.
• We subtract (L_i^2 + L_{i+1}^2 + \ldots + R_i^2) to account for choosing two of the same element.
• Now the remaining value didn’t account for order, i.e, x_1x_2 and x_2x_1 were both counted as different products.
To account for this, divide by 2.

Once all the P_i values are computed, we want to compute their product taken M at a time.
That can be done in \mathcal{O}(N\cdot M) with a fairly simple DP.

Let \text{dp}[i][j] denote the sum of products all j-tuples of the first i values. Then,
\text{dp}[i][j] = \text{dp}[i-1][j] + P_i\cdot\text{dp}[i-1][j-1]
The value we’re looking for is, of course, \text{dp}[N][M].

As a small side note, the sum of products of the elements taken M at a time can be computed for every M from 1 to N in \mathcal{O}(N\log^2 N) using NTT, but that’s extreme overkill for this problem because we’re only interested in a single value of M, and the \mathcal{O}(N\cdot M) method is enough because M = \lfloor \log_2 N \rfloor.

This almost solves the problem: the only remaining case is when we take exactly three elements from one interval.
You might’ve noticed that this is not so different from the above case: indeed, once again we only really need to know, for each interval, the sum of products of all distinct triples of elements in it.

This can be computed almost exactly the same way we computed P_i, by using inclusion-exclusion.

Details

Just as before, we start with (L_i + \ldots + R_i)^3, disregarding both distinctness and order.

From this, subtract 3\cdot (L_i + \ldots + R_i)\cdot (L_i^2 + \ldots + R_i^2), which is the sum of all triplets such that two elements are equal (fix the equal pair in \binom{3}{2} = 3 ways, then it’s the sum of all elements multiplied by the sum of squares of all elements).

Next, add 2\cdot (L_i^3 + \ldots + R_i^3), to account for the case when all three elements are equal.
Such cases were subtracted thrice above but should’ve been subtracted only once; hence the multiplier of 2.

Now we’re left with the sum of products of triplets, all of whose elements are distinct.
To account for order, divide this by 6 to get the final answer.

Note that in intermediate steps, we needed to compute the sum of first/second/third powers of a range of integers; there are well-known formulas for these that allow for \mathcal{O}(1) computation.

Once these values are computed for each interval (let’s call them Q_i), the earlier dp can be easily modified to account for them.
Let \text{dp}[i][j][0/1] denote the sum of scores for the first i intervals if we’ve chosen exactly j so far, and the third dimension denotes whether we have/haven’t chosen 3 from an interval yet.
Then,

• \text{dp}[i][j][0] = \text{dp}[i-1][j][0] + \text{dp}[i-1][j-1][0] \cdot P_i
• \text{dp}[i][j][1] = \text{dp}[i-1][j][1] + \text{dp}[i-1][j-1][1] \cdot P_i + \text{dp}[i-1][j-1][0] \cdot Q_i

Once again, this takes \mathcal{O}(N\log N) time.
It can also be implemented with \mathcal{O}(\log N) memory, though that’s likely unnecessary to get AC.

• If 3\cdot 2^{M-1} \leq N, we can have an interval with K_i = 3. In this case, the answer is \text{dp}[i][M][0] + \text{dp}[i][M][1].
• Otherwise, all intervals must have K_i = 2 and the answer is just \text{dp}[i][M][0].

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Author's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=200010;
const int LOGN=20;
const ll  TMD=998244353;
const ll  INF=2147483647;
int T,n,m;
int l[N],r[N];
ll  f2[N],f3[N];
ll  dp[N][LOGN][2];

ll pw(ll x,ll p)
{
if(!p) return 1;
ll y=pw(x,p>>1);
y=(y*y)%TMD;
if(p&1) y=(y*(x%TMD))%TMD;
return y;
}

ll inv(ll x)
{
return pw(x,TMD-2);
}

ll sum(int t,ll x)
{
if(x==0)      return 0;
if(t==1)      return x*(x+1)%TMD*inv(2)%TMD;
else if(t==2) return x*(x+1)%TMD*(2*x+1)%TMD*inv(6)%TMD;
else          return pw(sum(1,x),2);
}

ll sum(int t,int l,int r)
{
return (sum(t,r)-sum(t,l-1)+TMD)%TMD;
}

void cal_f()
{
for(int i=1;i<=n;i++)
{
ll s1=sum(1,l[i],r[i]),s2=sum(2,l[i],r[i]),s3=sum(3,l[i],r[i]);
f2[i]=(pw(s1,2)-s2+TMD)*inv(2)%TMD;
if(r[i]-l[i]+1>2) f3[i]=(pw(s1,3)-3*s1*s2%TMD+2*s3+TMD*TMD)%TMD*inv(6)%TMD;
else f3[i]=0;
}
}

int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d%d",&l[i],&r[i]);
cal_f();
m=(int)log2(n);
for(int i=1;i<=n;i++)
{
dp[i-1][0][0]=1;
for(int j=1;j<=m;j++)
{
for(int k=0;k<=1;k++)
{
dp[i][j][k]=dp[i-1][j][k];
dp[i][j][k]=(dp[i][j][k]+dp[i-1][j-1][k]*f2[i])%TMD;
if(k) dp[i][j][k]=(dp[i][j][k]+dp[i-1][j-1][0]*f3[i])%TMD;
}
}
}
if((1<<(m-1))*3<=n) printf("%d\n",(dp[n][m][0]+dp[n][m][1])%TMD);
else printf("%d\n",dp[n][m][0]);
}

return 0;
}

Tester's code (C++)
/*...................................................................*
*............___..................___.....____...______......___....*
*.../|....../...\........./|...../...\...|.............|..../...\...*
*../.|...../.....\......./.|....|.....|..|.............|.../........*
*....|....|.......|...../..|....|.....|..|............/...|.........*
*....|....|.......|..../...|.....\___/...|___......../....|..___....*
*....|....|.......|.../....|...../...\.......\....../.....|./...\...*
*....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
*....|.....\...../.........|....|.....|.......|.../........\...../..*
*..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
*...................................................................*
*/

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 998244353

int power(int a,int b)
{
if(b==0)
return 1;
else
{
int x=power(a,b/2);
int y=(x*x)%MOD;
if(b%2)
y=(y*a)%MOD;
return y;
}
}

int inverse(int a)
{
return power(a,MOD-2);
}

int sum(int l,int r)
{
int a = (r*(r+1)/2)%MOD;
int b = (l*(l-1)/2)%MOD;
return (a-b+MOD)%MOD;
}

int sum2(int l,int r)
{
int a = (((r*(r+1))%MOD*(2*r+1))%MOD*inverse(6))%MOD;
int b = (((l*(l-1))%MOD*(2*l-1))%MOD*inverse(6))%MOD;
return (a-b+MOD)%MOD;
}

int sum3(int l,int r)
{
int a = (r*(r+1)/2)%MOD;
a = (a*a)%MOD;
int b = (l*(l-1)/2)%MOD;
b = (b*b)%MOD;
return (a-b+MOD)%MOD;
}

pair<int,int> sop(int l,int r)
{
if(l==r)
return {0,0};
else if(r==l+1)
return {(l*r)%MOD,0};
else
{
int s1 = sum(l,r);
int s2 = sum2(l,r);
int s3 = sum3(l,r);
int pair_sum = ((s1*s1-s2+MOD)%MOD*inverse(2))%MOD;
int triple_sum = ((((s1*s1)%MOD*s1-3*s1*s2+2*s3)%MOD+MOD)*inverse(6))%MOD;
return {pair_sum,triple_sum};
}
}

void solve(int tc)
{
int n;
cin >> n;
pair<int,int> a[n];
for(int i=0;i<n;i++)
cin >> a[i].first >> a[i].second;
pair<int,int> s[n];
for(int i=0;i<n;i++)
s[i] = sop(a[i].first,a[i].second);
int m=0;
for(int i=1;i<=n;i*=2)
m++;
m--;
int pre[n+1][m+1],suf[n+1][m+1];
memset(pre,0,sizeof(pre));
memset(suf,0,sizeof(suf));
pre[0][0] = 1;
for(int i=0;i<n;i++)
{
pre[i+1][0]=1;
for(int j=1;j<=m;j++)
pre[i+1][j]=(pre[i][j]+pre[i][j-1]*s[i].first)%MOD;
}
suf[n][0]=1;
for(int i=n-1;i>=0;i--)
{
suf[i][0]=1;
for(int j=1;j<=m;j++)
suf[i][j]=(suf[i+1][j]+suf[i+1][j-1]*s[i].first)%MOD;
}
int ans1 = pre[n][m];
int ans2 = 0;
for(int i=0;i<n;i++)
{
for(int j=0;j<m;j++)
{
ans2 += (pre[i][j]*suf[i+1][m-1-j])%MOD*s[i].second;
ans2 %= MOD;
}
}
if((1<<(m-1))*3<=n)
cout << (ans1+ans2)%MOD << '\n';
else
cout << ans1 << '\n';
}

int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int tc=1;
cin >> tc;
for(int ttc=1;ttc<=tc;ttc++)
solve(ttc);
return 0;
}

Editorialist's code (Python)
import sys

mod = 998244353
inv2 = (mod + 1) // 2
inv6 = (mod + 1) // 6

def pref(n): return (n*(n+1)//2) % mod
def pref2(n):
ret = n*(n+1) % mod
ret = (ret*(2*n+1)) % mod
ret = (ret * inv6) % mod
return ret

for _ in range(int(input())):
n = int(input())

# choose 2 or 3 from each interval, because if we take k >= 4 we can instead split into 2 and 2 for higher m and not worse cost
# 3 from at most one interval, because 3*3 can be replaced by 2*2*2
# dp[i][j][0/1] -> answer for first i intervals if we've taken from j intervals so far and used/not used 3

lim = 1
while True:
if 2**(lim + 1) > n: break
lim += 1

dp2 = [0]*(lim + 1)
dp3 = [0]*(lim + 1)
dp2[0] = 1
for i in range(n):
l, r = map(int, input().split())
for j in reversed(range(1, lim + 1)):
# don't choose this interval: no change

# take 2: (l+...r)^2 - (l^2 + ... + r^2)
mul = (pref(r) - pref(l-1))**2
mul -= pref2(r) - pref2(l-1)
mul = (mul % mod * inv2) % mod
dp2[j] = (dp2[j] + dp2[j-1] * mul) % mod
dp3[j] = (dp3[j] + dp3[j-1] * mul) % mod

# take 3
if r-l+1 >= 3:
mul = pow((pref(r) - pref(l-1)), 3, mod)
mul -= 3 * (pref2(r) - pref2(l-1)) * (pref(r) - pref(l-1)) % mod
mul += 2*(pref(r)**2 - pref(l-1)**2)
mul = (mul % mod * inv6) % mod

dp3[j] = (dp3[j] + dp2[j-1] * mul) % mod

ans = dp2[lim]
if 3*(2**(lim-1)) <= n: ans += dp3[lim]
print(ans % mod)