STRPOW - Editorial

PROBLEM LINK:

Div-1 Contest
Div-2 Contest
Div-3 Contest
Practice

Author & Editorialist: Samarth Gupta
Tester: Aryan

DIFFICULTY:

Hard

PREREQUISITES:

Expected Value, Dynamic Programming, Generating Functions

PROBLEM:

Given a parity array B, construct a random string S and find the expected power of the constructed string given that the frequency(modulo 2) of its characters satisfy the parity array B. The power of individual characters is a_i.

QUICK EXPLANATION:

Let’s assume that the character c_i occurs j_i times. We can find the number of such possible strings and their respective probabilities of occurrence and the power also. After that we can use generating functions and simplify the expressions of power and probability.
Note that we need to calculate i^K also. However, binary exponentiation would TLE. To optimize this we can see that we are interested in \leq R values of i. Therefore, we can calculate i^K in O(1) using O(R \times \sqrt K) pre-computation.

EXPLANATION:

The problem statement simplifies to conditional expectation, that is given that frequency(modulo 2) of characters of random string S satisfy the parity array B, find the expected power of string S of length k. Therefore, let’s assume that the character c_i occurs j_i times in the string S, the conditional expected power can then be written as:

E = \frac{\sum\limits_{j_1+j_2+..+j_N = k, j_i\%2 == B_i} \frac{k!}{j_1!j_2!..j_N!} \times (a_1j_1+a_2j_2+..+a_Nj_N)\times p_1^{j_1} \times p_2^{j_2} \times .. \times p_N^{j_N}}{\sum\limits_{j_1+j_2+..+j_N = k, j_i\%2 == B_i} \frac{k!}{j_1!j_2!..j_N!} \times p_1^{j_1} \times p_2^{j_2} \times .. \times p_N^{j_N}}

Define R = \sum p_i
Let’s define a multi-variable function f as follows:

f_k(x_1, x_2,..,x_N) = \sum\limits_{j_1+j_2+..+j_n = K, j_i\%2 == B_i} \frac{k!}{j_1!j_2!..j_N!} \times x_1^{a_1j_1}p_1^{j_1}x_2^{a_2j_2}p_2^{j_2} .. x_n^{a_Nj_N}p_n^{j_N}

Note that the denominator of E equals to f_k(1, 1,..,1) and numerator equals to \left (x_1\frac{\partial f}{\partial x_1} + x_2\frac{\partial f}{\partial x_2} + .. + x_N\frac{\partial f}{\partial x_N}\right)_{x_1 = 1, x_2 = 1,..,x_N=1}
We next write the EGF of function f as follows:

\sum\limits_{k \ge 0} \frac{f_k(x_1, x_2,..,x_N)x^k}{k!} = \sum\limits_{j_1+j_2+..+j_n = K, j_i\%2 == B_i} \frac{(p_1x_1^{a_1}x)^{j_1}}{j_1!}..\frac{(p_Nx_N^{a_N}x)^{j_N}}{j_n!}

The above expression can then be simplified to:

\prod\limits_{i=1}^{N} \frac{(e^{p_ix_i^{a_i}x} + (-1)^{B_i}e^{p_ix_i^{a_i}x})}{2}

First let’s look at the denominator. Putting x_i = 1, we get,

D_k(x) = \prod \left(\frac{e^{p_ix}+(-1)^{B_i}e^{-{p_ix}}}{2}\right)

Note that D_k(x) is a polynomial in e^x, that is D_k(x) = \sum D_{k, i}e^{ix}. This polynomial can be calculated using Dynamic Programming O(N*R). The definition looks like this: dp[i][j] denotes that coefficient of e^{jx} after i terms have been multiplied. Therefore 0 \leq i \leq N and -R \leq j \leq R.

Next we move to the numerator term. For this we need to calculate N partial derivatives. If we calculate all derivatives individually and then use the dp approach used in calculating denominator term, the complexity becomes O(N^2*R) which is too slow. However, we can use a prefix dp approach. Define dp[i][j][0/1] as the coefficient of e^{jx} after i terms have been multiplied. The state 0 shows, that no partial derivative has been taken so far and 1 shows that partial derivative was taken before.
Then, dp[i][j][0] has the same transitions as that of dp[i][j] in denominator term.
For dp[i][j][1], we can move from states dp[i-1][j\pm p_i][0] or dp[i-1][j \pm p_i][1]. This dp can then be calculated in O(N*R).
Now, we have polynomials(in e^x) for both numerator and denominator. Suppose the polynomial for denominator D_k(x) is \sum\limits_{i=-R}^{R} D_{k, i}e^{ix}. The coefficient of \frac{x^k}{k!} in polynomial is \sum\limits_{i=-R}^{R} D_{k, i}i^k. If we calculate this for a single query, it would take O(R*log(k)) which is also slow for last subtask. However, we can see that we need to calculate i^k for only O(R) values. Therefore, we can pre-compute i, i^2,i^3 .. ,i^{\sqrt K}, i^{2 \sqrt K},..,i^{K}, where K = 4 \cdot 10^8 and then calculate i^k in O(1) using these pre-computed values and thus we can get rid of that log(k) factor too. This optimization is similar to the Baby Step Giant Step Algorithm.
The time complexity of this pre-computation is O(R \sqrt K) and the complexity of answering each query would be O(R).

