GUESSALL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: vendx_greyback
Preparer: satyam_343
Testers: IceKnight1093, tabr
Editorialist: IceKnight1093

DIFFICULTY:

2562

PREREQUISITES:

None

PROBLEM:

There is a function F such that F(i) is predetermined (but secret) for 0 \leq i \lt K, and F(i) = -\sum_{j=1}^K F(i-j) otherwise.

In one query, you can provide x and the judge will return F(x).

You are given an array B of length N. Using the least possible number of queries, compute the values of F(B_1), F(B_2), \ldots, F(B_N).

EXPLANATION:

Playing around with the function a little should give you a rather useful piece of information: F is periodic, with period K+1. That is, F(i) = F(i+K+1) for every i \geq 0.

Proof

Let’s try computing some values of F.

F(0), F(1), \ldots, F(K-1) are all fixed.
The formula gives us F(K) = -(F(0) + F(1) + \ldots + F(K-1)).
Now try computing the next few values:

  • F(K+1) = -(F(1) + F(2) + \ldots + F(K)) = F(0) if you substitute the value of F(K) from above.
  • F(K+2) = F(1) using a similar computatation
    \vdots
  • F(2K) = F(K-1)
  • F(2K+1) = -(F(2K) + F(2K-1) + \ldots + F(K+1)) = -(F(0) + F(1) + \ldots + F(K-1))

Notice that these K+1 values are exactly equal to the first K+1.
Now that the values have repeated, apply the same argument to this batch of K+1 values to obtain the next K+1, and so on.

This tells us that it is enough to know the values of F(0), F(1), \ldots, F(K). We can use these to answer every query, since F(x) = F(x\bmod{(K+1)}).

Notice that the first K values being hidden means we have no choice but to ask for those. However, there are still a couple of optimizations to be made:

First, note that if we know the value of the function at 0, 1, 2, \ldots, K-1 then we can compute f(K) without a query. This means we can always use \leq K queries.

Second, if some value of x\bmod{(K+1)} doesn’t appear in B, we can simply skip asking for it since we’re never going to need it anyway.

This gives us our final solution:

  • Let S = \{B_i\bmod{(K+1)} \mid 1 \leq i \leq N\} be the set of all relevant remainders.
  • If the size of S is K+1, query for 0, 1, 2, \ldots, K-1 and compute F(K) from them.
  • Otherwise, query only for F(x) such that x \in S.
  • Using the results of these queries, compute the answer for each B_i.

TIME COMPLEXITY:

\mathcal{O}(N\log K + K) per testcase.

CODE:

Preparer's code (C++)
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;    
#define pb push_back                 
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif       
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;     
const ll MAX=200010;
void solve(){    
    ll k; cin>>k;
    vector<ll> track(k+5,0);
    ll n; cin>>n;
    vector<ll> b(n+5,0);
    set<ll> check; 
    for(ll i=1;i<=n;i++){
        cin>>b[i];
        check.insert(b[i]%(k+1));
    }
    if(check.size()<k){
        for(auto it:check){
            cout<<"? "<<it<<endl;
            cin>>track[it];
        }
        cout<<"! ";
        for(ll i=1;i<=n;i++){
            cout<<track[b[i]%(k+1)]<<" ";
        }
        cout<<endl;
    }
    else{
        for(ll i=0;i<k;i++){
            cout<<"? "<<i<<endl;
            cin>>track[i];
            track[k]-=track[i];
        }
        cout<<"! ";
        for(ll i=1;i<=n;i++){
            cout<<track[b[i]%(k+1)]<<" ";
        }
        cout<<endl;
    }
    return;                                
}                                                
int main()                                                                                             
{        
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);    
    #ifndef ONLINE_JUDGE                   
    freopen("input.txt", "r", stdin);                                              
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    #endif                          
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    } 
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>

using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    int tt;
    cin >> tt;
    int sn = 0;
    while (tt--) {
        int k;
        cin >> k;
        int n;
        cin >> n;
        sn += n;
        vector<int> b(n);
        for (int i = 0; i < n; i++) {
            cin >> b[i];
            assert(0 <= b[i] && b[i] <= 1e9);
        }
        set<int> st;
        for (int i = 0; i < n; i++) {
            st.emplace(b[i] % (k + 1));
        }
        vector<long long> ans(k + 1);
        if ((int) st.size() == k + 1) {
            for (int i = 0; i < k; i++) {
                cout << "? " << i + (k + 1) * (long long) 1e9 << "   " << endl;
                cin >> ans[i];
            }
            ans[k] = -accumulate(ans.begin(), ans.end(), 0LL);
        } else {
            for (int i: st) {
                cout << "?   " << i + (k + 1) * (long long) 1e9 << "   " << endl;
                cin >> ans[i];
            }
        }
        cout << "! ";
        for (int i = 0; i < n; i++) {
            cout << " " << ans[b[i] % (k + 1)];
        }
        cout << endl;
    }
    cerr << sn << endl;
    assert(sn <= 2e5);
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	k = int(input())
	n = int(input())
	b = list(map(int, input().split()))
	
	def query(x):
		print('?', x)
		return int(input())
	def answer(a):
		print('!', *a)
	
	need = set(x%(k+1) for x in b)
	vals = [0 for _ in range(k+1)]
	
	if len(need) == k+1:
		for i in range(k): vals[i] = query(i)
		vals[k] = -sum(vals)
	else:
		for x in need: vals[x] = query(x)
	
	for i in range(n):
		b[i] = vals[b[i]%(k+1)]
	answer(b)