CSTREE - Editorial

PROBLEM LINK:

Div1
Div2
Practice

Setter- Lewin Gan
Tester- Suchan Park
Editorialist- Abhishek Pandey

DIFFICULTY:

HARD

PRE-REQUISITES:

Kirchoff’s Theorem , Cayley’s Formula (Generalized) (Also, here ) , Polynomial Interpolation, Some command over algebra and its interpretations, Inclusion-Exclusion

Additionally, you should refer to following resources to solve the problem-

PROBLEM:

Given a graph G, we make a new graph G_b by making K copies of G. We call the complement of this new graph(i.e. G_b) as H. Find number of spanning trees in H.

QUICK-EXPLANATION:

Key to AC- Notice the pattern in the degree matrix, or apply generalized Cayley’s formula with inclusion exclusion.

Notice that the graph H is very similar to a complete graph K_n with N*K nodes, except for missing some edges. Lets call such an edge as forbidden edges. We will use inclusion-exclusion principle and get answer by the concept-
Ans= Spanning \ Trees \ in \ K_n -Spanning \ Tree \ using \ 1 \ forbidden \ edge + Spanning \ trees \ using \ 2 \ forbidden \ edges-.... .

Lets define T=t_1+t_2+..+t_K where t_i is the number of forbidden edges used in i'th copy of the graph. In the i'th copy of the graph, say considering only t_i forbidden edges, we have N-t_i connected components of size s_{i,1},s_{i,2},...,s_{i,N-t_i}. Using the generalized Cayley’s formula, we obtain-

Ans= \sum (\sum_{i=1}^{i=K} (\prod_{j=1}^{j=K-t_i}s_{i,j})) * (K*N)^{(K*N-T-2)} over all valid ways to split sizes of each component (and over all values of T).

To calculate (\sum_{i=1}^{i=K} (\prod_{j=1}^{j=N-t_i}s_{i,j})), we will find the term for 1 component (i.e. G_s) and we can then use polynomial multiplication to find solution for K copies of that component.

To do so, create the Laplacian Matrix of G_s. If we add some value x to the diagonal of it, then we get a polynomial. In this polynomial, the coefficient of x^i counts a particular i spanning forest \prod s_i times (where s_i s size of i'th component). Which is what we need, but for G_b. If we can get this polynomial for G_b, then our answer would be-

Ans=\sum Coefficient_i*(-1)^i * (K*N)^{(K*N-i-2)}) for all i.

Now all that is left is applying clever tricks to get the answer. The easiest way to obtain the polynomial will be to find its values at, say x=0,1,..,N-1 and then interpolate them to find the coefficient. This is the polynomial for G_s. If you see the research papers in the pre-requisites, you’d see that the corresponding polynomial for G_b is nothing but (polynomial \ for \ G_s)^K.

However, still, this polynomial will have N*K terms on exponentiating to the K'th power. Hence, the final trick to get full score in this question is to, realize we can reform the formula for answer as-

Ans=\sum Coefficient_i*(-1)^i * (K*N)^{(K*N-i-2)}=\sum Coefficient_i*(\frac {-1} {K*N})^i * (K*N)^{K*N-2} for all i.
What we did was just split (K*N)^{(K*N-i-2)} to (K*N)^{K*N-2} * (K*N)^{-i} and combined the latter term with (-1)^i. Notice that, now, finding the sum of all coefficients would be equal to finding the value of this polynomial at x=-1/(N*K). Hence, putting in the values and exponentiating is the final trick to obtain full score.

EXPLANATION:

Since some people would want to practice on how to search for relevant papers online and how to derive the answers from there, the appropriate research papers are given in Pre-requisites, along with tester’s method of reading and concluding from them so that you guys can match your inferences. (Its given in the bonus section under Tester’s Notes).

This editorial will chiefly focus on idea of the problem, mainly because for stuff like polynomial interpolation etc. one can use public libraries (see setter and tester’s solution).

The plan for the editorial is as follows-

  • Discuss Cayley’s Formula
  • Discuss Inclusion Exclusion Principle
  • The polynomial interpretation of the problem
  • How to finally calculate the answer.

There are quite some observations from research paper etc. , so giving proofs for each and every one of them is difficult. I will try to give an informal intuition where ever possible nevertheless , but proofs are lengthy and difficult :frowning: .

Cayley's Formula (generalized)

The very first observation one makes about this question is that the graph G_b, in one way or the other, resembles the complete graph K_n with N*K nodes. Its missing some edges but thats it.

Also, is it a co-incidence that we already know (by Cayley’s Formula) that number of spanning trees of a complete graph, K_n, with N nodes is equal to N^{N-2}. Nopes, this ain’t a co-incidence!

It turns out that, there is a generalized version of Cayley’s formula at our disposal as well.

If you see the link at pre-requisites, it beautifully discusses the problem of finding the number of trees spanning trees such that the given set of edges is used. It proceeds to describe how to use generalized Cayley’s formula there.

Since what we are going to use in the future section of the editorial is similar, lets discuss it a little.

What he essentially did there was, to combine all nodes connected by these “important” edges into a supernode. If after this, there were K connected components, then its answer became-

Number \ of \ Spanning \ Trees=N^{K-2}*s_0*s_1*..*s_{K-1}

The exact proof and derivation on how and why its correct is not needed. But the inference is important. If we have K such super nodes (or connected components) in K_n, then the number of ways to get spanning trees in such a case becomes N^{K-2}*\prod_{i=0}^{i=K-1} s_i.

