MONSTBAT - Editorial

PROBLEM LINK:

Practice
Div-2 Contest
Div-1 Contest

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Game theory, Recursion

PROBLEM:

Chef has N monsters and Chefina has M monsters. Each monster has a value and is in either attack mode or defense mode. In one turn, the player either end the game or use one of their monsters in attack mode to kill one of the opponent’s monsters in defense mode. The monster which attacked changes its mode to defense. Given that Chef starts the game and both players play optimally, find the sum of values of Chef’s living monsters minus the sum of values of Chefina’s living monsters after the game ends.

QUICK EXPLANATION

Create a recursive function which will solve the problem given the set of monsters for both players.

EXPLANATION:

We will analyze the game from the start. If player 1 decides to end the game, then the final value of the game will simply be the sum of the values of player 1’s monsters minus the sum of the values of player 2’s monsters at this point.

Otherwise, if player 1 has at least one attack monster and player 2 has at least one defense monster, then player 1 can attack. Which monster does player 1 use to attack, and which monster should be attacked?

The monster player 1 chooses to attack with becomes a defense monster, making it vulnerable to player 2’s attacks, potentially causing player 1 to lose that monster. So, we should sacrifice the attack monster with the smallest value in order to maximize the sum of values of player 1’s monsters.

For the monster to attack, it’s obvious that the defense monster with the greatest value should be killed in order to minimize the sum of values of player 2’s monsters.

Suppose that we already know the final value of the game given that player 1 attacks. In order to find the final answer, we just take the maximum value of both cases of attacking and not attacking.

How do we find the final value of the game given that player 1 attacks? After player 1’s turn has ended, it is player 2’s turn. Player 2’s decisions can be found with the same process that player 1 uses above. However, player 2 will have a different set of monsters to consider.

This suggests that we should use a recursive solution, since the result of the problem for a set of monsters can depend on the result of the same problem for a smaller set of monsters. The pseudocode is shown below:

  • //Plays player 1 and returns max((player 1 value sum) - (player 2 value sum)).
  • //attack1 and defense1 are sets of monsters belonging to player 1, attack2 and defense2 are sets of monsters for player 2.
  • play(attack1, defense1, attack2, defense2)
    • value1 is set to the current sum of player 1’s monsters minus the current sum of player 2’s monsters.
    • If attack1 is not empty and defense 2 is not empty, then player 1 can attack:
      • Let monster1 be the monster with least value in attack1.
      • Let monster2 be the mosnter with greatest value in defense2.
      • monster1 becomes defense, so we remove monster1 from attack1 and add monster1 to defense1.
      • monster2 dies, so we remove monster2 from defense2.
      • //It is now player 2’s turn. Note that we add a negative sign in front of play() because the function returns (player 2 value sum) - (player 1 value sum), but we want (player 1 value sum) - (player 2 value sum).
      • value2 = -play(attack2, defense2, attack1, defense1)
    • Return the maximum of value1 and value2 as we want the best of both cases.

What is the time complexity of this algorithm? In each recursive call, the number of total monsters decreases by 1, so there can be no more than N+M calls. Each call takes O(N+M) time, so the total time complexity is O((N+M)^2).

