PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: still_me
Testers: the_hyp0cr1t3, rivalq
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
None
PROBLEM:
Given an integer K in base P, find the smallest integer X \geq 0 such that K+X contains all the digits 0, 1, 2, \ldots, P-1 in base P.
EXPLANATION:
This is more of an implementation task than anything else, but being a bit careful can lead to a relatively short and clean implementation.
Instead of finding the smallest value of X, let’s find the smallest possible value of K+X; in the end, we can obtain X by subtracting K from it.
First, note that K+X will contain either P or P+1 digits: at least P since we need P distinct digits, and at most P+1 because N \leq P means we never need to go any higher.
In particular, for the answer to have P digits, all of its digits need to be distinct.
Now let’s get rid of a couple of simple edge cases:
- If N \lt P, then the optimal value of K+X is [1, 0, 2, 3, 4, \ldots, P-1]; this is the smallest integer with all the digits.
- Now we’re left with N = P always.
- The largest possible P-digit number with distinct digits is [P-1, P-2, \ldots, 2, 1, 0]. If the given number is larger than this (which can be checked by comparing them lexicographically), then we must use P+1 digits, and the answer is [1, 0, 0, 2, 3, 4, \ldots, P-1]; this is the smallest P+1-digit number containing all digits.
Once those cases are out of the way, we know that it’s always possible to achieve a P-digit answer; so our task is to make all the digits distinct.
To minimize K+X, we’d like to keep as large a prefix of K unchanged as possible.
So, let’s iterate i from 1 to P.
- if A_i hasn’t appeared before, we don’t need to increase this index right now
- If A_i has appeared before, we must change some index \leq i in order to not have repeated digits.
Let x be the first time we meet a repeated digit (if there are no repeated digits, the answer is obviously 0 so we only focus on the case when a valid x exists).
We have to increase some digit at a position \leq x. So, let’s do the following:
- Let’s check if we can replace A_x with something higher than it. This is only possible if there exists a digit \gt A_x that has not appeared before position x (recall that our aim is distinct digits).
- If we can’t replace A_x with something larger than it, instead check for position x-1, then position x-2, and so on. One of these will definitely satisfy the condition since we know a P-digit answer is possible.
- One simple way of quickly checking whether replacement is possible is to maintain a set of unseen values and find its maximum. When moving from x to x-1, insert A_{x-1} into this set of unseen elements.
- This allows each position to be processed in \mathcal{O}(\log N) time, for \mathcal{O}(N\log N) overall. You can also maintain a boolean
mark
array and a pointer to it for \mathcal{O}(N) time.
- Suppose we’ve found the position y that needs to be increased. Do the following:
- Set A_y to the smallest unseen element larger than it
- Then, fill positions y+1 to P with the remaining unseen elements in ascending order.
Now we have the array representing K+X, so we need to compute X = (K+X) - K.
Notice that we only need X modulo 10^9 + 7.
So, compute both K and K+X modulo 10^9 + 7 (for example, using binary exponentiation to quickly compute powers modulo the mod), and subtract one from the other to obtain the final answer.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Setter's code (C++)
// Code by Reyaan Jagnani
#include<bits/stdc++.h>
#define ll long long int
#define ld long double
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define scanit(a,n) for(ll indexaa=0; indexaa<n; indexaa++) cin>>a[indexaa];
#define printit(a,n) for(ll indexaa=0; indexaa<n; indexaa++) cout<<a[indexaa]<<" "; cout<<endl;
#define pb push_back
#define precision(a) cout<<fixed<<setprecision(a)
#define testcase ll t; cin>>t; while(t--)
#define endl "\n"
#define iendl "\n", cout<<flush // FOR INTERACTIVE PROBLEMS
#define quick ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL)
#define timetaken cerr<<fixed<<setprecision(10); cerr << "time taken : " << (float)clock() / CLOCKS_PER_SEC << " secs" << endl
using namespace std;
const ll M = 1000000007;
const ll maxN = 200001;
ll fact[maxN] = {}, smallestPrimeFactor[maxN] = {}, isPrimeSieve[maxN] = {};
mt19937_64 my_rand(chrono::steady_clock::now().time_since_epoch().count());
inline bool comp(ll x,ll y) { return x<y; } // INITIALLY IN DEFAULT INCREASING ORDER (SMALL TO BIG)
inline ll mod(ll x) {ll a1=(x%M); if(a1<0){a1+=M;} return a1;}
inline ll power(ll x, unsigned ll y, ll p = LLONG_MAX) {ll res=1; x=x%p; if(x==0) {return 0;} while(y>0){ if(y&1){res=(res*x)%p;} y=y>>1; x=(x*x)%p;} return res;} // CALCULATING POWER IN LOG(Y) TIME COMPLEXITY
inline ll inversePrimeModular(ll a, ll p) {return power(a,p-2,p);}
inline void calcFact(ll n = maxN-1) { fact[0] = 1; for(ll i=1; i<=n; i++){ fact[i] = fact[i-1]*i; fact[i] = mod(fact[i]); }}
inline ll ncr(ll n, ll r) { if(n<r) return 0; return mod(inversePrimeModular(mod(fact[n-r]*fact[r]),M)*fact[n]); }
inline ll ceil(ll a, ll b) { if(b==0) return LLONG_MAX; ll ans = (a+b-1)/b; return ans; }
struct custom_hash { static uint64_t splitmix64(uint64_t x) { x += 0x9e3779b97f4a7c15; x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9; x = (x ^ (x >> 27)) * 0x94d049bb133111eb; return x ^ (x >> 31); } size_t operator()(uint64_t x) const { static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count(); return splitmix64(x + FIXED_RANDOM); }};
void sieve(ll n = maxN-1) { for(ll i=1; i<=n; i++) smallestPrimeFactor[i] = i; for(ll i=2; (i*i)<=n; i++) { if(smallestPrimeFactor[i]==i) { for(ll j=(i*i); j<=n; j+=i) { smallestPrimeFactor[j] = min(smallestPrimeFactor[j], i); } } } for(ll i=2; i<=n; i++) { if(smallestPrimeFactor[i]==i) isPrimeSieve[i] = 1; } }
#ifndef ONLINE_JUDGE
#define dbg(x) cerr << #x << " : "; _print_(x);cerr << endl;
#else
#define dbg(x)
#endif
void _print_(ll t) {cerr << t;}
void _print_(int t) {cerr << t;}
void _print_(string t) {cerr << t;}
void _print_(char t) {cerr << t;}
void _print_(ld t) {cerr << t;}
void _print_(double t) {cerr << t;}
template <class T, class V> void _print_(pair <T, V> p);
template <class T> void _print_(vector <T> v);
template <class T> void _print_(set <T> v);
template <class T, class V> void _print_(map <T, V> v);
template <class T> void _print_(multiset <T> v);
template <class T, class V> void _print_(pair <T, V> p) {cerr << "{"; _print_(p.ff); cerr << ","; _print_(p.ss); cerr << "}";}
template <class T> void _print_(vector <T> v) {cerr << "[ "; for (T i : v) {_print_(i); cerr << " ";} cerr << "]";}
template <class T> void _print_(set <T> v) {cerr << "[ "; for (T i : v) {_print_(i); cerr << " ";} cerr << "]";}
template <class T> void _print_(multiset <T> v) {cerr << "[ "; for (T i : v) {_print_(i); cerr << " ";} cerr << "]";}
template <class T, class V> void _print_(map <T, V> v) {cerr << "[ "; for (auto i : v) {_print_(i); cerr << " ";} cerr << "]";}
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
assert(l<=x && x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
void case2(vector<ll> &vect, vector<ll> &final, vector<ll> &m, ll n, ll i, ll large)
{
for(ll j=0; j<i; j++)
final.pb(vect[j]);
for(ll j=large-1; j>vect[i]; j--)
{
if(!m[j])
large = j;
}
final.pb(large);
m[large] = 1;
for(ll i=0; i<m.size(); i++)
{
if(!m[i])
final.pb(i);
}
}
ll solve(vector<ll> &vect, vector<ll> &final, ll p)
{
ll ans = 0;
reverse(all(vect));
reverse(all(final));
while(final.size() > vect.size())
vect.pb(0);
while(final.size() < vect.size())
final.pb(0);
for(ll i=0; i<vect.size(); i++)
{
ll temp1 = mod(final[i]*power(p,i,M));
ll temp2 = mod(vect[i]*power(p,i,M));
ans += mod(temp1 - temp2);
ans = mod(ans);
}
return ans;
}
int main()
{
quick;
#ifndef ONLINE_JUDGE
freopen("edge.in", "r", stdin);
freopen("edge.out", "w", stdout);
// freopen("error.txt", "w", stderr);
#endif
ll sum = 0, k = 2;
ll T = readIntLn(1,1e4);
while(T--)
{
// dbg(k);
k+=2;
ll n = readIntSp(1,1e6);
ll p = readIntLn(1,1e6);
sum += p;
assert(n<=p);
vector<ll> vect(n);
for(ll i=0; i<n-1; i++)
vect[i] = readIntSp(0,p-1);
vect[n-1] = readIntLn(0,p-1);
assert(vect[0]!=0);
vector<ll> m(p), final;
if(n<p)
{
final.pb(1);
final.pb(0);
for(ll i=2; i<=p-1; i++)
final.pb(i);
cout<<solve(vect, final, p)<<endl;
continue;
}
ll large = p-1, i = 0;
bool check = 1;
while(i<n)
{
if(m[vect[i]])
{
while(large>=0 && m[large])
large--;
if(large < vect[i])
{
i--;
while(i>=0 && large < vect[i])
{
m[vect[i]] = 0;
large = max(large, vect[i]);
i--;
}
if(i<0)
{
final.pb(1);
final.pb(0);
final.pb(0);
for(ll j=2; j<p; j++)
final.pb(j);
}
else
{
m[vect[i]] = 0;
case2(vect, final, m, n, i, large);
}
}
else
case2(vect, final, m, n, i, large);
check = 0;
break;
}
else
m[vect[i]] = 1;
i++;
}
if(check)
cout<<"0"<<endl;
else
cout<<solve(vect, final, p)<<endl;
}
assert(sum<=1e6);
assert(getchar()==-1); // Ensures that there are no extra characters at the end.
cerr<<"SUCCESS\n"; // You should see this on the http://campus.codechef.com/files/stderr/SUBMISSION_ID page, at the bottom.
timetaken;
return 0;
}
/*
1. Binary Search / Binary Search on Answer
2. Bit
3. Parity (Odd / Even)
4. DP / Greedy
5. Graph / Bi-Partite
*/
Tester's code (C++)
/**
* the_hyp0cr1t3
* 04.01.2023 16:18:06
**/
#ifdef W
#include <k_II.h>
#else
#include <bits/stdc++.h>
using namespace std;
#endif
// -------------------- Input Validator Start --------------------
#define read_int_sp(x, L, R) val.read_int(x, L, R, ' ', __LINE__, #x)
#define read_int_ln(x, L, R) val.read_int(x, L, R, '\n', __LINE__, #x)
#define read_vec(vec, N, L, R) val.read_vector(vec, N, L, R, __LINE__, #vec)
#define read_str_sp(x, L, R, chset) val.read_string(x, L, R, chset, ' ', __LINE__, #x)
#define read_str_ln(x, L, R, chset) val.read_string(x, L, R, chset, '\n', __LINE__, #x)
constexpr int max_digits = 19;
enum test_type { single_test, multi_tests };
enum char_set { alpha, binary, digit, gridwalls };
template <test_type T = single_test> class validator {
int tests, current_test {0}, input_line_no {1}, input_col_no {0};
public:
template <test_type U = T,
std::enable_if_t<
std::is_same<validator<U>, validator<single_test>>::value> * = nullptr>
validator() : tests {1} {}
template <test_type U = T,
std::enable_if_t<
std::is_same<validator<U>, validator<multi_tests>>::value> * = nullptr>
validator(int tests_lb, int tests_ub) {
read_int(tests, tests_lb, tests_ub, '\n', -1, "tests");
}
#define FAIL(cond, msg) \
if (cond) { \
std::cerr << msg "while reading\n" \
<< "> symbol \"" << label << "\" (line " << line << ")\n" \
<< "> in test " << current_test << '\n' \
<< "> at pos " << input_line_no << ':' << input_col_no << '\n'; \
abort(); \
}
template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
void read_int(
U &x, int64_t L, int64_t R, char delim, int line = -1, const char *label = "") {
int64_t res = 0;
int len = 0, leading = -1;
bool is_negative = false;
while (true) {
char c = std::getchar();
++input_col_no;
if (c == '-') {
FAIL(len > 0, "error: found invalid symbol \'-\'\n")
is_negative = true;
} else if ('0' <= c and c <= '9') {
res = res * 10 + c - '0';
if (++len == 1)
leading = c - '0';
FAIL(leading == 0 and len > 1, "error: found leading zeroes\n")
FAIL(leading == 0 and is_negative, "error: found negative zero\n")
FAIL(len > max_digits or len == max_digits and leading > 1,
"error: value will overflow in64_t\n")
} else if (c == delim) {
if (is_negative)
res *= -1;
if (res < L or R < res) {
std::cerr << "error: found value " << res
<< " expected to be in range [" << L << ", " << R << "]\n";
FAIL(true, "")
}
x = res;
if (delim == '\n') ++input_line_no, input_col_no = 0;
return;
} else {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
}
}
int read_string(std::string &s,
int L, int R, char_set chset, char delim, int line = -1, const char *label = "") {
std::string res;
char c;
while (res.size() <= R) {
c = std::getchar();
if (c == EOF or c == delim)
break;
res += c;
if (chset == binary) {
if (c != '0' and c != '1') {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == alpha) {
if (c < 'a' or 'z' < c) {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == digit) {
if (c < '0' or '9' < c) {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
} else if (chset == gridwalls) {
if (c != '.' and c != '#') {
const char *cc = c == '\n' ? "\n" : &c;
std::cerr << "error: found invalid character \'" << cc << "\'\n";
FAIL(true, "")
}
}
}
FAIL(c == EOF, "Unexpected EOF\n")
if (res.length() < L or R < res.length()) {
std::cerr << "error: found string of length " << res.length()
<< " expected to be in range [" << L << ", " << R << "]\n";
FAIL(true, "")
}
s = res;
return res.length();
}
template <typename U = int, typename = std::enable_if_t<std::is_integral<U>::value>>
void read_vector(
std::vector<U> &vec, int N, int L, int R, int line = -1, const char *label = "") {
vec.resize(N);
for (int i = 0; i < N - 1; i++)
read_int(vec[i], L, R, ' ', line, label);
read_int(vec[N - 1], L, R, '\n', line, label);
}
bool do_test() { return ++current_test <= tests; }
~validator() {
#ifndef W
if (std::getchar() != EOF) {
std::cerr << "error: expected EOF\n";
abort();
}
#endif
}
};
// -------------------- Input Validator End --------------------
int main() {
#if __cplusplus > 201703L
namespace R = ranges;
#endif
ios_base::sync_with_stdio(false), cin.tie(nullptr);
constexpr int MOD = 1'000'000'000 + 7;
int64_t sum_b = 0, sum_b2 = 0, case1 = 0, case2 = 0, case3 = 0;
validator<multi_tests> val(1, 10'000);
auto eval = [](const vector<int> &p, int b) {
int64_t ans = 0;
for (auto x: p)
ans = (ans * b + x) % MOD;
return ans;
};
while (val.do_test()) {
int n, b;
read_int_sp(n, 1, 1e6);
read_int_ln(b, 2, 1e6);
assert(n <= b);
sum_b += b;
sum_b2 += 1LL * b * b;
vector<int> A, B(b);
read_vec(A, n, 0, b - 1);
assert(A[0] > 0);
// case 1: make it [1, 0, 2, 3, 4...]
if (n < b) {
iota(B.begin(), B.end(), 0);
swap(B[0], B[1]);
++case1;
} else {
iota(B.begin(), B.end(), 0);
reverse(B.begin(), B.end());
// case 2: make it [1, 0, 0, 2, 3, 4...]
if (A > B) {
B.push_back(1);
reverse(B.begin(), B.end());
B[2] = 0;
++case2;
} else {
B = A;
vector<bool> used(n);
int can_inc = -1, top = n - 1;
for (int i = 0; i < n; i++) {
while (top >= 0 and used[top]) top--;
if (top > A[i])
can_inc = i;
if (used[A[i]])
break;
used[A[i]] = true;
}
if (count(used.begin(), used.end(), true) < n) {
used.assign(n, 0);
int i = 0;
while (i < can_inc) {
used[A[i]] = true;
++i;
}
++B[i];
while (used[B[i]]) ++B[i];
used[B[i]] = true;
top = 0;
while (++i < n) {
while (used[top]) top++;
B[i] = top++;
}
++case3;
} // else A is already good
}
}
cout << (eval(B, b) - eval(A, b) + MOD) % MOD << '\n';
}
cerr << "Sum B: " << sum_b << '\n';
cerr << "Sum B^2: " << sum_b2 << '\n';
cerr << "10234... cnt: " << case1 << '\n';
cerr << "100234... cnt: " << case2 << '\n';
cerr << "other cnt: " << case3 << '\n';
} // ~W
Tester's code (C++)
// Jai Shree Ram
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,n) for(int i=a;i<n;i++)
#define ll long long
#define int long long
#define pb push_back
#define all(v) v.begin(),v.end()
#define endl "\n"
#define x first
#define y second
#define gcd(a,b) __gcd(a,b)
#define mem1(a) memset(a,-1,sizeof(a))
#define mem0(a) memset(a,0,sizeof(a))
#define sz(a) (int)a.size()
#define pii pair<int,int>
#define hell 1000000007
#define elasped_time 1.0 * clock() / CLOCKS_PER_SEC
template<typename T1,typename T2>istream& operator>>(istream& in,pair<T1,T2> &a){in>>a.x>>a.y;return in;}
template<typename T1,typename T2>ostream& operator<<(ostream& out,pair<T1,T2> a){out<<a.x<<" "<<a.y;return out;}
template<typename T,typename T1>T maxs(T &a,T1 b){if(b>a)a=b;return a;}
template<typename T,typename T1>T mins(T &a,T1 b){if(b<a)a=b;return a;}
// -------------------- Input Checker Start --------------------
long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}
string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}
long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
void readEOF() { assert(getchar() == EOF); }
vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}
// -------------------- Input Checker End --------------------
const int MOD = hell;
struct mod_int {
int val;
mod_int(long long v = 0) {
if (v < 0)
v = v % MOD + MOD;
if (v >= MOD)
v %= MOD;
val = v;
}
static int mod_inv(int a, int m = MOD) {
int g = m, r = a, x = 0, y = 1;
while (r != 0) {
int q = g / r;
g %= r; swap(g, r);
x -= q * y; swap(x, y);
}
return x < 0 ? x + m : x;
}
explicit operator int() const {
return val;
}
mod_int& operator+=(const mod_int &other) {
val += other.val;
if (val >= MOD) val -= MOD;
return *this;
}
mod_int& operator-=(const mod_int &other) {
val -= other.val;
if (val < 0) val += MOD;
return *this;
}
static unsigned fast_mod(uint64_t x, unsigned m = MOD) {
#if !defined(_WIN32) || defined(_WIN64)
return x % m;
#endif
unsigned x_high = x >> 32, x_low = (unsigned) x;
unsigned quot, rem;
asm("divl %4\n"
: "=a" (quot), "=d" (rem)
: "d" (x_high), "a" (x_low), "r" (m));
return rem;
}
mod_int& operator*=(const mod_int &other) {
val = fast_mod((uint64_t) val * other.val);
return *this;
}
mod_int& operator/=(const mod_int &other) {
return *this *= other.inv();
}
friend mod_int operator+(const mod_int &a, const mod_int &b) { return mod_int(a) += b; }
friend mod_int operator-(const mod_int &a, const mod_int &b) { return mod_int(a) -= b; }
friend mod_int operator*(const mod_int &a, const mod_int &b) { return mod_int(a) *= b; }
friend mod_int operator/(const mod_int &a, const mod_int &b) { return mod_int(a) /= b; }
mod_int& operator++() {
val = val == MOD - 1 ? 0 : val + 1;
return *this;
}
mod_int& operator--() {
val = val == 0 ? MOD - 1 : val - 1;
return *this;
}
mod_int operator++(int32_t) { mod_int before = *this; ++*this; return before; }
mod_int operator--(int32_t) { mod_int before = *this; --*this; return before; }
mod_int operator-() const {
return val == 0 ? 0 : MOD - val;
}
bool operator==(const mod_int &other) const { return val == other.val; }
bool operator!=(const mod_int &other) const { return val != other.val; }
mod_int inv() const {
return mod_inv(val);
}
mod_int pow(long long p) const {
assert(p >= 0);
mod_int a = *this, result = 1;
while (p > 0) {
if (p & 1)
result *= a;
a *= a;
p >>= 1;
}
return result;
}
friend ostream& operator<<(ostream &stream, const mod_int &m) {
return stream << m.val;
}
friend istream& operator >> (istream &stream, mod_int &m) {
return stream>>m.val;
}
};
int solve(){
int n = readIntSp(1,1e6);
static int sum_n = 0;
sum_n += n;
assert(sum_n <= 1e6);
int b = readIntLn(max(2LL,n),1e6);
vector<int> a = readVectorInt(n,0,b - 1);
assert(a[0] != 0);
mod_int temp = 0;
set<int> st;
for(auto i:a){
st.insert(i);
temp = temp*b + i;
}
if(st.size() == b){
cout << 0 << endl;
return 0;
}
if(n < b){
mod_int val = b;
for(int i = 2; i < b; i++){
val = val*b + i;
}
cout << val - temp << endl;
return 0;
}
vector<int> mx;
for(int i = b - 1; i >= 0; i--) mx.push_back(i);
if(a > mx){
mod_int val = b*b;
for(int i = 2; i < b; i++){
val = val*b + i;
}
cout << val - temp << endl;
return 0;
}
set<int> rem;
for(int i = 0; i <= b - 1; i++) rem.insert(i);
mod_int val = 0;
bool inv = false;
int pos = -1;
int cnt = 0;
for(auto i:a){
auto it = rem.upper_bound(i);
if(it != rem.end()){
pos = cnt;
}
if(rem.count(i) == 0) break;
rem.erase(i);
cnt++;
}
rem.clear();
for(int i = 0; i <= b - 1; i++) rem.insert(i);
for(int i = 0; i < pos; i++){
val = val*b + a[i];
rem.erase(a[i]);
}
auto it = rem.upper_bound(a[pos]);
val = val*b + *it;
rem.erase(it);
for(auto i: rem){
val = val*b + i;
}
cout << val - temp << endl;
return 0;
}
signed main(){
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("input.txt", "r", stdin);
//freopen("output.txt", "w", stdout);
#ifdef SIEVE
sieve();
#endif
#ifdef NCR
init();
#endif
int t = readIntLn(1,10000);
while(t--){
solve();
}
return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
def solve(n, b, a):
if n < b: return [1, 0] + list(range(2, b))
for i in range(n):
if a[i] < b-1-i: break
if a[i] > b-1-i: return [1, 0, 0] + list(range(2, b))
mark = [0]*b
for i in range(n):
if mark[a[i]] == 0:
mark[a[i]] = 1
continue
digit = b-1
while mark[digit] == 1: digit -= 1
pos = i
while pos >= 0:
if pos < i:
mark[a[pos]] = 0
digit = max(digit, a[pos])
if digit > a[pos]: break
pos -= 1
prv = digit
digit -= 1
while digit >= 0:
if digit <= a[pos]: break
if mark[digit] == 0: prv = digit
digit -= 1
mark[prv] = 1
a[pos] = prv
digit = 0
for j in range(pos+1, n):
while mark[digit] == 1: digit += 1
mark[digit] = 1
a[j] = digit
break
return a
def convert(a, base):
num = 0
for i in range(len(a)):
num = num*base + a[i]
num %= mod
return num
for _ in range(int(input())):
n, b = map(int, input().split())
a = list(map(int, input().split()))
acopy = a[:]
ans = solve(n, b, a)
print((convert(ans, b) - convert(acopy, b))%mod)