This will form the basis of the inclusion exclusion formula used in next section.

Inclusion-Exclusion

Note that in previous section we discussed a little about the problem where we wanted to find number of spanning trees of K_n such that a particular set of edges is used. Also, we know that for a complete graph K_n with N nodes, the number of spanning trees is N^{N-2}.

Now, our graph H is also a complete graph with some edges missing. Had those edges been there, our graph H would have also been a complete graph K_n, the answer for which is easy to calculate.

Now, what if I propose the following-

Take the complete graph with N*K nodes. Now, there is a set of edges which we need to remove from this complete graph to get H. Lets call this set of edges as “forbidden edges”. We know the number of spanning trees for this complete graph with N*K nodes, and if we subtract the number of spanning trees with are using at least 1 forbidden edge from it, then we are done! But this is not easy to calculate, and this is were inclusion exclusion comes to play.

Suppose that I am using T forbidden edges as of now. To make things easier, lets split T into T=t_1+t_2+..+t_K where t_i is number of forbidden edges used from i'th copy of G_s (note that edges in G_b are the forbidden edges). Now, roughly, I can say that my answer would be something like-

Ans= Spanning \ Trees \ in \ K_n -Spanning \ Tree \ using \ 1 \ forbidden \ edge + Spanning \ trees \ using \ 2 \ forbidden \ edges-Spanning \ Tree \ using \ 3 \ forbidden \ edge .... .

If I am using t_i edges in the i'th copy, I get N-t_i connected components (or super nodes), of size s_{i,1},s_{i,2},...,s_{i,K}.

Applying Cayley’s formula to above, we see that our answer for this configuration will be of form-

Ans= \sum (\sum_{i=1}^{i=K} (\prod_{j=1}^{j=K-t_i}s_{i,j})) * (K*N)^{(K*N-T-2)} over all valid ways to split the sizes of each component. Note that, N*K correspond to total number of nodes in H, and K*N-T correspond to number of connected components. Hence, all we did was to substitute Total \ Nodes= N*K and Number \ of \ Connected \ Components = N*K-T in the generalized Cayley’s formula discussed in previous section.

The polynomial interpretation of the problem.

Now, the first major hurdle in our way is to calculate \sum \prod s_{i,j} over all ways to split the sizes of each components. One of the ways to find such counts or sums is by expressing it as coefficient of some x^i of some polynomial (Eg- Generating Functions).

But how to deduce such a polynomial?

Look at the general Matrix Tree Theorem. Lets make the Laplacian Matrix of the graph G_s. Now, it is known that if I have a set of indices, say A, such that I delete the i'th row and column for every i \in A , I will get the number of forests of the original graph with every node in A belonging to a separate component.