Overall Complexity : O((N + Q + \sqrt K)*R)

SOLUTIONS:

Setter's & Editorialist Solution
#include <bits/stdc++.h>
using namespace std;
#define mod 998244353
#define sumpi 2000
#define mxk 20000
using ll = long long;
int pws[sumpi+1][2*mxk];
int mypower(int a, int b)
{
    int ans = 1;
    while(b)
    {
        if(b%2)
            ans = ans*1ll*a%mod;
        b=b/2;
        a=a*1ll*a%mod;
    }
    return ans;
}
void pre()
{
    int i, j;
    for(i=1;i<=sumpi;i++)
    {
        pws[i][0] = 1;
        pws[i][1] = i;
        for(j=2;j<=mxk;j++)
            pws[i][j] = pws[i][j-1]*1ll*i%mod;
        for(j=mxk+1;j<2*mxk;j++)
            pws[i][j] = pws[i][j-1]*1ll*pws[i][mxk]%mod;
    }
}
int calpw(int i, int k) // calculates i^k in O(1)
{
    //assert(i <= sumpi && i > 0);
    if(k <= mxk)
        return pws[i][k];
    int lhs = k/mxk;
    int lef = k%mxk;
    int ans = pws[i][mxk+lhs-1]*1ll*pws[i][lef]%mod;
    return ans;
}
signed main() {
	// your code goes here
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	pre();
	int n;
	cin >> n;
	int p[n+1], b[n+1], a[n+1];
	int i, j, sum = 0, odd = 0;
	for(i=1;i<=n;i++)
	{
	    cin >> a[i] >> p[i] >> b[i];
	    odd+=b[i];
	    sum+=p[i];
	}
	// parity of odd and k should be same
	ll dp[n+1][2*sum+1][2]; 
	// (e^pi*y + (-1)^b[i]*e^-pi*y) , b[i] = 0 for even, 1 for odd
	// pi*ai*y(e^pi*y + (-1)^(b[i]+1)*e^-pi*y) , b[i] = 0 for even, 1 for odd 
	for(i=0;i<=n;i++)
	    for(j=0;j<=2*sum;j++)
	        dp[i][j][0] = dp[i][j][1] = 0;
	dp[0][sum][0] = 1;
	for(i=1;i<=n;i++)
	    for(j=-sum;j<=sum;j++)
	    {
	        // x^(sum + j1 + pi) + x^(sum + j2 - pi), j1 = j - pi, j2 = j + pi
	        int pi = p[i], ai = a[i];
	        if(j+sum-pi >= 0 && j-pi <= sum)
	        {
	            dp[i][j+sum][0] = dp[i-1][j+sum-pi][0];
	            dp[i][j+sum][1] = dp[i][j+sum][1] + dp[i-1][j+sum-pi][1];
	            dp[i][j+sum][1] = dp[i][j+sum][1] + (dp[i-1][j+sum-pi][0]*(pi*1ll*ai%mod))%mod;
	        }
	        if(j+sum+pi >= 0 && j+pi <= sum)
	        {
	            dp[i][j+sum][0] = dp[i][j+sum][0] + dp[i-1][j+sum+pi][0]*(b[i] == 0 ? 1 : -1);
	            dp[i][j+sum][1] = dp[i][j+sum][1] + dp[i-1][j+sum+pi][1]*(b[i] == 0 ? 1 : -1);
	            ll m = dp[i-1][j+sum+pi][0]*(pi*1ll*ai%mod)*(b[i] == 0 ? -1 : 1);
	            dp[i][j+sum][1] = dp[i][j+sum][1] + m%mod;
	        }
	        dp[i][j+sum][0] = dp[i][j+sum][0]%mod;
	        dp[i][j+sum][1] = dp[i][j+sum][1]%mod;
	    }
// 	for(i=0;i<=2*sum;i++)
// 	    cout << dp[n][i][0] << " ";
// 	cout << '\n';
// 	for(i=0;i<=2*sum;i++)
// 	    cout << dp[n][i][1] << " ";
// 	cout << '\n';
    vector<pair<int, int>> prob, expec;
    for(i=sum+1;i<=2*sum;i++)
    {
        if(dp[n][i][0] != 0)
        {
            //assert(dp[n][i][1] != 0);
            int diff = i - sum;
            prob.push_back({(dp[n][i][0] + (odd%2 == 0 ? 1 : -1)*dp[n][2*sum-i][0])%mod, diff});
        }
        if(dp[n][i][1] != 0)
        {
            int diff = i - sum;
            expec.push_back({(dp[n][i][1] + (odd%2 == 0 ? -1 : 1)*dp[n][2*sum-i][1])%mod, diff});
        }
    }
	//cout << prob.size() << " " << expec.size() << '\n';
    int q;
    cin >> q;
	//q = 0;
    while(q--)
    {
        int k;
        cin >> k;
        //assert(k > 1);
        if(k%2 != odd%2 || k < odd)
        {
            cout << -1 << '\n';
            continue;
        }
        ll num = 0, den = 0;
        for(auto ele : expec)
            num = (num + ele.first*1ll*calpw(ele.second, k-1)%mod);
        for(auto ele : prob)
            den = (den + ele.first*1ll*calpw(ele.second, k)%mod);
        if(k == 1)
            num = (num + dp[n][sum][1]);
		num=num%mod;
		den=den%mod;
        num = num*1ll*k%mod;
        //cout << num << " " << den << '\n';
        assert(den != 0);
        num = num*1ll*mypower(den, mod-2)%mod;
        num = (num + mod)%mod;
        cout << num << '\n';
    }
	return 0;
}
Tester's Solution
/* in the name of Anton */
 
