Help in Polynomial Multiplication

How do i solve This SPOJ Problem using FFT?

This is why fft is used, to multiply 2 polynomials of degree n in O(nlogn). If you don’t know what fft is google up. There are lots of resources online on fft. This is one of the best resource.

1 Like

I think I’ve got a pretty good explanation for this. What I will show is how to connect polynomial multiplication to calculating FFT, but I will not go into any details of how to calculate the FFT. Lets start out with connecting polynomial multiplication to something called convolution.

Let A=[a_0,a_1,...,a_{n-1}] and B=[b_0,b_1,...,b_{m-1}] and let C=[c_0,c_1,...,c_{n+m-2}] be the coefficients corresponding to the polynomial found by multiplying (a_0 + a_1 x + ... + a_{n-1} x^{n-1})(b_0 + b_1 x + ... + b_{m-1} x^{m-1}).

Another way of saying this is that C is defined as the convolution between A and B, denoted C=A*B. Basically multiplying the polynomials is the same thing as convoluting the coefficients.

One way to calculate a convolution is to run the following code:

Click to view

Polynomial multiplication in O(n m) time

def conv(A,B):
    n = len(A)
    m = len(B)
    C = [0]*(n+m-1)
    for i in range(n):
        for j in range(m):
            C[i+j]+=A[i]*B[j]
    return C

Unfortunately this code runs in O(n^2) and is not straight up something you can calculate using FFT. However there is
another kind of convolution that easily can be calculated in O(n \log(n)) using FFT. The natural “convolution” when using FFT is something called a circular convolution defined as

Click to view

Circular convolution in O(n^2) time

def circ_conv(A,B):
    assert(len(A)==len(B))
    n = len(A)
    C = [0]*(n)
    for i in range(n):
        for j in range(n):
            C[(i+j)%n]+=A[i]*B[j]
    return C

Using FFT and iFFT (inverse Fourier transform) circular convolution can be calculated in O(n \log(n)) by

Click to view

Circular convolution in O(n \log(n)) time

def circ_conv(A,B):
    assert(len(A)==len(B))
    n = len(A)
    A = FFT(A)
    B = FFT(B)
    C = [0]*(n)
    for i in range(n):
        C[i]=A[i]*B[i]
    return iFFT(C)

The only thing left to do is to connect common convolution to the circular one. The standard way to do this is to pad the end of A and B with a lot of zeros, then do a circular convolution, and then remove the extra zeros at the end of C. The following is one way to implement it:

Click to view

Polynomial multiplication in O((n+m) \log(n+m)) time

def conv(A,B):
    n = len(A)
    m = len(B)
    N = 1
    while N<n+m-1:
        N*=2
    A = A + [0]*(N-n)
    B = B + [0]*(N-m)
    C = circ_conv(A,B)
    return C[:n+m-1]

Hope that helps!

Some final remarks:
One reason that I really like this way of implementing polynomial multiplication is that during debugging/implementation of the code you can use the O(n^2) algorithms to be sure that it is not the FFT that is the problem. Also note that all the code I’ve written is python code and should straight up run in python, given that you’ve implemented FFT and iFFT.

EDIT: Added an example of a fully working polynomial multiplier running in python3 with time complexity O(n \log(n)), using a recursive implementation of Cooley-Tukey algorithm for FFT.

Click to view
from cmath import exp
from math import pi

def isPowerOfTwo(n):
    return n>0 and (n&(n-1))==0