Hence, if we add a variable x in all diagonal elements, then the determinant will be some polynomial of x. The special thing to note is, that, the coefficient x^i of this polynomial counts a particular i spanning forest \prod s_i times (where s_i s size of i'th component). Which is what we need, but for G_b (as we need to get count of spanning trees with forbidden edges).

Since our graph G_b is nothing but K copies of the graph G_s, we find that exponentiating the polynomial for G_s to K'th power yields the polynomial for G_b. This observation can also be asserted from the research papers mentioned in pre-requisites.

In this case, we can rephrase our formula as-

Ans=\sum Coefficient_i*(-1)^i * (K*N)^{(K*N-i-2)} over all i. The only difference between the previous formula, and this formula is that we substituted \sum \prod s_{i,j} = Coefficient_i

This finishes the chief idea of the question. But there are still some things to see-

  • We just said that “lets add x to all diagonal elements of…” . How to exactly find the polynomial?
  • If we found the polynomial, how to exponentiate it to K'th power to calculate the final answer?

These aspects will be answered in the next section.

Calculating the Final Answer

So far we saw that the answer is equal to-

Ans=\sum Coefficient_i*(-1)^i * (K*N)^{(K*N-i-2)} over all i.

We also have a way of determining the required polynomial for G_b. But this will not be enough for the final subtask. Because N*K can be pretty large, we cannot go on summing up each and ever term. Whats more, we still do not have a concrete way of getting the required polynomial for G_s.

Realize that the polynomial for G_s can be of atmost degree N (Why?). Hence, we can find the value of polynomial for x=0,1,...,N-1 and interpolate to find the coefficients. This helps us find the polynomial.

But exponentiating it to power of K is still a big deal, especially considering that it will have a total of N*K terms. To address that, we can use the following trick.

Note that, Ans=\sum Coefficient_i*(-1)^i * (K*N)^{(K*N-i-2)} over all i

Split the term of (K*N)^{K*N-i-2}=(K*N)^{K*N-2}*(K*N)^{-i}, and combine the term of (K*N)^{-i} with the (-1)^i term.

This yields our required value to-

Ans=\sum Coefficient_i*(\frac {-1} {K*N})^i * (K*N)^{K*N-2}.

Here, Coefficient_i represents the coefficients and rest of the terms are something by which we are multiplying the coefficient with.

Since (K*N)^{K*N-2} does not depend on i, lets take it out to get-

Ans=(K*N)^{K*N-2} * \sum Coefficient_i*(\frac {-1} {K*N})^i .

Realize that, if we put x=-1/(N*K), and then evaluate the polynomial, we will get \sum Coefficient_i*(\frac {-1} {K*N})^i . Multiplying it with the other term finally yields us the required answer! :smiley:

SOLUTION

Setter
import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.BufferedWriter;
import java.io.Writer;
import java.io.OutputStreamWriter;
import java.util.InputMismatchException;
import java.io.IOException;
import java.io.InputStream;
 
/**
 * Built using CHelper plug-in
 * Actual solution is at the top
 *
 * @author lewin
 */
public class Main {
    public static void main(String[] args) {
        InputStream inputStream = System.in;
        OutputStream outputStream = System.out;
        InputReader in = new InputReader(inputStream);
        OutputWriter out = new OutputWriter(outputStream);
        CSTree2 solver = new CSTree2();
        int testCount = Integer.parseInt(in.next());
        for (int i = 1; i <= testCount; i++)
            solver.solve(i, in, out);
        out.close();
    }
 
    static class CSTree2 {
        public static int mod = 998244353;
 
        public void solve(int testNumber, InputReader in, OutputWriter out) {
            int n = in.nextInt(), m = in.nextInt(), k = in.nextInt();
            long[][] tmat = new long[n][n];
 
            for (int i = 0; i < m; i++) {
                int a = in.nextInt() - 1, b = in.nextInt() - 1;
                tmat[a][b]--;
                tmat[b][a]--;
                tmat[a][a]++;
                tmat[b][b]++;
            }
            long[] poly = new long[n + 1];
            long[][] r = new long[n][n];
            for (int i = 0; i <= n; i++) {
                for (int j = 0; j < n; j++) System.arraycopy(tmat[j], 0, r[j], 0, n);
                for (int j = 0; j < n; j++) r[j][j] += i;
                poly[i] = Determinant.det(r, mod);
            }
            poly = Polynomial.interpolation(poly, mod);
            long c = (long) n * k;
            long ans = Utils.mod_exp(Polynomial.eval(poly, mod - n * k, mod), k, mod) * Utils.inv(c * c % mod, mod) % mod;
            if ((n * k) % 2 == 1) ans = (mod - ans) % mod;
            out.println(ans);
        }
 
    }
 
    static class Polynomial {
        public static long[] interpolation(long[] y, int mod) {
            int n = y.length;
            long[] a = new long[n], b = new long[n];
            a[0] = y[0];
            b[0] = 1;
            long[] inv = new long[n + 1];
            inv[1] = 1;
            for (int i = 2; i <= n; i++) {
                inv[i] = (mod - mod / i) * inv[mod % i] % mod;
            }
            for (int i = 1; i < n; i++) {
                for (int j = 0; j + i < n; j++) y[j] = (y[j + 1] - y[j] + mod) % mod * inv[i] % mod;
                for (int j = i; j > 0; j--) b[j] = (b[j - 1] + (mod - i + 1) * b[j]) % mod;
                b[0] = b[0] * (mod - i + 1) % mod;
                for (int j = 0; j <= i; j++) {
                    a[j] = (a[j] + b[j] * y[0]) % mod;
                }
            }
            return a;
        }
 
        public static long eval(long[] y, long pos, long mod) {
            long res = 0;
            for (int i = y.length - 1; i >= 0; i--) {
                res = (res * pos + y[i]) % mod;
            }
            return res;
        }
 
    }
 
    static class InputReader {
        private InputStream stream;
        private byte[] buf = new byte[1 << 16];
        private int curChar;
        private int numChars;
 
        public InputReader(InputStream stream) {
            this.stream = stream;
        }
 
        public int read() {
            if (this.numChars == -1) {
                throw new InputMismatchException();
            } else {
                if (this.curChar >= this.numChars) {
                    this.curChar = 0;
 
                    try {
                        this.numChars = this.stream.read(this.buf);
                    } catch (IOException var2) {
                        throw new InputMismatchException();
                    }
 
                    if (this.numChars <= 0) {
                        return -1;
                    }
                }
 
                return this.buf[this.curChar++];
            }
        }
 
        public int nextInt() {
            int c;
            for (c = this.read(); isSpaceChar(c); c = this.read()) {
                ;
            }
 
            byte sgn = 1;
            if (c == 45) {
                sgn = -1;
                c = this.read();
            }
 
            int res = 0;
 
            while (c >= 48 && c <= 57) {
                res *= 10;
                res += c - 48;
                c = this.read();
                if (isSpaceChar(c)) {
                    return res * sgn;
                }
            }
 
            throw new InputMismatchException();
        }
 
        public String next() {
            int c;
            while (isSpaceChar(c = this.read())) {
                ;
            }
 
            StringBuilder result = new StringBuilder();
            result.appendCodePoint(c);
 
            while (!isSpaceChar(c = this.read())) {
                result.appendCodePoint(c);
            }
 
            return result.toString();
        }
 
        public static boolean isSpaceChar(int c) {
            return c == 32 || c == 10 || c == 13 || c == 9 || c == -1;
        }
 
    }
 
    static class OutputWriter {
        private final PrintWriter writer;
 
        public OutputWriter(OutputStream outputStream) {
            writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream)));
        }
 
        public OutputWriter(Writer writer) {
            this.writer = new PrintWriter(writer);
        }
 
        public void close() {
            writer.close();
        }
 
        public void println(long i) {
            writer.println(i);
        }
 
    }
 
    static class Determinant {
        public static long det(long[][] c, int mod) {
            int n = c.length;
            long res = 1;
            for (int p = 0; p < n; p++) {
                int pi = p;
                while (pi < n && c[pi][p] == 0) pi++;
                if (pi == n) return 0;
                if (p != pi) {
                    res = (mod - res) % mod;
                    long[] t1 = c[pi];
                    c[pi] = c[p];
                    c[p] = t1;
                }
                res = res * c[p][p] % mod;
                for (int i = p + 1; i < n; i++) {
                    c[i][p] = c[i][p] * inv(c[p][p], mod) % mod;
                    for (int j = p + 1; j < n; j++) {
                        c[i][j] -= c[p][j] * c[i][p];
                        c[i][j] %= mod;
                        if (c[i][j] < 0) c[i][j] += mod;
                    }
                }
            }
            return res;
        }
 
        public static long pow(long b, long e, long mod) {
            long r = 1;
            while (e > 0) {
                if ((e & 1) == 1) r = r * b % mod;
                b = b * b % mod;
                e >>= 1;
            }
            return r;
        }
 
        public static long inv(long a, long mod) {
            return pow(a, mod - 2, mod);
        }
 
    }
 
    static class Utils {
        public static long inv(long N, long M) {
            long x = 0, lastx = 1, y = 1, lasty = 0, q, t, a = N, b = M;
            while (b != 0) {
                q = a / b;
                t = a % b;
                a = b;
                b = t;
                t = x;
                x = lastx - q * x;
                lastx = t;
                t = y;
                y = lasty - q * y;
                lasty = t;
            }
            return (lastx + M) % M;
        }
 
        public static long mod_exp(long b, long e, long mod) {
            long res = 1;
            while (e > 0) {
                if ((e & 1) == 1)
                    res = (res * b) % mod;
                b = (b * b) % mod;
                e >>= 1;
            }
            return res % mod;
        }
 
    }
}
 
 
Tester
// copied from https://github.com/e-maxx-eng/e-maxx-eng-aux/blob/master/src/polynomial.cpp
// There are some changes in constants
// License: https://github.com/e-maxx-eng/e-maxx-eng-aux/blob/master/LICENSE
 