For the full solution, we need some slight optimizations. We can use data structures (such as multiset in C++) to store the sets of monsters. Those data structures should support querying min/max and removing values in O(\log n) time. In addition, we should also pass the sum of values of monsters of both players into the recursive function and update as needed, so we don’t need O(N+M) to calculate the value1 every time. The final time complexity is O((N+M)\log (N+M).

SOLUTIONS:

Setter's Solution
#include <iostream>
#include <algorithm>
#include <string>
#include <queue>
#include <assert.h>
using namespace std;
 
 
 
long long readInt(long long l,long long r,char endd){
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true){
		char g=getchar();
		if(g=='-'){
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g && g<='9'){
			x*=10;
			x+=g-'0';
			if(cnt==0){
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
			
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd){
			assert(cnt>0);
			if(is_neg){
				x= -x;
			}
			assert(l<=x && x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l,int r,char endd){
	string ret="";
	int cnt=0;
	while(true){
		char g=getchar();
		assert(g!=-1);
		if(g==endd){
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt && cnt<=r);
	return ret;
}
long long readIntSp(long long l,long long r){
	return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
	return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
	return readString(l,r,'\n');
}
string readStringSp(int l,int r){
	return readString(l,r,' ');
}
 
int T;
int n,m;
long long v1[100100];
string d1;
long long v2[100100];
string d2;
 
int sm_n=0;
int sm_m=0;
 
long long result[300300],nn=0;
long long dp[300300];
int main(){
	//("01.txt","r",stdin);
	//freopen("01o.txt","w",stdout);
	T=readIntLn(1,100);
	while(T--){
		priority_queue<long long> def1,atk1,def2,atk2;
		n=readIntSp(1,100000);
		m=readIntLn(1,100000);
		sm_n += n;
		sm_m += m;
		assert(sm_n<=1000000);
		assert(sm_m<=1000000);
		nn=0;
		long long tot=0;
		for(int i=0;i<n;i++){
			if(i==n-1){
				v1[i]=readIntLn(1,1000000000);
			} else {
				v1[i]= readIntSp(1,1000000000);
			}
			tot += v1[i];
		}
		d1 = readStringLn(n,n);
		for(int i=0;i<n;i++){
			assert(d1[i]=='A' || d1[i] == 'D');
		}
		for(int i=0;i<m;i++){
			if(i==m-1){
				v2[i]=readIntLn(1,1000000000);
			} else {
				v2[i]= readIntSp(1,1000000000);
			}
			tot -= v2[i];
		}
		d2 = readStringLn(m,m);
		bool has_def=false;
		for(int i=0;i<m;i++){
			assert(d2[i]=='A' || d2[i] == 'D');
			if(d2[i]=='D')has_def=true;
		}
		assert(has_def);
 
		for(int i=0;i<n;i++){
			if(d1[i]=='A'){
				atk1.push(-v1[i]);
			} else {
				def1.push(v1[i]);
			}
		}
 
		for(int i=0;i<m;i++){
			if(d2[i]=='A'){
				atk2.push(-v2[i]);
			} else {
				def2.push(v2[i]);
			}
		}
		result[nn++]=tot;
		while(true){
			if(atk1.empty()){
				break;
			}
			def1.push(-atk1.top());
			atk1.pop();
			tot += def2.top();
			def2.pop();
			result[nn++]=tot;
			if(atk2.empty()){
				break;
			}
			def2.push(-atk2.top());
			atk2.pop();
			tot -= def1.top();
			def1.pop();
			result[nn++]=tot;
		}
		for(int i=nn-1;i>=0;i--){
			dp[i] = result[i];
			if(i% 2 == 0){
				if(i +1 < nn ){
					dp[i] = max(dp[i],dp[i+1]);
				}
			} else {
				if(i +1 < nn ){
					dp[i] = min(dp[i],dp[i+1]);
				}
			}
		}
		cout<<dp[0]<<endl;
	}
	assert(getchar()==-1);
}
Tester's Solution
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
#define int ll
 
int N;
int dp[412345],visit[412345];
int x[123456],y[123456];
int ans[412345];
int solve(int pos){
	if(pos==N-1)
		return ans[pos];
	if(visit[pos]==1)
		return dp[pos];
	visit[pos]=1;
	if(pos%2==0){
		dp[pos]=max(ans[pos],solve(pos+1));
	}
	else{
		dp[pos]=min(ans[pos],solve(pos+1));
	}
	return dp[pos];
}
main(){
    std::ios::sync_with_stdio(false); cin.tie(NULL);
    int tt;
    cin>>tt;
    while(tt--){
    	int n,m;
    	cin>>n>>m;
    	int i,elem;
    	rep(i,n+m+10){
    		visit[i]=0;
    	}
    	rep(i,n){
    		cin>>x[i];
    	}
    	string s;
    	cin>>s;
    	rep(i,m){
    		cin>>y[i];
    	}
    	string t;
    	cin>>t;
    	pqueue pq1,pq2;
    	pdqueue pdqueue1,pdqueue2;
    	int sum1=0,sum2=0;
    	rep(i,n){
    		if(s[i]=='A'){
    			pdqueue1.push(x[i]);
    		}
    		else{
    			pq1.push(x[i]);
    		}
    		sum1+=x[i];
    	}
    	rep(i,m){
    		if(t[i]=='A'){
    			pdqueue2.push(y[i]);
    		}
    		else{
    			pq2.push(y[i]);
    		}
    		sum2+=y[i];
    	}
    	int j=0;
    	while(1){
    		ans[j]=sum1-sum2;
    		j++;
    		if(pdqueue1.empty()){
    			break;
    		}
    		elem=pdqueue1.top();
    		pdqueue1.pop();
    		pq1.push(elem);
    		sum2-=pq2.top();
    		pq2.pop();
    		ans[j]=sum1-sum2;
    		j++;
    		if(pdqueue2.empty()){
    			break;
    		}
    		elem=pdqueue2.top();
    		pdqueue2.pop();
    		pq2.push(elem);
    		sum1-=pq1.top();
    		pq1.pop();
    	}
    	N=j;
    	cout<<solve(0)<<endl;
    }
    return 0;   
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

#define ll long long

int n[2], a[100000];

//player 1's turn
ll play(ll s1, ll s2, multiset<int> &attack1, multiset<int> &defense1, multiset<int> &attack2, multiset<int> &defense2) {
	//we can choose to end the game
	ll r=s1-s2;
	if(!attack1.empty()&&!defense2.empty()) {
		//choose best monsters
		int monster1=*attack1.begin(), monster2=*--defense2.end();
		//change monster1 to defense
		attack1.erase(attack1.find(monster1));
		defense1.insert(monster1);
		//monster2 dies
		s2-=monster2;
		defense2.erase(defense2.find(monster2));
		//now it's player 2's turn
		r=max(-play(s2, s1, attack2, defense2, attack1, defense1), r);
	}
	return r;
}

void solve() {
	//input
	cin >> n[0] >> n[1];
	//sum of monster values
	ll s[2]={};
	//multiset for attack & defense
	multiset<int> attack[2], defense[2];
	for(int k : {0, 1}) {
		for(int i=0; i<n[k]; ++i)
			cin >> a[i];
		string t;
		cin >> t;
		//update sums & multisets
		for(int i=0; i<n[k]; ++i) {
			if(t[i]=='A')
				attack[k].insert(a[i]);
			else
				defense[k].insert(a[i]);
			s[k]+=a[i];
		}
	}
	
	cout << play(s[0], s[1], attack[0], defense[0], attack[1], defense[1]) << "\n";
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);

	int t;
	cin >> t;
	while(t--)
		solve();
}

Please give me suggestions if anything is unclear so that I can improve. Thanks :slight_smile:

3 Likes

Please explain the output of the test case given in the question.Please

If on the next turn Chefina attacks then she would end in a state which would decrease the
Y-X so we would go to the state

CHEF (5-A 60-A) and CHEFINA(15-D 16-A)

Chef would obviously then attack Chefina.
now if you calculate the difference CHEFINA - CHEF, it is (16 - 65).
If Chefina continues to attack you would see that the CHEFINA - CHEF difference would never get better than when she quits in her first turn.

Thanks Sir…