/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/
 
#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
    //#pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #define dbg(args...) 42;
#endif
 
using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"
 
typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;
 
const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}
 
long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}
 
void readEOF(){
    assert(getchar()==EOF);
}
 
vi readVectorInt(lli l,lli r,int n){
    vi a(n);
    for(int i=0;i<n-1;++i)
        a[i]=readIntSp(l,r);
    a[n-1]=readIntLn(l,r);
    return a;
}
 
const lli INF = 0xFFFFFFFFFFFFFFFL;
 
lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}
 
class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};
 
void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}
 
void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}
 
bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}
 
const lli mod = 1000000007L;
// const lli maxN = 1000000007L;
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
 
#include <utility>
 
#ifdef _MSC_VER
#include <intrin.h>
#endif
 
namespace atcoder {
 
namespace internal {
 
constexpr long long safe_mod(long long x, long long m) {
    x %= m;
    if (x < 0) x += m;
    return x;
}
 
struct barrett {
    unsigned int _m;
    unsigned long long im;
 
    barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
 
    unsigned int umod() const { return _m; }
 
    unsigned int mul(unsigned int a, unsigned int b) const {
 
        unsigned long long z = a;
        z *= b;
#ifdef _MSC_VER
        unsigned long long x;
        _umul128(z, im, &x);
#else
        unsigned long long x =
            (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
        unsigned int v = (unsigned int)(z - x * _m);
        if (_m <= v) v += _m;
        return v;
    }
};
 
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
    if (m == 1) return 0;
    unsigned int _m = (unsigned int)(m);
    unsigned long long r = 1;
    unsigned long long y = safe_mod(x, m);
    while (n) {
        if (n & 1) r = (r * y) % _m;
        y = (y * y) % _m;
        n >>= 1;
    }
    return r;
}
 
constexpr bool is_prime_constexpr(int n) {
    if (n <= 1) return false;
    if (n == 2 || n == 7 || n == 61) return true;
    if (n % 2 == 0) return false;
    long long d = n - 1;
    while (d % 2 == 0) d /= 2;
    constexpr long long bases[3] = {2, 7, 61};
    for (long long a : bases) {
        long long t = d;
        long long y = pow_mod_constexpr(a, t, n);
        while (t != n - 1 && y != 1 && y != n - 1) {
            y = y * y % n;
            t <<= 1;
        }
        if (y != n - 1 && t % 2 == 0) {
            return false;
        }
    }
    return true;
}
template <int n> constexpr bool is_prime = is_prime_constexpr(n);
 
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
    a = safe_mod(a, b);
    if (a == 0) return {b, 0};
 