#include <bits/stdc++.h>
 
using namespace std;
namespace algebra {
  const int inf = 1e9;
  const int magic = 500; // threshold for sizes to run the naive algo
  
  namespace fft {
    const int maxn = 1 << 19;
 
    typedef double ftype;
    typedef complex<ftype> point;
 
    point w[maxn];
    const ftype pi = acos(-1);
    bool initiated = 0;
    void init() {
      if(!initiated) {
        for(int i = 1; i < maxn; i *= 2) {
          for(int j = 0; j < i; j++) {
            w[i + j] = polar(ftype(1), pi * j / i);
          }
        }
        initiated = 1;
      }
    }
    template<typename T>
    void fft(T *in, point *out, int n, int k = 1) {
      if(n == 1) {
        *out = *in;
      } else {
        n /= 2;
        fft(in, out, n, 2 * k);
        fft(in + k, out + n, n, 2 * k);
        for(int i = 0; i < n; i++) {
          auto t = out[i + n] * w[i + n];
          out[i + n] = out[i] - t;
          out[i] += t;
        }
      }
    }
    
    template<typename T>
    void mul_slow(vector<T> &a, const vector<T> &b) {
      vector<T> res(a.size() + b.size() - 1);
      for(size_t i = 0; i < a.size(); i++) {
        for(size_t j = 0; j < b.size(); j++) {
          res[i + j] += a[i] * b[j];
        }
      }
      a = res;
    }
    
    
    template<typename T>
    void mul(vector<T> &a, const vector<T> &b) {
      if(min(a.size(), b.size()) < magic) {
        mul_slow(a, b);
        return;
      }
      init();
      static const int shift = 15, mask = (1 << shift) - 1;
      size_t n = a.size() + b.size() - 1;
      while(__builtin_popcount(n) != 1) {
        n++;
      }
      a.resize(n);
      static point A[maxn], B[maxn];
      static point C[maxn], D[maxn];
      for(size_t i = 0; i < n; i++) {
        A[i] = point(a[i] & mask, a[i] >> shift);
        if(i < b.size()) {
          B[i] = point(b[i] & mask, b[i] >> shift);
        } else {
          B[i] = 0;
        }
      }
      fft(A, C, n); fft(B, D, n);
      for(size_t i = 0; i < n; i++) {
        point c0 = C[i] + conj(C[(n - i) % n]);
        point c1 = C[i] - conj(C[(n - i) % n]);
        point d0 = D[i] + conj(D[(n - i) % n]);
        point d1 = D[i] - conj(D[(n - i) % n]);
        A[i] = c0 * d0 - point(0, 1) * c1 * d1;
        B[i] = c0 * d1 + d0 * c1;
      }
      fft(A, C, n); fft(B, D, n);
      reverse(C + 1, C + n);
      reverse(D + 1, D + n);
      int t = 4 * n;
      for(size_t i = 0; i < n; i++) {
        int64_t A0 = llround(real(C[i]) / t);
        T A1 = llround(imag(D[i]) / t);
        T A2 = llround(imag(C[i]) / t);
        a[i] = A0 + (A1 << shift) + (A2 << 2 * shift);
      }
      return;
    }
  }
  template<typename T>
  T bpow(T x, size_t n) {
    return n ? n % 2 ? x * bpow(x, n - 1) : bpow(x * x, n / 2) : T(1);
  }
  template<typename T>
  T bpow(T x, size_t n, T m) {
    return n ? n % 2 ? x * bpow(x, n - 1, m) % m : bpow(x * x % m, n / 2, m) : T(1);
  }
  template<typename T>
  T gcd(const T &a, const T &b) {
    return b == T(0) ? a : gcd(b, a % b);
  }
  template<typename T>
  T nCr(T n, int r) { // runs in O(r)
    T res(1);
    for(int i = 0; i < r; i++) {
      res *= (n - T(i));
      res /= (i + 1);
    }
    return res;
  }
 