# FFT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented reqursively, 
# NOTE that Cooley-Tukey requires n to be a power of two
def FFT(A):
    n = len(A)
    if n==1:
        return A

    assert(isPowerOfTwo(n))

    even = FFT(A[::2])
    odd  = FFT(A[1::2])

    # Numerically stable way of "twiddling"
    return [even[k] + exp(-2*pi*k/n*1j)*odd[k] for k in range(n//2)] +\
           [even[k] - exp(-2*pi*k/n*1j)*odd[k] for k in range(n//2)]

# Inverse FFT
def iFFT(A):
    n = len(A)
    A = FFT([a.conjugate() for a in A])
    return [a.conjugate()/n for a in A]

# Circular convolution in O(nlog(n)) time
def circ_conv(A,B):
    assert(len(A)==len(B))
    n = len(A)
    A = FFT(A)
    B = FFT(B)
    C = [0]*(n)
    for i in range(n):
        C[i]=A[i]*B[i]
    return iFFT(C)

# Polynomial multiplication in O((n+m)log(n+m)) time
def conv(A,B):
    n = len(A)
    m = len(B)
    N = 1
    while N<n+m-1:
        N*=2
    A = A + [0]*(N-n)
    B = B + [0]*(N-m)
    C = circ_conv(A,B)
    return C[:n+m-1]

# Example
A = [1,1,2,3,4,5,6+7j,100]
print(A)
print(FFT(A))
print(iFFT(FFT(A)))

# Multiply (1+2x+3x^2) and (2+3x+4x^2+5x^3)
A = [1,2,3]
B = [2,3,4,5]
print(A,B,conv(A,B))

If you want to actually use this code then there are two things you should consider

  1. My implementation of FFT currently has a large constant cost because it is recursive. Switch to an in-place iterative implementation of Cooley-Tukey if you want to make it noticeable quicker. The wiki is great at explaining how to do this, note however that following the wiki might lead to really bad numerical stability.

  2. My implementation is written in python, and the worst thing in python is how slow it is with recursion. Switch to c++ and the code will run much much quicker.

EDIT2: What the heck, I might as well go the whole way. The following is a pretty quick and advanced implementation of FFT and NTT that allows for convolution using FFT or NTT. You can switch between them freely. Still the basics about using convolution and circular convolution is exactly the same.

Click to view
from cmath import exp
from math import pi

# Quick convolution that can interchangeably use both 
# FFT or NTT, see example at bottom.
# This uses a somewhat advanced implementation
# of Cooley-Tukey that hopefully runs quickly with high
# numerical stability.
# /pajenegod


def isPowerOfTwo(n):
    return n>0 and (n&(n-1))==0

# Permutates A with a bit reversal
# Ex. [0,1,2,3,4,5,6,7]->[0,4,2,6,1,5,3,7]
def bit_reversal(A):
    n = len(A)
    assert(isPowerOfTwo(n))
    
    k = 0
    m = 1
    while m<n:m*=2;k+=1
    
    for i in range(n):
        I = i
        j = 0
        for _ in range(k):
            j = j*2 + i%2
            i //= 2
        if j>I:
            A[I],A[j]=A[j],A[I]
    return

### NTT ALGORITHM BASED ON COOLEY TUKEY

# Inplace NTT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented iteratively using bit reversal, 
# NOTE that Cooley-Tukey requires n to be a power of two
# and also that n <= longest_conv, basically
# n is limited by the ntt_prime

# Remember to set ntt_prime and ntt_root before calling, for example
ntt_prime = (479<<21)+1
ntt_root = 3
def NTT_CT(A,inverse=False):
    # Some pre-calulations needed to do the ntt
    non_two = ntt_prime-1
    longest_conv = 1
    while (ntt_prime-1)%(2*longest_conv)==0:longest_conv*=2
    ntt_base = pow(ntt_root,(ntt_prime-1)//longest_conv,ntt_prime) 
   
    N = len(A)
    assert(isPowerOfTwo(N))
    assert(N<=longest_conv)
    
    for i in range(N):
        A[i]%=ntt_prime
    
    # Calculate the twiddle factors
    e = pow(ntt_base,longest_conv//N,ntt_prime)
    if inverse:
        e = pow(e,ntt_prime-2,ntt_prime)
    b = e
    twiddles = [1]
    while len(twiddles)<N//2:
        twiddles += [t*b%ntt_prime for t in twiddles]
        b = b**2%ntt_prime

    bit_reversal(A)

    n = 2
    while n<=N:
        offset = 0
        while offset<N: 
            depth = N//n
            for k in range(n//2):
                ind1 = k      + offset
                ind2 = k+n//2 + offset
                even = A[ind1]
                odd  = A[ind2]*twiddles[k*depth]

                A[ind1]  = (even + odd)%ntt_prime
                A[ind2]  = (even - odd)%ntt_prime
            
            offset += n
        n*=2

    if inverse:
        inv_N = pow(N,ntt_prime-2,ntt_prime) 
        for i in range(N):
            A[i] = A[i]*inv_N%ntt_prime
    return

### FFT ALGORITHM BASED ON Cooley-Tukey

# Inplace FFT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented iteratively using bit_reversal, 
# NOTE that Cooley-Tukey requires n to be a power of two
def FFT_CT(A,inverse=False):
    N = len(A)
    assert(isPowerOfTwo(N))
    
    # Calculate the twiddle factors, with very good numerical stability 
    e = -2*pi/N*1j
    if inverse:
        e = -e
    twiddles = [exp(e*k) for k in range(N//2)]
    
    bit_reversal(A)

    n = 2
    while n<=N:
        offset = 0
        while offset<N: 
            depth = N//n
            for k in range(n//2):
                ind1 = k      + offset
                ind2 = k+n//2 + offset
                even = A[ind1]
                odd  = A[ind2]*twiddles[k*depth]

                A[ind1]  = even + odd
                A[ind2]  = even - odd
            
            offset += n
        n*=2

    if inverse:
        inv_N = 1.0/N 
        for i in range(N):
            A[i]*=inv_N
    return A

# Circular convolution in O(nlog(n)) time
def circ_conv(A,B):
    assert(len(A)==len(B))
    n = len(A)
    
    A = list(A)
    B = list(B)
    FFT(A)
    FFT(B)
    
    C = [A[i]*B[i] for i in range(n)]
    FFT(C,inverse=True)
    return C

# Polynomial multiplication in O((n+m)log(n+m)) time
def conv(A,B):
    n = len(A)
    m = len(B)
    N = 1
    while N<n+m-1:
        N*=2
    A = A + [0]*(N-n)
    B = B + [0]*(N-m)
    C = circ_conv(A,B)
    return C[:n+m-1]

# example
for ntt in [False,True]:
    # Switch between using ntt or ftt for convolution
    if ntt:
        # Set ntt prime
        ntt_prime = (119<<23)+1
        ntt_root  = 3
        
        print('Using NTT for convolution')
        FFT = NTT_CT
    else:
        print('Using FFT for convolution')
        FFT = FFT_CT

    # Example
    A = [1,1,2,3,4,5,6,100]
    print('A',A)
    FFT(A)
    print('FFT(A)',A)
    FFT(A,inverse=True)
    print('iFFT(FFT(A))',A)

    # Multiply (1+2x+3x^2) and (2+3x+4x^2+5x^3)
    A = [1,2,3]
    B = [2,3,4,5]
    print('A=',A)
    print('B=',B)
    print('A*B=',conv(A,B))
14 Likes

can anyone provide reference java implementation ?

@soham1234 I know the theory part but don’t know how to implement it.

1 Like

Best will be have a look at someone’s code and learn how to use it if you don’t want to implement it by yourself. I do the same thing infact

@gorre_morre , let me handle the editing. :slight_smile:

Have a look now. :slight_smile:

Thanks for the help!

We need more good people like @gorre_morre who actually tell why a post needs editing so some poor mod isnt confused here and there xD

1 Like

@gorre_morre Can u help me in Calculating FFT as u have described in O(n log(n)) approach?

@rds_98 I’ve just added an example of a fully working polynomial multiplier using Cooley-Tukey FFT algorithm.

In my original answer I didn’t describe how to implement the FFT, just how polynomial multiplication can be done using FFT. The most common and probably simplest way to implement FFT is using Cooley-Tukey algortihm, wiki has a pretty good article about the algorithm. So that is what I’m using for my example.

4 Likes

@gorre_morre Thanks a Lot. Thanks :slight_smile:

I doubt that it is possible to solve this particular problem using FFT written on JAVA, because JAVA hasn’t analogue of C long double type and I suspect that BigDecimal is too slow. I tried to solve this problem on D language using FFT from standard library which use doubles and got WA due precision errors. Only when I wrote FFT with real type (long doubles in C), I got AC. If I want to solve this problem on JAVA I prefer to use Karatsuba multiplication algorithm (divide and conquerer).

@vijju123 Need editing again. Code is not working now again.