PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
Sorting
PROBLEM:
There are N items, in ascending order of size. Item i costs C_i.
You can display several items on the storefront, following these conditions:
- At least two items should be displayed when a customer visits.
- An item that’s displayed cannot be removed, unless a customer buys it.
- Each item can be bought by at most one customer; and if it is bought, is to be removed from the storefront and cannot be placed there again.
K customers will visit, in order. Each of them will buy either the largest item or the smallest item that’s on the storefront, represented by a binary string A.
The answer for this binary string is the maximum possible revenue you can obtain by placing items appropriately.
Compute the sum of answers across all substrings of A.
EXPLANATION:
Let’s first recall the solution to the easy version of the problem, where we solved for a single binary string.
If the string’s length was K, and the last c characters of the binary string are all 0,
- Consider the most expensive K items.
If these exclude at least one of the largest K-c+1 items, then their sum is the answer. - Otherwise, the answer is the sum of most expensive K items, minus the minimum cost among the largest K-c+1 items, plus the cost of the (K+1)-th most expensive item.
(This is assuming the binary string ends with a 0.)
The main observation here is that the actual binary string doesn’t matter much at all: we only care about the parameters K and c (i.e. the length of the string and its longest suffix of zeros).
Let’s use this fact to solve the problem.
For convenience, we’ll first solve for only all those substrings that end with a 0.
We break the problem into two parts: the first is binary strings that consist of only zeros (i.e. K = c), and then everything else (i.e K \gt c).
(Note that there’s a bit of abuse of notation going on here: K is now just a generic variable representing length, and has nothing to do with the length of the binary string given to us, which doesn’t really matter anymore.)
The K = c part is fairly easy to handle.
Details
Consider the K items with highest values.
If item N isn’t included in them, then their sum is the answer for K.
Otherwise, the answer is the sum of the K+1 largest values, minus the value of item N.There will be some prefix of K for which the answer is just the sum of the largest K values, and then everything after that will be of the “largest (K+1) except N” form.
This makes it fairly easy to just compute the answer for every value of K; let this value be \text{ans}_K.Then, count the number of all-0 substrings of each length (say, denoted ct_K), after which we can just add ct_K\cdot\text{ans}_K to the answer for each K.
To find ct_K, note that for each index that ends with a 0, it’ll add 1 to each of ct_1, ct_2, ct_3, \ldots till a 0 just before it is reached, so simple prefix sums can get the job done.
We now focus entirely on K \gt c.
To make things a bit easier to work with, we’ll first compute an initial approximation of the answer: for every substring that ends with a 0, we’ll compute its best possible answer, and then compute adjustments where necessary.
Let \text{best}_i denote the sum of the largest i costs.
First, for each index i, compute the value c_i - the length of the maximal consecutive segment of zeros ending at index i.
Now, the possible segment lengths ending at i, with not all characters being equal, are c_i+1, c_i+2, \ldots, i.
So, to facilitate our initial approximation, we need to add
to the answer, which can be done in constant time using prefix sums.
Now, we need to perform adjustments.
Call a (K, c) pair “bad”, if the K best-valued items include the K-c+1 largest items.
For each bad pair, we need to do two things:
- Subtract the minimum value among the largest K-c+1 items.
This should be done once for each substring with (K, c), so we’d like to know the count of such substrings. - Then, add the (K+1)-th largest value to the answer.
Again, this should be done once for each substring with this (K, c).
The two parts are independent, so we can do them separately.
Let p_i denote the order of the i-th item: that is, the item with maximum cost has p_i = 1, the second maximum cost has p_i = 2, and so on.
Observe that a (K, c) pair is bad if and only if \max(p_1, p_2, \ldots, p_{K-c+1}) \leq K.
Suppose we fix x = K-c+1.
Let mx = \max(p_1, p_2, \ldots, p_x), and let M denote the minimum value among the smallest x items.
If we then fix c, the value of K is also fixed, being x+c-1.
So, a pair (K, c) will be bad if and only if mx \leq x+c-1, which can be rearranged to c \geq mx-x+1.
That is, once x=K-c+1 is fixed, all c that are “large enough” will have a unique K for which (K, c) is bad.
Let’s look at this in terms of counting substrings.
Consider some index j. Recall that c_j denoted the length of the maximal segment of zeros ending at j. We have:
- If c_j \lt mx-x+1, there is no valid K corresponding to this index, for this value of x.
- If c_j \geq mx-x+1, there is exactly one valid K corresponding to this index, that being x+c_j-1.
The issue now is that K must be small enough that the substring exists at all, i.e. we must have x+c_j-1 \leq j.
So, with x fixed, our aim is to count the number of indices j such that c_j \geq mx-x+1 and x+c_j-1 \leq j.
The second inequality can be rearranged to become x-1 \leq j - c_j, so that the left side is a constant.
This has thus become just a 2D range query problem: if we create the point (c_j, j-c_j) for each j, we’re interested in the number of points inside some rectangular region.
This is a standard problem, with a variety of solutions. The fastest is to just solve the queries offline using a fenwick tree or segment tree along with a sweepline.
At any rate, we now know, for a fixed x, the number of substrings that are “bad” with respect to x.
Let this count be b, we can then subtract b\cdot M from the answer (recall that M is the minimum value among these items).
The second part, of adding the (K+1)-th largest value to the answer for “bad” substrings, can be handled similarly.
This time, fix the value of K+1 instead.
We’re now only interested in indices \geq K+1.
Specifically, among all indices \geq K+1, we want to know the count of j for which
\max(p_1, p_2, \ldots, p_{K-c_j+1}) \leq K.
Since this is a prefix maximum, there will be some contiguous range of c_j for which the inequality is valid; and so yet again we end up 2D range queries (this time with points (j, c_j)) which can be solved the same way as above.
This takes care of all substrings that end with a 0.
Everything else can be solved by inverting all the elements of the string, reversing the array C, and then running the same method again.
TIME COMPLEXITY:
\mathcal{O}(N\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct FenwickTree{
int n;
vector <int> f;
vector <int> b;
inline void add(int i, int x){
b[i] += x;
for (int j = i; j <= n; j += j & (-j)){
f[j] += x;
}
}
inline void modify(int i, int x){
add(i, x - b[i]);
}
inline void init(int nn, vector <int> a){
n = nn;
if (a.size() == n){
vector <int> a2;
a2.push_back(0);
for (auto x : a) a2.push_back(x);
a = a2;
}
f.resize(n + 1);
b.resize(n + 1);
for (int i = 0; i <= n; i++) f[i] = 0, b[i] = 0;
for (int i = 1; i <= n; i++){
modify(i, a[i]);
}
}
inline int query(int x){
int ans = 0;
for (int i = x; i; i -= i & (-i)){
ans += f[i];
}
return ans;
}
inline int query(int l, int r){
return query(r) - query(l - 1);
}
};
const int mod = 1e9 + 7;
int sub;
void Solve()
{
int n, k; cin >> n >> k;
vector <int> c(n);
for (auto &x : c) cin >> x;
string s; cin >> s;
int ans = 0;
if (sub == 1){
if (s[k - 1] == '0'){
reverse(c.begin(), c.end());
}
int y = 1;
while (y < k && s[k - 1 - y] == s[k - 1]){
y++;
}
int x = k - y;
// not all from first x + 1
vector <int> ord(n);
iota(ord.begin(), ord.end(), 0);
sort(ord.begin(), ord.end(), [&](int x, int y){
return c[x] > c[y];
});
// not all from first x + 1
int take = 0;
int sum = 0;
for (int i = 0; i < k; i++){
sum += c[ord[i]];
take += ord[i] <= x;
}
if (take == x + 1){
int mn = INF;
for (int i = 0; i <= x; i++){
mn = min(mn, c[i]);
}
sum -= mn;
sum += c[ord[k]];
}
sum %= mod;
cout << sum << "\n";
return;
}
for (int _ = 0; _ < 2; _++){
reverse(c.begin(), c.end());
for (auto &x : s){
x ^= '0' ^ '1';
}
vector <int> f(n + 1, 0);
for (int i = 0; i < k; i++) if (s[i] == '0'){
int j = i;
while (j + 1 < k && s[j + 1] == '0'){
j++;
}
int len = (j - i + 1);
for (int k = 1; k <= len; k++){
f[k] += (len - k + 1);
}
i = j;
}
vector <int> a;
for (int i = 0; i < n - 1; i++){
a.push_back(c[i]);
}
sort(a.begin(), a.end(), greater<int>());
int sum = 0;
for (int l = 1; l <= k; l++){
sum += a[l - 1];
sum %= mod;
ans += sum * f[l];
ans %= mod;
}
vector <int> ord(n);
iota(ord.begin(), ord.end(), 0);
sort(ord.begin(), ord.end(), [&](int x, int y){
return c[x] > c[y];
});
vector <int> pos(n);
for (int i = 0; i < n; i++){
pos[ord[i]] = i;
}
vector <int> pre(n);
for (int i = 0; i < n; i++){
pre[i] = pos[i];
if (i > 0) pre[i] = max(pre[i], pre[i - 1]);
}
vector <int> y(n);
for (int i = 0; i < k; i++) if (s[i] == '1'){
int j = i;
while (j + 1 < k && s[j + 1] == s[i]){
j++;
}
for (int k = i; k <= j; k++){
y[k] = (k - i + 1);
}
i = j;
}
vector <int> diff(n + 1);
auto add = [&](int l, int r){
diff[l]++;
diff[r + 1]--;
};
for (int i = 0; i < k; i++){
if (s[i] == '1'){
int L = y[i] + 1;
int R = i + 1;
add(L, R);
}
}
for (int i = 1; i <= n; i++){
diff[i] += diff[i - 1];
}
sum = 0;
for (int i = 0; i < n; i++){
sum += c[ord[i]];
sum %= mod;
ans += sum * diff[i + 1];
ans %= mod;
}
vector <int> pref_mn(n);
for (int i = 0; i < n; i++){
pref_mn[i] = c[i];
if (i > 0) pref_mn[i] = min(pref_mn[i], pref_mn[i - 1]);
}
// doing += c[ord[length]] for now
vector<vector<int>> adj(n + 1);
for (int i = 0; i < k; i++){
if (s[i] == '1' && y[i] != (i + 1)){
adj[y[i] + 1].push_back(y[i]);
adj[i + 2].push_back(-y[i]);
}
}
FenwickTree fen;
vector <int> vv(n);
fen.init(n, vv);
for (int i = 1; i <= k; i++){
for (int y : adj[i]){
if (y > 0){
fen.add(y, +1);
} else {
fen.add(-y, -1);
}
}
int lo = 0, hi = i;
while (lo != hi){
int mid = (lo + hi + 1) / 2;
// this many y good?
int y = mid;
int x = i - y;
if (pre[x] < x + y){
hi = mid - 1;
} else {
lo = mid;
}
}
// upto lo is ok in terms of y
// greater things are bad
int count = fen.query(lo + 1, n);
ans += count * c[ord[i]];
ans %= mod;
}
// now -= pref_mn[x]
for (int i = 0; i <= n; i++){
adj[i].clear();
}
for (int i = 0; i < k; i++){
if (s[i] == '1' && y[i] != (i + 1)){
adj[1].push_back(y[i]);
adj[i - y[i] + 2].push_back(-y[i]);
}
}
for (int i = 1; i <= n; i++){
fen.modify(i, 0);
}
for (int x = 1; x <= k; x++){
for (int y : adj[x]){
if (y > 0){
fen.add(y, +1);
} else {
fen.add(-y, -1);
}
}
int lo = 0, hi = k - x;
while (lo != hi){
int mid = (lo + hi + 1) / 2;
int y = mid;
if (pre[x] < x + y){
hi = mid - 1;
} else {
lo = mid;
}
}
int count = fen.query(lo + 1, n);
ans -= count * pref_mn[x];
ans %= mod;
if (ans < 0) ans += mod;
}
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
// cin >> sub;
sub = 2;
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
template<typename T>
struct fenwick {
int n;
vector<T> tr;
int LOG = 0;
fenwick() {
}
fenwick(int n_) {
n = n_;
tr = vector<T>(n + 1);
while((1<<LOG) <= n) LOG++;
}
int lsb(int x) {
return x & -x;
}
void pupd(int i, T v) {
for(; i <= n; i += lsb(i)){
tr[i] += v;
}
}
T sum(int i) {
T res = 0;
for(; i; i ^= lsb(i)){
res += tr[i];
}
return res;
}
T query(int l, int r) {
if (l > r) return 0;
T res = sum(r) - sum(l - 1);
return res;
}
int lower_bound(T s){
// first pos with sum >= s
if(sum(n) < s) return n+1;
int i = 0;
rev(bit,LOG-1,0){
int j = i+(1<<bit);
if(j > n) conts;
if(tr[j] < s){
s -= tr[j];
i = j;
}
}
return i+1;
}
int upper_bound(T s){
return lower_bound(s+1);
}
};
void solve(int test_case){
ll n,k; cin >> n >> k;
vector<ll> a(n+5);
rep1(i,n) cin >> a[i];
string t; cin >> t;
t = "$" + t;
ll ans = 0;
auto go = [&](){
// solve for all [l..r] s.t t[r] = '1'
vector<ll> one_len(k+5);
rep1(i,k){
if(t[i] == '1') one_len[i] = one_len[i-1]+1;
}
vector<ll> full_one(k+5);
rep1(i,k){
if(t[i] != '1') conts;
full_one[one_len[i]]++;
}
rev(i,k,1) full_one[i] += full_one[i+1];
vector<pll> b;
rep1(i,n) b.pb({a[i],i});
sort(rall(b));
b.insert(b.begin(),{0,0});
vector<ll> mnp(n+5,inf2);
rep1(i,n) mnp[i] = min(mnp[i-1],a[i]);
vector<ll> pk(k+5);
{
vector<ll> enter[n+5], leave[n+5];
rep1(i,k){
if(t[i] != '1') conts;
ll cnt = i-one_len[i];
ll lk = one_len[i]+1, rk = lk+cnt;
enter[lk].pb(2);
leave[rk].pb(2+cnt);
}
vector<bool> taken(n+5);
ll p = 1;
ll sum = 0;
ll curr_shift = 0;
fenwick<ll> fenw(2*n+5);
rep1(c,k){
sum += b[c].ff;
sum %= MOD;
taken[b[c].ss] = 1;
while(taken[p]) p++;
pk[c] = p;
if(p == 1){
ans += sum*full_one[c];
}
else{
ans += (sum-a[1]+b[c+1].ff+MOD)*full_one[c];
}
ans %= MOD;
trav(x,enter[c]){
fenw.pupd(x-curr_shift+n,1);
}
trav(x,leave[c]){
fenw.pupd(x-curr_shift+n,-1);
}
ll tot_cnt = fenw.sum(2*n+3);
ll bad_cnt = fenw.sum(p-curr_shift+n-1);
ans += (tot_cnt*sum)+(bad_cnt*b[c+1].ff);
ans %= MOD;
curr_shift++;
}
}
{
vector<ll> enter[n+5], leave[n+5];
rep1(i,k){
if(t[i] != '1') conts;
ll cnt = i-one_len[i];
ll lk = one_len[i]+1, rk = lk+cnt;
enter[2].pb(lk);
leave[2+cnt].pb(rk);
}
fenwick<ll> fenw(2*n+5);
ll curr_shift = 0;
ll ptr = 1;
rep1(i,n){
trav(x,enter[i]){
fenw.pupd(x-curr_shift+n,1);
}
trav(x,leave[i]){
fenw.pupd(x-curr_shift+n,-1);
}
// find min ptr s.t pk[ptr] > i
while(ptr <= k and pk[ptr] <= i){
ptr++;
}
if(ptr > k) break;
ll bad_cnt = fenw.query(ptr-curr_shift+n,2*n+3);
ans -= bad_cnt*mnp[i];
ans = (ans%MOD+MOD)%MOD;
curr_shift++;
}
}
};
go();
rep1(i,k){
if(t[i] == '1') t[i] = '0';
else t[i] = '1';
}
reverse(a.begin()+1,a.begin()+n+1);
go();
cout << ans << endl;
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}