  template<int m>
  struct modular {
    int64_t r;
    modular() : r(0) {}
    modular(int64_t rr) : r(rr) {if(abs(r) >= m) r %= m; if(r < 0) r += m;}
    modular inv() const {return bpow(*this, m - 2);}
    modular operator * (const modular &t) const {return (r * t.r) % m;}
    modular operator / (const modular &t) const {return *this * t.inv();}
    modular operator += (const modular &t) {r += t.r; if(r >= m) r -= m; return *this;}
    modular operator -= (const modular &t) {r -= t.r; if(r < 0) r += m; return *this;}
    modular operator + (const modular &t) const {return modular(*this) += t;}
    modular operator - (const modular &t) const {return modular(*this) -= t;}
    modular operator *= (const modular &t) {return *this = *this * t;}
    modular operator /= (const modular &t) {return *this = *this / t;}
    
    bool operator == (const modular &t) const {return r == t.r;}
    bool operator != (const modular &t) const {return r != t.r;}
    
    operator int64_t() const {return r;}
  };
  template<int T>
  istream& operator >> (istream &in, modular<T> &x) {
    return in >> x.r;
  }
  
  
  template<typename T>
  struct poly {
    vector<T> a;
    
    void normalize() { // get rid of leading zeroes
      while(!a.empty() && a.back() == T(0)) {
        a.pop_back();
      }
    }
    
    poly(){}
    poly(T a0) : a{a0}{normalize();}
    poly(vector<T> t) : a(t){normalize();}
    
    poly operator += (const poly &t) {
      a.resize(max(a.size(), t.a.size()));
      for(size_t i = 0; i < t.a.size(); i++) {
        a[i] += t.a[i];
      }
      normalize();
      return *this;
    }
    poly operator -= (const poly &t) {
      a.resize(max(a.size(), t.a.size()));
      for(size_t i = 0; i < t.a.size(); i++) {
        a[i] -= t.a[i];
      }
      normalize();
      return *this;
    }
    poly operator + (const poly &t) const {return poly(*this) += t;}
    poly operator - (const poly &t) const {return poly(*this) -= t;}
    
    poly mod_xk(size_t k) const { // get same polynomial mod x^k
      k = min(k, a.size());
      return vector<T>(begin(a), begin(a) + k);
    }
    poly mul_xk(size_t k) const { // multiply by x^k
      poly res(*this);
      res.a.insert(begin(res.a), k, 0);
      return res;
    }
    poly div_xk(size_t k) const { // divide by x^k, dropping coefficients
      k = min(k, a.size());
      return vector<T>(begin(a) + k, end(a));
    }
    poly substr(size_t l, size_t r) const { // return mod_xk(r).div_xk(l)
      l = min(l, a.size());
      r = min(r, a.size());
      return vector<T>(begin(a) + l, begin(a) + r);
    }
    poly inv(size_t n) const { // get inverse series mod x^n
      assert(!is_zero());
      poly ans = a[0].inv();
      size_t a = 1;
      while(a < n) {
        poly C = (ans * mod_xk(2 * a)).substr(a, 2 * a);
        ans -= (ans * C).mod_xk(a).mul_xk(a);
        a *= 2;
      }
      return ans.mod_xk(n);
    }
    
    poly operator *= (const poly &t) {fft::mul(a, t.a); normalize(); return *this;}
    poly operator * (const poly &t) const {return poly(*this) *= t;}
    
    poly reverse(size_t n, bool rev = 0) const { // reverses and leaves only n terms
      poly res(*this);
      if(rev) { // If rev = 1 then tail goes to head
        res.a.resize(max(n, res.a.size()));
      }
      std::reverse(res.a.begin(), res.a.end());
      return res.mod_xk(n);
    }
    
    pair<poly, poly> divmod_slow(const poly &b) const { // when divisor or quotient is small
      vector<T> A(a);
      vector<T> res;
      while(A.size() >= b.a.size()) {
        res.push_back(A.back() / b.a.back());
        if(res.back() != T(0)) {
          for(size_t i = 0; i < b.a.size(); i++) {
            A[A.size() - i - 1] -= res.back() * b.a[b.a.size() - i - 1];
          }
        }
        A.pop_back();
      }
      std::reverse(begin(res), end(res));
      return {res, A};
    }
    
    pair<poly, poly> divmod(const poly &b) const { // returns quotiend and remainder of a mod b
      if(deg() < b.deg()) {
        return {poly{0}, *this};
      }
      int d = deg() - b.deg();
      if(min(d, b.deg()) < magic) {
        return divmod_slow(b);
      }
      poly D = (reverse(d + 1) * b.reverse(d + 1).inv(d + 1)).mod_xk(d + 1).reverse(d + 1, 1);
      return {D, *this - D * b};
    }
    
