PROBLEM LINK:
Author: Anurag Tiwari
Tester: Sankalp Gupta
DIFFICULTY:
HARD.
PREREQUISITES:
Bit mask, DP, Math
PROBLEM:
You are given an array A of numbers from 1 to n, that is , the array has n distinct numbers from 1 to n and two numbers p and q.
Your task is to determine subsequences such that for any 2 numbers, x and y , in the subsequence , 4 conditions hold always true:
- x-y ≠ p
- y-x ≠ p
- x-y ≠ q
- y-x ≠ q
Print the length of the longest subsequence which holds these conditions.
EXPLANATION:
First of all, the order of elements does not matter.
The key idea of the task is to prove that there is an optimal answer where the chosen element in A has a period equal to s=p + q. Let’s work with 0,1,…,n−1 instead of 1,2,…,n.
Firstly, let’s prove that if we’ve chosen correct set x_{1},x_{2},…,x_{k} in interval [l,l+s) then if we take all x_{i}+s then set {x_{1},…,x_{k},x_{1}+s,…x_{k}+s} will be correct as well.
By contradiction: suppose we have x_{i}+s−x_{j}=p , then x_{i}−x_{j}=p−s=p−p−q=−q or |x_{i}−x_{j}|=q (contradiction). Similarly if we have x_{i}+s−x_{j}=q , then x_{i}−x_{j}=q−s=q−p−q=−p or |x_{i}−x_{j}|=p (contradiction).
It means that if we take the correct set in interval [l,l+s) we can create a periodic answer by copying this interval several times.
Next, let’s prove that there is an optimal periodic answer. Let’s look at any optimal answer and its indicator vector (binary vector of length n where id_{i}=1 iff i is in the set). Let r=n mod s.
Let’s split the vector in 2⌊\frac{n}{s}⌋+1 intervals: [0,r),[r,s),[s,s+r),[s+r,2s),…,[n−r,n). The 1st, 3rd, 5th… segments have length r and 2nd, 4th,… segments have length s−r. If we choose any two consecutive segments its total length will be equal to s and we can use it to make periodic answers by replacing all length r segments with the chosen one and s−r segments with the other one.
We can prove that we can always find such two consecutive segments that the induced answer will be greater or equal to the initial one. If we create vector where v_{i} is equal to the sum of id_{j} in the i^{th} segment, then the task is equivalent to finding v_{i} and v_{i+1} such that replacing all v_{i±2z} by vi and all v_{i±2z+1} by v_{i+1} won’t decrease array v sum. The proof is down below.CodeChef submission 45489008 (C++17) plaintext list. Status: AC, problem ALGOCUP4, contest SALC21TS. By sank_1611 (sank_1611), 2021-04-27 21:50:12.
Now, since the answer is periodical, taking element c (0≤c<s) is equivalent to taking all elements d ≡ c mod s, so for each c we can calculate val_{c} — the number of integers with the same remainder. And for each c we either take it or not.
So we can write dp[s][2^{max(p,q)}], where dp[i][mk] is the maximum sum if we processed i elements and last max(p,q) elements are described by mask mk. We start with dp[0][0] and, when look at the i^{th} element, either take it (if we can) or skip it.
Time complexity is O((p+q)*2^{max(p,q)}).
Let’s prove that for any array v_{1},v_{2},…,v_{2n+1} we can find pair v_{s},v_{s+1} such that replacing all v_{s±2z} with vs and all v_{s±2z+1} with v_{s+1} won’t decrease the total sum.
Let’s define So=∑_{i=1}^ {n+1} v_{2i−1} and Se=∑_{i=1}^ {n}v_{2i}. Let’s make an array b_{1},…,b_{2n+1}, where b_{2i-1}=(n+1)*v_{2i-1}−So and b_{2i}=n*v_{2i}−Se. The meaning behind b_{i} is how changes the total sum if we replace corresponding elements by v_{i}.
Note, that finding a good pair v_{s},v_{s+1} is equivalent to finding b_{s}+b_{s+1}≥0. Also, note that ∑_{i=1}^ {n+1}b_{2i-1}=(n+1)*So−(n+1)*So=0 and
analogically, ∑_{i=1}^ {n}b_{2i}=n*Se−n*Se=0.
Let’s prove by contradiction: suppose that for any i, b_{i}+b_{i+1}<0. Let’s look at
∑_{i=1}^ {2n+1}b_{i}=∑_{i=1}^ {n+1}b_{2i-1}+∑_{i=1}^ {n}b_{2i}=0. But from the other side, we know that b_{2}+b_{3}<0, b_{4}+b_{5}<0, ..., b_{2n}+b_{2n+1}<0, so b_{1}>0, otherwise ∑_{i=1}^ {2n+1}b_{i} will be negative.
In the same way, since b_{1}+b_{2}<0, b_{4}+b_{5}<0, ..., b_{2n}+b_{2n+1}<0, then b_{3}>0. Analogically we can prove that each b_{2i-1}>0, but ∑_{i=1}^ {n+1}b_{2i-1}=0 (contradiction). So, there is always a pair b_{s}+b_{s+1}≥0, i.e. a pair v_{s},v_{s+1}.
SOLUTIONS:
Setter's Solution
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;
const int N = 22;
int dp[2][1 << N];
int val[2 * N];
int main() {
int t;
cin>>t;
while(t--){
int n, x, y;
scanf("%d%d%d", &n, &x, &y);
int k = x + y;
int m = max(x, y);
int FULL = (1<<m)-1;
for (int i = 0; i < k; ++i)
val[i] = n / k + (i < n % k);
for (int mask = 0; mask < (1 << m); ++mask)
dp[0][mask] = -INF;
dp[0][0] = 0;
for (int i = 0; i < k; ++i) {
for (int mask = 0; mask < (1 << m); ++mask)
dp[1][mask] = -INF;
for (int mask = 0; mask < (1 << m); ++mask) {
if (dp[0][mask] == -INF)
continue;
int nmask = (mask << 1) & FULL;
dp[1][nmask] = max(dp[1][nmask], dp[0][mask]);
if (((mask >> (x - 1)) & 1) | ((mask >> (y - 1)) & 1))
continue;
nmask |= 1;
dp[1][nmask] = max(dp[1][nmask], dp[0][mask] + val[i]);
}
swap(dp[0], dp[1]);
}
int ans = 0;
for (int mask = 0; mask < (1 << m); ++mask)
ans = max(ans, dp[0][mask]);
printf("%d\n", ans);
}
}
Tester's Solution
#include<bits/stdc++.h>
#define ll long long int
#define ull unsigned long long int
#define vi vector<int>
#define vll vector<ll>
#define vvi vector<vi>
#define vvl vector<vll>
#define pb push_back
#define mp make_pair
#define all(v) v.begin(), v.end()
#define pii pair<int,int>
#define pll pair<ll,ll>
#define vpii vector<pii >
#define vpll vector<pll >
#define ff first
#define ss second
#define PI 3.14159265358979323846
#define fastio ios_base::sync_with_stdio(false) , cin.tie(NULL) ,cout.tie(NULL)
ll power(ll a,ll b){ ll res=1; while(b>0){ if(b&1) res*=a; a*=a; b>>=1;} return res; }
ll power(ll a,ll b,ll m){ ll res=1; while(b>0){ if(b&1) res=(res*a)%m; a=(a*a)%m; b>>=1;} return res;}
bool pp(int a,int b) {return a>b;}
using namespace std;
const int INF = 1e9;
const int N = 21;
ll dp[2][1 << N];
ll val[2 * N];
void solve(){
ll n,x,y;
cin>>n>>x>>y;
assert(n>0&&n<=100000000);
assert(x>0&&x<=20);
assert(y>0&&y<=20);
ll k = x+y;
ll m = max(x,y);
m = (1<<m);
ll FULL = m-1;
for(ll i=0;i<k;i++){
val[i] = n/k + (i<n%k);
}
for (int mask = 0; mask < m; ++mask)
dp[0][mask] = -INF;
dp[0][0] = 0;
for (int i = 0; i < k; ++i) {
for (int mask = 0; mask < m; ++mask)
dp[1][mask] = -INF;
for (int mask = 0; mask < m; ++mask) {
if (dp[0][mask] == -INF)
continue;
ll nmask = (mask << 1) & FULL;
dp[1][nmask] = max(dp[1][nmask], dp[0][mask]);
if (((mask >> (x - 1)) & 1) | ((mask >> (y - 1)) & 1))
continue;
nmask |= 1;
dp[1][nmask] = max(dp[1][nmask], dp[0][mask] + val[i]);
}
swap(dp[0], dp[1]);
}
ll ans = 0;
for (int mask = 0; mask < m; ++mask)
ans = max(ans, dp[0][mask]);
cout<<ans<<"\n";
}
int main()
{
fastio;
int t;
cin>>t;
assert(t>0&&t<=10);
while(t--){
solve();
}
return 0;
}