    long long s = b, t = a;
    long long m0 = 0, m1 = 1;
 
    while (t) {
        long long u = s / t;
        s -= t * u;
        m0 -= m1 * u;  // |m1 * u| <= |m1| * s <= b
 
 
        auto tmp = s;
        s = t;
        t = tmp;
        tmp = m0;
        m0 = m1;
        m1 = tmp;
    }
    if (m0 < 0) m0 += b / s;
    return {s, m0};
}
 
constexpr int primitive_root_constexpr(int m) {
    if (m == 2) return 1;
    if (m == 167772161) return 3;
    if (m == 469762049) return 3;
    if (m == 754974721) return 11;
    if (m == 998244353) return 3;
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    int x = (m - 1) / 2;
    while (x % 2 == 0) x /= 2;
    for (int i = 3; (long long)(i)*i <= x; i += 2) {
        if (x % i == 0) {
            divs[cnt++] = i;
            while (x % i == 0) {
                x /= i;
            }
        }
    }
    if (x > 1) {
        divs[cnt++] = x;
    }
    for (int g = 2;; g++) {
        bool ok = true;
        for (int i = 0; i < cnt; i++) {
            if (pow_mod_constexpr(g, (m - 1) / divs[i], m) == 1) {
                ok = false;
                break;
            }
        }
        if (ok) return g;
    }
}
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
 
}  // namespace internal
 
}  // namespace atcoder
 
 
#include <cassert>
#include <numeric>
#include <type_traits>
 
namespace atcoder {
 
namespace internal {
 
#ifndef _MSC_VER
template <class T>
using is_signed_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value ||
                                  std::is_same<T, __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int128 =
    typename std::conditional<std::is_same<T, __uint128_t>::value ||
                                  std::is_same<T, unsigned __int128>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using make_unsigned_int128 =
    typename std::conditional<std::is_same<T, __int128_t>::value,
                              __uint128_t,
                              unsigned __int128>;
 
template <class T>
using is_integral = typename std::conditional<std::is_integral<T>::value ||
                                                  is_signed_int128<T>::value ||
                                                  is_unsigned_int128<T>::value,
                                              std::true_type,
                                              std::false_type>::type;
 
template <class T>
using is_signed_int = typename std::conditional<(is_integral<T>::value &&
                                                 std::is_signed<T>::value) ||
                                                    is_signed_int128<T>::value,
                                                std::true_type,
                                                std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<(is_integral<T>::value &&
                               std::is_unsigned<T>::value) ||
                                  is_unsigned_int128<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<
    is_signed_int128<T>::value,
    make_unsigned_int128<T>,
    typename std::conditional<std::is_signed<T>::value,
                              std::make_unsigned<T>,
                              std::common_type<T>>::type>::type;
 
#else
 
template <class T> using is_integral = typename std::is_integral<T>;
 
template <class T>
using is_signed_int =
    typename std::conditional<is_integral<T>::value && std::is_signed<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using is_unsigned_int =
    typename std::conditional<is_integral<T>::value &&
                                  std::is_unsigned<T>::value,
                              std::true_type,
                              std::false_type>::type;
 
template <class T>
using to_unsigned = typename std::conditional<is_signed_int<T>::value,
                                              std::make_unsigned<T>,
                                              std::common_type<T>>::type;
 
#endif
 
template <class T>
using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
 
template <class T>
using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
 
template <class T> using to_unsigned_t = typename to_unsigned<T>::type;
 
}  // namespace internal
 
}  // namespace atcoder
 
 
namespace atcoder {
 
namespace internal {
 
struct modint_base {};
struct static_modint_base : modint_base {};
 
template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;
 
}  // namespace internal
 
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;
 
  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};
 
template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;
 
  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }
 
    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }
 
    unsigned int val() const { return _v; }
 
    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }
 
    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }
 
    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }
 
    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }
 
    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }
 
  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt = 998244353;
 
using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;
 
namespace internal {
 
template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;
 
template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;
 
template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};
 
template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;
 
}  // namespace internal
 
}  // namespace atcoder
 
using namespace atcoder;
using mint = modint998244353;
// using mint = modint1000000007;
std::ostream& operator << (std::ostream& out, const mint& rhs) {
        return out<<rhs.val();
    }
 
namespace algebra {
    template <typename T>
vector<T>& operator+=(vector<T>& a, const vector<T>& b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); i++) {
    a[i] += b[i];
  }
  return a;
}
 
template <typename T>
vector<T> operator+(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c += b;
}
 