    poly operator / (const poly &t) const {return divmod(t).first;}
    poly operator % (const poly &t) const {return divmod(t).second;}
    poly operator /= (const poly &t) {return *this = divmod(t).first;}
    poly operator %= (const poly &t) {return *this = divmod(t).second;}
    poly operator *= (const T &x) {
      for(auto &it: a) {
        it *= x;
      }
      normalize();
      return *this;
    }
    poly operator /= (const T &x) {
      for(auto &it: a) {
        it /= x;
      }
      normalize();
      return *this;
    }
    poly operator * (const T &x) const {return poly(*this) *= x;}
    poly operator / (const T &x) const {return poly(*this) /= x;}
    
    void print() const {
      for(auto it: a) {
        cout << it << ' ';
      }
      cout << endl;
    }
    T eval(T x) const { // evaluates in single point x
      T res(0);
      for(int i = int(a.size()) - 1; i >= 0; i--) {
        res *= x;
        res += a[i];
      }
      return res;
    }
    
    T& lead() { // leading coefficient
      return a.back();
    }
    int deg() const { // degree
      return a.empty() ? -inf : a.size() - 1;
    }
    bool is_zero() const { // is polynomial zero
      return a.empty();
    }
    T operator [](int idx) const {
      return idx >= (int)a.size() || idx < 0 ? T(0) : a[idx];
    }
    
    T& coef(size_t idx) { // mutable reference at coefficient
      return a[idx];
    }
    bool operator == (const poly &t) const {return a == t.a;}
    bool operator != (const poly &t) const {return a != t.a;}
    
    poly deriv() { // calculate derivative
      vector<T> res;
      for(int i = 1; i <= deg(); i++) {
        res.push_back(T(i) * a[i]);
      }
      return res;
    }
    poly integr() { // calculate integral with C = 0
      vector<T> res = {0};
      for(int i = 0; i <= deg(); i++) {
        res.push_back(a[i] / T(i + 1));
      }
      return res;
    }
    size_t leading_xk() const { // Let p(x) = x^k * t(x), return k
      if(is_zero()) {
        return inf;
      }
      int res = 0;
      while(a[res] == T(0)) {
        res++;
      }
      return res;
    }
    poly log(size_t n) { // calculate log p(x) mod x^n
      assert(a[0] == T(1));
      return (deriv().mod_xk(n) * inv(n)).integr().mod_xk(n);
    }
    poly exp(size_t n) { // calculate exp p(x) mod x^n
      if(is_zero()) {
        return T(1);
      }
      assert(a[0] == T(0));
      poly ans = T(1);
      size_t a = 1;
      while(a < n) {
        poly C = ans.log(2 * a).div_xk(a) - substr(a, 2 * a);
        ans -= (ans * C).mod_xk(a).mul_xk(a);
        a *= 2;
      }
      return ans.mod_xk(n);
      
    }
    poly pow_slow(size_t k, size_t n) { // if k is small
      return k ? k % 2 ? (*this * pow_slow(k - 1, n)).mod_xk(n) : (*this * *this).mod_xk(n).pow_slow(k / 2, n) : T(1);
    }
    poly pow(size_t k, size_t n) { // calculate p^k(n) mod x^n
      if(is_zero()) {
        return *this;
      }
      if(k < magic) {
        return pow_slow(k, n);
      }
      int i = leading_xk();
      T j = a[i];
      poly t = div_xk(i) / j;
      return bpow(j, k) * (t.log(n) * T(k)).exp(n).mul_xk(i * k).mod_xk(n);
    }
    poly mulx(T x) { // component-wise multiplication with x^k
      T cur = 1;
      poly res(*this);
      for(int i = 0; i <= deg(); i++) {
        res.coef(i) *= cur;
        cur *= x;
      }
      return res;
    }
    poly mulx_sq(T x) { // component-wise multiplication with x^{k^2}
      T cur = x;
      T total = 1;
      T xx = x * x;
      poly res(*this);
      for(int i = 0; i <= deg(); i++) {
        res.coef(i) *= total;
        total *= cur;
        cur *= xx;
      }
      return res;
    }
    vector<T> chirpz_even(T z, int n) { // P(1), P(z^2), P(z^4), ..., P(z^2(n-1))
      int m = deg();
      if(is_zero()) {
        return vector<T>(n, 0);
      }
      vector<T> vv(m + n);
      T zi = z.inv();
      T zz = zi * zi;
      T cur = zi;
      T total = 1;
      for(int i = 0; i <= max(n - 1, m); i++) {
        if(i <= m) {vv[m - i] = total;}
        if(i < n) {vv[m + i] = total;}
        total *= cur;
        cur *= zz;
      }
      poly w = (mulx_sq(z) * vv).substr(m, m + n).mulx_sq(z);
      vector<T> res(n);
      for(int i = 0; i < n; i++) {
        res[i] = w[i];
      }
      return res;
    }
    vector<T> chirpz(T z, int n) { // P(1), P(z), P(z^2), ..., P(z^(n-1))
      auto even = chirpz_even(z, (n + 1) / 2);
      auto odd = mulx(z).chirpz_even(z, n / 2);
      vector<T> ans(n);
      for(int i = 0; i < n / 2; i++) {
        ans[2 * i] = even[i];
        ans[2 * i + 1] = odd[i];
      }
      if(n % 2 == 1) {
        ans[n - 1] = even.back();
      }
      return ans;
    }
    template<typename iter>
    vector<T> eval(vector<poly> &tree, int v, iter l, iter r) { // auxiliary evaluation function
      if(r - l == 1) {
        return {eval(*l)};
      } else {
        auto m = l + (r - l) / 2;
        auto A = (*this % tree[2 * v]).eval(tree, 2 * v, l, m);
        auto B = (*this % tree[2 * v + 1]).eval(tree, 2 * v + 1, m, r);
        A.insert(end(A), begin(B), end(B));
        return A;
      }
    }
    vector<T> eval(vector<T> x) { // evaluate polynomial in (x1, ..., xn)
      int n = x.size();
      if(is_zero()) {
        return vector<T>(n, T(0));
      }
      vector<poly> tree(4 * n);
      build(tree, 1, begin(x), end(x));
      return eval(tree, 1, begin(x), end(x));
    }
    template<typename iter>
    poly inter(vector<poly> &tree, int v, iter l, iter r, iter ly, iter ry) { // auxiliary interpolation function
      if(r - l == 1) {
        return {*ly / a[0]};
      } else {
        auto m = l + (r - l) / 2;
        auto my = ly + (ry - ly) / 2;
        auto A = (*this % tree[2 * v]).inter(tree, 2 * v, l, m, ly, my);
        auto B = (*this % tree[2 * v + 1]).inter(tree, 2 * v + 1, m, r, my, ry);
        return A * tree[2 * v + 1] + B * tree[2 * v];
      }
    }
  };
  template<typename T>
  poly<T> operator * (const T& a, const poly<T>& b) {
    return b * a;
  }
  
  template<typename T>
  poly<T> xk(int k) { // return x^k
    return poly<T>{1}.mul_xk(k);
  }
 
  template<typename T>
  T resultant(poly<T> a, poly<T> b) { // computes resultant of a and b
    if(b.is_zero()) {
      return 0;
    } else if(b.deg() == 0) {
      return bpow(b.lead(), a.deg());
    } else {
      int pw = a.deg();
      a %= b;
      pw -= a.deg();
      T mul = bpow(b.lead(), pw) * T((b.deg() & a.deg() & 1) ? -1 : 1);
      T ans = resultant(b, a);
      return ans * mul;
    }
  }
  template<typename iter>
  poly<typename iter::value_type> kmul(iter L, iter R) { // computes (x-a1)(x-a2)...(x-an) without building tree
    if(R - L == 1) {
      return vector<typename iter::value_type>{-*L, 1};
    } else {
      iter M = L + (R - L) / 2;
      return kmul(L, M) * kmul(M, R);
    }
  }
  template<typename T, typename iter>
  poly<T> build(vector<poly<T>> &res, int v, iter L, iter R) { // builds evaluation tree for (x-a1)(x-a2)...(x-an)
    if(R - L == 1) {
      return res[v] = vector<T>{-*L, 1};
    } else {
      iter M = L + (R - L) / 2;
      return res[v] = build(res, 2 * v, L, M) * build(res, 2 * v + 1, M, R);
    }
  }
  template<typename T>
  poly<T> inter(vector<T> x, vector<T> y) { // interpolates minimum polynomial from (xi, yi) pairs
    int n = x.size();
    vector<poly<T>> tree(4 * n);
    return build(tree, 1, begin(x), end(x)).deriv().inter(tree, 1, begin(x), end(x), begin(y), end(y));
  }
};
 
using namespace algebra;
 
const int mod = 998244353;
 
typedef modular<mod> base;
typedef poly<base> polyn;
 
using namespace algebra;
 
struct LinearRecurrence {
  using int64 = long long;
  using vec = std::vector<int64>;
 
  static void extand(vec &a, size_t d, int64 value = 0) {
    if (d <= a.size()) return;
    a.resize(d, value);
  }
  static vec BerlekampMassey(const vec &s, int64 mod) {
    std::function<int64(int64)> inverse = [&](int64 a) {
      return a == 1 ? 1 : (int64)(mod - mod / a) * inverse(mod % a) % mod;
    };
    vec A = {1}, B = {1};
    int64 b = s[0];
    for (size_t i = 1, m = 1; i < s.size(); ++i, m++) {
      int64 d = 0;
      for (size_t j = 0; j < A.size(); ++j) {
        d += A[j] * s[i - j] % mod;
      }
      if (!(d %= mod)) continue;
      if (2 * (A.size() - 1) <= i) {
        auto temp = A;
        extand(A, B.size() + m);
        int64 coef = d * inverse(b) % mod;
        for (size_t j = 0; j < B.size(); ++j) {
          A[j + m] -= coef * B[j] % mod;
          if (A[j + m] < 0) A[j + m] += mod;
        }
        B = temp, b = d, m = 0;
      } else {
        extand(A, B.size() + m);
        int64 coef = d * inverse(b) % mod;
        for (size_t j = 0; j < B.size(); ++j) {
          A[j + m] -= coef * B[j] % mod;
          if (A[j + m] < 0) A[j + m] += mod;
        }
      }
    }
    return A;
  }
  
  vec init, trans;
  int m;
 
  LinearRecurrence(const vec &s, const vec &c):
    init(s), trans(c), m(s.size()) {}
  LinearRecurrence(const vec &s) {
    vec A = BerlekampMassey(s, mod);
    if (A.empty()) A = {0};
    m = A.size() - 1;
    trans.resize(m);
    for (int i = 0; i < m; ++i) {
      trans[i] = (mod - A[i + 1]) % mod;
    }
    std::reverse(trans.begin(), trans.end());
    init = {s.begin(), s.begin() + m};
  }
};
 