template <typename T>
vector<T>& operator-=(vector<T>& a, const vector<T>& b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); i++) {
    a[i] -= b[i];
  }
  return a;
}
 
template <typename T>
vector<T> operator-(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c -= b;
}
 
template <typename T>
vector<T> operator-(const vector<T>& a) {
  vector<T> c = a;
  for (int i = 0; i < (int) c.size(); i++) {
    c[i] = -c[i];
  }
  return c;
}
 
template <typename T>
vector<T> operator*(const vector<T>& a, const vector<T>& b) {
  if (a.empty() || b.empty()) {
    return {};
  }
  vector<T> c(a.size() + b.size() - 1, 0);
  for (int i = 0; i < (int) a.size(); i++) {
    for (int j = 0; j < (int) b.size(); j++) {
      c[i + j] += a[i] * b[j];
    }
  }
  return c;
}
 
template <typename T>
vector<T>& operator*=(vector<T>& a, const vector<T>& b) {
  return a = a * b;
}
 
template <typename T>
vector<T> inverse(const vector<T>& a) {
  assert(!a.empty());
  int n = (int) a.size();
  vector<T> b = {1 / a[0]};
  while ((int) b.size() < n) {
    vector<T> a_cut(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    vector<T> x = b * b * a_cut;
    b.resize(b.size() << 1);
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = -x[i];
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T>& operator/=(vector<T>& a, const vector<T>& b) {
  int n = (int) a.size();
  int m = (int) b.size();
  if (n < m) {
    a.clear();
  } else {
    vector<T> d = b;
    reverse(a.begin(), a.end());
    reverse(d.begin(), d.end());
    d.resize(n - m + 1);
    a *= inverse(d);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}
 
template <typename T>
vector<T> operator/(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c /= b;
}
 
template <typename T>
vector<T> operator*(const vector<T>& a, T b) {
  vector<T> c = a;
  for(auto &x:c)
    x*=b;
  return c;
}
 
template <typename T>
vector<T>& operator%=(vector<T>& a, const vector<T>& b) {
  int n = (int) a.size();
  int m = (int) b.size();
  if (n >= m) {
    vector<T> c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; i++) {
      a[i] -= c[i];
    }
  }
  return a;
}
 
template <typename T>
vector<T> operator%(const vector<T>& a, const vector<T>& b) {
  vector<T> c = a;
  return c %= b;
}
 
template <typename T, typename U>
vector<T> power(const vector<T>& a, const U& b, const vector<T>& c) {
  assert(b >= 0);
  vector<U> binary;
  U bb = b;
  while (bb > 0) {
    binary.push_back(bb & 1);
    bb >>= 1;
  }
  vector<T> res = vector<T>{1} % c;
  for (int j = (int) binary.size() - 1; j >= 0; j--) {
    res = res * res % c;
    if (binary[j] == 1) {
      res = res * a % c;
    }
  }
  return res;
}
 
template <typename T>
vector<T> derivative(const vector<T>& a) {
  vector<T> c = a;
  for (int i = 0; i < (int) c.size(); i++) {
    c[i] *= i;
  }
  if (!c.empty()) {
    c.erase(c.begin());
  }
  return c;
}
 
template <typename T>
vector<T> integrate(const vector<T>& a) {
  vector<T> c = {0};
  for (int i = 0; i < (int) a.size(); i++) {
    c.push_back(a[i]/(i+1));
  }
  return c;
}
 
template <typename T>
vector<T> primitive(const vector<T>& a) {
  vector<T> c = a;
  c.insert(c.begin(), 0);
  for (int i = 1; i < (int) c.size(); i++) {
    c[i] /= i;
  }
  return c;
}
 
template <typename T>
vector<T> logarithm(const vector<T>& a) {
  assert(!a.empty() && a[0] == 1);
  vector<T> u = primitive(derivative(a) * inverse(a));
  u.resize(a.size());
  return u;
}
 
template <typename T>
vector<T> exponent(const vector<T>& a) {
  assert(!a.empty() && a[0] == 0);
  int n = (int) a.size();
  vector<T> b = {1};
  while ((int) b.size() < n) {
    vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    x[0] += 1;
    vector<T> old_b = b;
    b.resize(b.size() << 1);
    x -= logarithm(b);
    x *= old_b;
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = x[i];
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T> sqrt(const vector<T>& a) {
  assert(!a.empty() && a[0] == 1);
  int n = (int) a.size();
  vector<T> b = {1};
  while ((int) b.size() < n) {
    vector<T> x(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    b.resize(b.size() << 1);
    x *= inverse(b);
    T inv2 = 1 / static_cast<T>(2);
    for (int i = (int) b.size() >> 1; i < (int) min(x.size(), b.size()); i++) {
      b[i] = x[i] * inv2;
    }
  }
  b.resize(n);
  return b;
}
 
template <typename T>
vector<T> multiply(const vector<vector<T>>& a) {
  if (a.empty()) {
    return {0};
  }
  function<vector<T>(int, int)> mult = [&](int l, int r) {
    if (l == r) {
      return a[l];
    }
    int y = (l + r) >> 1;
    return mult(l, y) * mult(y + 1, r);
  };
  return mult(0, (int) a.size() - 1);
}
 
template <typename T>
T evaluate(const vector<T>& a, const T& x) {
  T res = 0;
  for (int i = (int) a.size() - 1; i >= 0; i--) {
    res = res * x + a[i];
  }
  return res;
}
 
template <typename T>
vector<T> evaluate(const vector<T>& a, const vector<T>& x) {
  if (x.empty()) {
    return {};
  }
  if (a.empty()) {
    return vector<T>(x.size(), 0);
  }
  int n = (int) x.size();
  vector<vector<T>> st((n << 1) - 1);
  function<void(int, int, int)> build = [&](int v, int l, int r) {
    if (l == r) {
      st[v] = vector<T>{-x[l], 1};
    } else {
      int y = (l + r) >> 1;
      int z = v + ((y - l + 1) << 1);
      build(v + 1, l, y);
      build(z, y + 1, r);
      st[v] = st[v + 1] * st[z];
    }
  };
  build(0, 0, n - 1);
  vector<T> res(n);
  function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
    f %= st[v];
    if ((int) f.size() < 150) {
      for (int i = l; i <= r; i++) {
        res[i] = evaluate(f, x[i]);
      }
      return;
    }
    if (l == r) {
      res[l] = f[0];
    } else {
      int y = (l + r) >> 1;
      int z = v + ((y - l + 1) << 1);
      eval(v + 1, l, y, f);
      eval(z, y + 1, r, f);
    }
  };
  eval(0, 0, n - 1, a);
  return res;
}
 
template <typename T>
vector<T> interpolate(const vector<T>& x, const vector<T>& y) {
  if (x.empty()) {
    return {};
  }
  assert(x.size() == y.size());
  int n = (int) x.size();
  vector<vector<T>> st((n << 1) - 1);
  function<void(int, int, int)> build = [&](int v, int l, int r) {
    if (l == r) {
      st[v] = vector<T>{-x[l], 1};
    } else {
      int w = (l + r) >> 1;
      int z = v + ((w - l + 1) << 1);
      build(v + 1, l, w);
      build(z, w + 1, r);
      st[v] = st[v + 1] * st[z];
    }
  };
  build(0, 0, n - 1);
  vector<T> m = st[0];
  vector<T> dm = derivative(m);
  vector<T> val(n);
  function<void(int, int, int, vector<T>)> eval = [&](int v, int l, int r, vector<T> f) {
    f %= st[v];
    if ((int) f.size() < 150) {
      for (int i = l; i <= r; i++) {
        val[i] = evaluate(f, x[i]);
      }
      return;
    }
    if (l == r) {
      val[l] = f[0];
    } else {
      int w = (l + r) >> 1;
      int z = v + ((w - l + 1) << 1);
      eval(v + 1, l, w, f);
      eval(z, w + 1, r, f);
    }
  };
  eval(0, 0, n - 1, dm);
  for (int i = 0; i < n; i++) {
    val[i] = y[i] / val[i];
  }
  function<vector<T>(int, int, int)> calc = [&](int v, int l, int r) {
    if (l == r) {
      return vector<T>{val[l]};
    }
    int w = (l + r) >> 1;
    int z = v + ((w - l + 1) << 1);
    return calc(v + 1, l, w) * st[z] + calc(z, w + 1, r) * st[v + 1];
  };
  return calc(0, 0, n - 1);
}
 
// f[i] = 1^i + 2^i + ... + up^i
template <typename T>
vector<T> faulhaber(const T& up, int n) {
  vector<T> ex(n + 1);
  T e = 1;
  for (int i = 0; i <= n; i++) {
    ex[i] = e;
    e /= i + 1;
  }
  vector<T> den = ex;
  den.erase(den.begin());
  for (auto& d : den) {
    d = -d;
  }
  vector<T> num(n);
  T p = 1;
  for (int i = 0; i < n; i++) {
    p *= up + 1;
    num[i] = ex[i + 1] * (1 - p);
  }
  vector<T> res = num * inverse(den);
  res.resize(n);
  T f = 1;
  for (int i = 0; i < n; i++) {
    res[i] *= f;
    f *= i + 1;
  }
  return res;
}
 
// (x + 1) * (x + 2) * ... * (x + n)
// (can be optimized with precomputed inverses)
template <typename T>
vector<T> sequence(int n) {
  if (n == 0) {
    return {1};
  }
  if (n % 2 == 1) {
    return sequence<T>(n - 1) * vector<T>{n, 1};
  }
  vector<T> c = sequence<T>(n / 2);
  vector<T> a = c;
  reverse(a.begin(), a.end());
  T f = 1;
  for (int i = n / 2 - 1; i >= 0; i--) {
    f *= n / 2 - i;
    a[i] *= f;
  }
  vector<T> b(n / 2 + 1);
  b[0] = 1;
  for (int i = 1; i <= n / 2; i++) {
    b[i] = b[i - 1] * (n / 2) / i;
  }
  vector<T> h = a * b;
  h.resize(n / 2 + 1);
  reverse(h.begin(), h.end());
  f = 1;
  for (int i = 1; i <= n / 2; i++) {
    f /= i;
    h[i] *= f;
  }
  vector<T> res = c * h;
  return res;
}
 
template <typename T>
class OnlineProduct {
 public:
  const vector<T> a;
  vector<T> b;
  vector<T> c;
 
  OnlineProduct(const vector<T>& a_) : a(a_) {}
 
  T add(const T& val) {
    int i = (int) b.size();
    b.push_back(val);
    if ((int) c.size() <= i) {
      c.resize(i + 1);
    }
    c[i] += a[0] * b[i];
    int z = 1;
    while ((i & (z - 1)) == z - 1 && (int) a.size() > z) {
      vector<T> a_mul(a.begin() + z, a.begin() + min(z << 1, (int) a.size()));
      vector<T> b_mul(b.end() - z, b.end());
      vector<T> c_mul = a_mul * b_mul;
      if ((int) c.size() <= i + (int) c_mul.size()) {
        c.resize(i + c_mul.size() + 1);
      }
      for (int j = 0; j < (int) c_mul.size(); j++) {
        c[i + 1 + j] += c_mul[j];
      }
      z <<= 1;
    }
    return c[i];
  }
};
};
 
using namespace algebra;
 
    lli T,n,i,j,k,in,cnt,l,r,u,v,x,y;
    lli m,c;
    string s;
    vector<vi> a;
    const mint inv2=mint(1)/2;
    vector<mint> numerator,denominator;
    //priority_queue < ii , vector < ii > , CMP > pq;// min priority_queue .
 
const int BITS=10;
const lli BT=(1<<BITS)-1;
vector<vector<mint>> p0,p1,p2;
void precompute_powers()
{
    p0.resize(BT+1,vector<mint>(c,1));
    p2=p1=p0;
    for(int j=0;j+1<=BT;++j)
    for(int i=0;i<c;++i)
        p0[j+1][i]=p0[j][i]*(i+1);
 
    for(int j=0;j+1<=BT;++j)
    for(int i=0;i<c;++i)
        p1[j+1][i]=p1[j][i]*p0[BT][i]*p0[1][i];
 
    for(int j=0;j+1<=BT;++j)
    for(int i=0;i<c;++i)
        p2[j+1][i]=p2[j][i]*p1[BT][i]*p1[1][i];
}
 
void preComputations()
{
    auto getPoly=[&](const mint x){
        vector<mint> pref(n+1,1),suf(n+1,1),px(n);
        for(int i=0;i<n;++i)
            px[i]=x.pow(2*a[i][1]);
        for(int i=0;i<n;++i)
        {
            const lli ai=a[i][0],pi=a[i][1],b=a[i][2];
            mint cur=1;
            if(b&1)
                cur*=px[i]-1;
            else
                cur*=px[i]+1;
            pref[i+1]=cur*pref[i];
        }
 
        for(int i=n-1;i>=0;--i)
        {
            const lli ai=a[i][0],pi=a[i][1],b=a[i][2];
            mint cur=1;
            if(b&1)
                cur*=px[i]-1;
            else
                cur*=px[i]+1;
            suf[i]=cur*suf[i+1];
        }
 
        mint num=0,den=suf[0];
        for(int i=0;i<n;++i)
        {
            const lli ai=a[i][0],pi=a[i][1],b=a[i][2];
            mint cur=pref[i]*suf[i+1];
            cur*=ai;
            cur*=pi;
            if(b&1)
                cur*=px[i]+1;
            else
                cur*=px[i]-1;
            num+=cur;
        }
        return make_pair(num,den);
    };
 
    vector<mint> x,polynum,polyden;
    for(int i=2;i<=2*c+2;++i)
    {
        x.pb(mint(i));
        const auto ans=getPoly(i);
        polynum.pb(ans.X);
        polyden.pb(ans.Y);
    }
 
    numerator=interpolate(x,polynum);
    denominator=interpolate(x,polyden);
    dbg(numerator);
    dbg(denominator);
}
 
void get_powers(vector<mint> &a,lli k)
{
    a.resize(c);
    for(int i=0;i<c;++i)
        a[i]=p0[k&BT][i]*p1[(k>>BITS)&BT][i]*p2[(k>>(2*BITS))&BT][i];
}
 
vector<mint> aa;
mint solve(lli k)
{
    mint num=0,den=0,sg=0;
    if(k&1)
        sg=-1;
    else
        sg=1;
    get_powers(aa,k-1);
    for(int i=1;i<=c;++i)
    {
        num+=(numerator[i+c]-sg*numerator[-i+c])*aa[i-1];
        den+=(denominator[i+c]+sg*denominator[-i+c])*aa[i-1]*i;
    }
    return num*k/den;
}
 
void testpowers()
{
    c=5;
    k=1<<(2*BITS);
    precompute_powers();
    vector<mint> aa;
    get_powers(aa,k);
    dbg(aa);
    dbg(p2[1][1],mint(2).pow(k),(k>>(2*BITS))&BT,(k>>BITS)&BT,k&BT,k,BT);
    dbg(aa[1]);
    exit(0);
}
 
int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
    // testpowers();
    n=readIntLn(1,2000);
    a.resize(n,vi(3));
    lli kmin=0;
    mint k1=0;
    for(auto &v:a)
    {
        v[0]=readIntSp(1,4e8);
        v[1]=readIntSp(0,2e3);
        v[2]=readIntLn(0,1);
        c+=v[1];
        kmin+=v[2];
        k1+=v[2]*v[0];
    }
    assert(1<=c&&c<=2000);
    precompute_powers();
    preComputations();
 
    lli q=readIntLn(1,2e4);
    while(q--)
    {
        k=readIntLn(kmin,4e8);
        assert(kmin%2==k%2);
        if(k==1)
        {
          cout<<k1<<endl;
          // cerr<<k1<<endl;
          continue;
        }
        cout<<solve(k)<<endl;
    }
    aryanc403();
    cerr<<"Completed."<<endl;
    readEOF();
    return 0;
}
2 Likes

Final expression is much simpler.

For N = 3
B = (p0 + p1 + p2)^k + (p0 + p1 - p2)^k + (p0 - p1 + p2)^k + (p0 -p1 - p2)^k
A = ((p0 + p1 + p2)^(k-1)) * (a0p0 + a1p1 + a2p2) + ((p0 + p1 - p2)^(k-1)) * (a0p0 + a1p1 - a2p2) + …

Answer = A/B

Here ((p0 + p1 + p2…), (p0+p1-p2…), …) series appears to take 2^N terms but it will be much smaller because (P0 + p1+p2…) is bounded by 2000. So we can combine multiple such terms.

Only thing left is p_sum^k for each of the p_sum in that series for each k.

We can do some preprocessing for each of the p_sum to compute p_sum^k efficiently. So I created this PowerFactory for each base.

struct PowerFactory {
  long get(int n) const {  // returns base^n
    int e0 = (n & 0xff);
    int e1 = ((n >> 8) & 0xff);
    int e2 = ((n >> 16) & 0xff);
    int e3 = ((n >> 24) & 0xff);
    return (((((a0[e0] * a1[e1]) % modp) * a2[e2] ) % modp) * a3[e3] ) % modp;
  }
  // a0[i] stores base^i
  long a0[256];
  // a1[i] stores (base^i)^(2^8)
  long a1[256];
  // a2[i] stores (base^i)^(2^16)
  long a2[256];
  // a3[i] stores (base^i)^(2^24)
  long a3[256];
};

That’s all. Also you will need to take care of sign (+1 or -1) for each p_sum term. If inside a p_sum term, there are T number of negative term and among them B[i] is 1 for E of them, then this term should be multiplied by ( E%2 == 0 ? 1 : -1).

@samarth2017 , :+1: , Nice editorial , After trying many times i finally understand the editorial.

If possible please elaborate on this line ,