// copied from https://github.com/zimpha/algorithmic-library/blob/master/cpp/mathematics/det-mod.cc
 
base det_mod(vector<vector<base>> mat) {
  int n = mat.size();
  base ret = 1;
  for (int i = 0; i < n; ++i) {
    for (int j = i + 1; j < n; ++j)
      for (; mat[j][i]; ret = -ret) {
        base t = mat[i][i] / mat[j][i];
        for (int k = i; k < n; ++k) {
          mat[i][k] = mat[i][k] - mat[j][k] * t;
          std::swap(mat[j][k], mat[i][k]);
        }
      }
    if (mat[i][i].r == 0) return 0;
    ret = ret * mat[i][i];
  }
  return ret;
}
 
std::mt19937 rng(0x173819);
 
void solve() {
  int N, M, K; cin >> N >> M >> K;
 
  vector<vector<base>> L(N, vector<base>(N));
  while(M--) {
    int u, v; cin >> u >> v;
    u -= 1;
    v -= 1;
    L[u][v] -= 1;
    L[v][u] -= 1;
    L[u][u] += 1;
    L[v][v] += 1;
  }
 
  vector<base> char_poly0_values;
  vector<base> lambdas;
  for(int lambda = 0; lambda <= N; lambda++) {
    vector<vector<base>> mat(N, vector<base>(N));
    for(int i = 0; i < N; i++) for(int j = 0; j < N; j++) {
      mat[i][j] = (i == j ? lambda : 0) - L[i][j];
    }
    lambdas.push_back(lambda);
    char_poly0_values.push_back(det_mod(mat));
  }
 
  poly<base> p0 = inter(lambdas, char_poly0_values); //.pow(K, N*K+1);
  base xx = base(N) * base(K);
  base ans = bpow(p0.eval(xx), K) / bpow(xx, 2);
  printf("%lld\n", ans.r);
}
 
int main() {
#ifdef IN_MY_COMPUTER
  freopen("cstree.in", "r", stdin);
#endif
  ios::sync_with_stdio(0);
  cin.tie(0);
 
  int T; cin >> T;
  while(T--) {
    solve();
  }
  return 0;
}   

Time Complexity=O(N^4)
Space Complexity=O(N^3)

CHEF VIJJU’S CORNER :smiley:

Why the polynomial of G_s is of degree at most N

We are adding x to only diagonal elements. Hence, x is added at a total of N places only. Worst case, if all of them are multiplied, we will get a polynomial of degree N. Anything more than that is not possible.

Interpolation

As useless as it might seem to you, interpolation is a very crucial thing in fields like Machine Learning.

In Machine Learning, we often have data about what the values of functions (Eg- Total Sales of Cars) across my parameter variables x,y,z,..etc are. Parameters can be anything I suspect to be affecting my function, say How big the city is, How many people in City are “rich” &etc.

Given the data, I want to predict the future data given those parameters. One of the way is to find a similar function most closely representing the behavior of my function. Most of the times, these functions are polynomials of some degree.

We first decide what degree of polynomial to use (depends on number of parameters) and then interpolate it using data points to find the exact polynomial. Then we see how good this polynomial satisfies our needs, else we rinse and repeat. Such concepts of algebra, althought they might appear useless, are forming the crux of Data Science.

Setter's Notes

We use generalized cayley’s formula (combinatorics - Proof of Generalized Cayley's formula - Mathematics Stack Exchange) and inclusion/exclusion.

For inclusion/exclusion. We need to count the number of spanning trees that use at least t forbidden edges. We can split t = t_1+t_2+...+t_k, where t_i is number of forbidden edges we used in i-th copy of graph. We also need to take the size of the components into account, so if we use t_i edges in the i-th copy, there will be n-t_i components with some sizes s_{i,1},...,s_{i,n-t_i}. The number of spanning trees in total with this configuration is \sum \prod(s_{i,j}) * (k*n)^{(k*n-t)}, over all valid ways to split the sizes of each component.

The second term is only dependent on the number of forbidden edges we use, so we can pull it out.
We can split the first term per component, so we only need to find these values for one tree, then we can use polynomial multiplication to get the values for k copies.

To compute \sum \prod (s_{i,j}) for one tree, use the generalized matrix tree theorem. Consider a matrix M, with M_{i,i} = deg(i) and M_{i,j} = -1 if there is an edge and 0 otherwise. Fix some set A. If you delete the i-th row and column for every i in A, then the determinant of M will be the number of forests of the original graph with every node in A belonging to a separate component.
Consider the matrix M. If we add a variable x along the diagonal of M, then the determinant of M will be some polynomial in x. We claim that the coefficient of x^k is exactly what we need to compute. It sums over all ways to split some nodes into k different components. Forests are counted multiple times according to the product of their sizes, which is what we need.

So the solution will look something like this:

  • Compute the determinant of matrix M + I * x to get a polynomial (easiest way can be to evaluate this at x=0,x=1,...,x=n-1, then do polynomial interpolation to get coefficients)
  • Exponentiate this polynomial to k-th power (maybe need fft here).
  • Take i-th coefficient and multiply by (-1)^i * (k*n)^{(k*n-i)}, and sum them up to get the final answer
Tester's Notes/ Trick

I could easily solve it with the help of Google. After trying to search with “laplacian spanning trees”, “laplacian complement” and skimming short papers and presentations, I could find a solution from these two materials.

Related Problems
4 Likes