How do i solve This SPOJ Problem using FFT?
This is why fft is used, to multiply 2 polynomials of degree n in O(nlogn). If you don’t know what fft is google up. There are lots of resources online on fft. This is one of the best resource.
I think I’ve got a pretty good explanation for this. What I will show is how to connect polynomial multiplication to calculating FFT, but I will not go into any details of how to calculate the FFT. Lets start out with connecting polynomial multiplication to something called convolution.
Let A=[a_0,a_1,...,a_{n-1}] and B=[b_0,b_1,...,b_{m-1}] and let C=[c_0,c_1,...,c_{n+m-2}] be the coefficients corresponding to the polynomial found by multiplying (a_0 + a_1 x + ... + a_{n-1} x^{n-1})(b_0 + b_1 x + ... + b_{m-1} x^{m-1}).
Another way of saying this is that C is defined as the convolution between A and B, denoted C=A*B. Basically multiplying the polynomials is the same thing as convoluting the coefficients.
One way to calculate a convolution is to run the following code:
Click to view
Polynomial multiplication in O(n m) time
def conv(A,B):
n = len(A)
m = len(B)
C = [0]*(n+m-1)
for i in range(n):
for j in range(m):
C[i+j]+=A[i]*B[j]
return C
Unfortunately this code runs in O(n^2) and is not straight up something you can calculate using FFT. However there is
another kind of convolution that easily can be calculated in O(n \log(n)) using FFT. The natural “convolution” when using FFT is something called a circular convolution defined as
Click to view
Circular convolution in O(n^2) time
def circ_conv(A,B):
assert(len(A)==len(B))
n = len(A)
C = [0]*(n)
for i in range(n):
for j in range(n):
C[(i+j)%n]+=A[i]*B[j]
return C
Using FFT and iFFT (inverse Fourier transform) circular convolution can be calculated in O(n \log(n)) by
Click to view
Circular convolution in O(n \log(n)) time
def circ_conv(A,B):
assert(len(A)==len(B))
n = len(A)
A = FFT(A)
B = FFT(B)
C = [0]*(n)
for i in range(n):
C[i]=A[i]*B[i]
return iFFT(C)
The only thing left to do is to connect common convolution to the circular one. The standard way to do this is to pad the end of A and B with a lot of zeros, then do a circular convolution, and then remove the extra zeros at the end of C. The following is one way to implement it:
Click to view
Polynomial multiplication in O((n+m) \log(n+m)) time
def conv(A,B):
n = len(A)
m = len(B)
N = 1
while N<n+m-1:
N*=2
A = A + [0]*(N-n)
B = B + [0]*(N-m)
C = circ_conv(A,B)
return C[:n+m-1]
Hope that helps!
Some final remarks:
One reason that I really like this way of implementing polynomial multiplication is that during debugging/implementation of the code you can use the O(n^2) algorithms to be sure that it is not the FFT that is the problem. Also note that all the code I’ve written is python code and should straight up run in python, given that you’ve implemented FFT and iFFT.
EDIT: Added an example of a fully working polynomial multiplier running in python3 with time complexity O(n \log(n)), using a recursive implementation of Cooley-Tukey algorithm for FFT.
Click to view
from cmath import exp
from math import pi
def isPowerOfTwo(n):
return n>0 and (n&(n-1))==0
# FFT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented reqursively,
# NOTE that Cooley-Tukey requires n to be a power of two
def FFT(A):
n = len(A)
if n==1:
return A
assert(isPowerOfTwo(n))
even = FFT(A[::2])
odd = FFT(A[1::2])
# Numerically stable way of "twiddling"
return [even[k] + exp(-2*pi*k/n*1j)*odd[k] for k in range(n//2)] +\
[even[k] - exp(-2*pi*k/n*1j)*odd[k] for k in range(n//2)]
# Inverse FFT
def iFFT(A):
n = len(A)
A = FFT([a.conjugate() for a in A])
return [a.conjugate()/n for a in A]
# Circular convolution in O(nlog(n)) time
def circ_conv(A,B):
assert(len(A)==len(B))
n = len(A)
A = FFT(A)
B = FFT(B)
C = [0]*(n)
for i in range(n):
C[i]=A[i]*B[i]
return iFFT(C)
# Polynomial multiplication in O((n+m)log(n+m)) time
def conv(A,B):
n = len(A)
m = len(B)
N = 1
while N<n+m-1:
N*=2
A = A + [0]*(N-n)
B = B + [0]*(N-m)
C = circ_conv(A,B)
return C[:n+m-1]
# Example
A = [1,1,2,3,4,5,6+7j,100]
print(A)
print(FFT(A))
print(iFFT(FFT(A)))
# Multiply (1+2x+3x^2) and (2+3x+4x^2+5x^3)
A = [1,2,3]
B = [2,3,4,5]
print(A,B,conv(A,B))
If you want to actually use this code then there are two things you should consider
-
My implementation of FFT currently has a large constant cost because it is recursive. Switch to an in-place iterative implementation of Cooley-Tukey if you want to make it noticeable quicker. The wiki is great at explaining how to do this, note however that following the wiki might lead to really bad numerical stability.
-
My implementation is written in python, and the worst thing in python is how slow it is with recursion. Switch to c++ and the code will run much much quicker.
EDIT2: What the heck, I might as well go the whole way. The following is a pretty quick and advanced implementation of FFT and NTT that allows for convolution using FFT or NTT. You can switch between them freely. Still the basics about using convolution and circular convolution is exactly the same.
Click to view
from cmath import exp
from math import pi
# Quick convolution that can interchangeably use both
# FFT or NTT, see example at bottom.
# This uses a somewhat advanced implementation
# of Cooley-Tukey that hopefully runs quickly with high
# numerical stability.
# /pajenegod
def isPowerOfTwo(n):
return n>0 and (n&(n-1))==0
# Permutates A with a bit reversal
# Ex. [0,1,2,3,4,5,6,7]->[0,4,2,6,1,5,3,7]
def bit_reversal(A):
n = len(A)
assert(isPowerOfTwo(n))
k = 0
m = 1
while m<n:m*=2;k+=1
for i in range(n):
I = i
j = 0
for _ in range(k):
j = j*2 + i%2
i //= 2
if j>I:
A[I],A[j]=A[j],A[I]
return
### NTT ALGORITHM BASED ON COOLEY TUKEY
# Inplace NTT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented iteratively using bit reversal,
# NOTE that Cooley-Tukey requires n to be a power of two
# and also that n <= longest_conv, basically
# n is limited by the ntt_prime
# Remember to set ntt_prime and ntt_root before calling, for example
ntt_prime = (479<<21)+1
ntt_root = 3
def NTT_CT(A,inverse=False):
# Some pre-calulations needed to do the ntt
non_two = ntt_prime-1
longest_conv = 1
while (ntt_prime-1)%(2*longest_conv)==0:longest_conv*=2
ntt_base = pow(ntt_root,(ntt_prime-1)//longest_conv,ntt_prime)
N = len(A)
assert(isPowerOfTwo(N))
assert(N<=longest_conv)
for i in range(N):
A[i]%=ntt_prime
# Calculate the twiddle factors
e = pow(ntt_base,longest_conv//N,ntt_prime)
if inverse:
e = pow(e,ntt_prime-2,ntt_prime)
b = e
twiddles = [1]
while len(twiddles)<N//2:
twiddles += [t*b%ntt_prime for t in twiddles]
b = b**2%ntt_prime
bit_reversal(A)
n = 2
while n<=N:
offset = 0
while offset<N:
depth = N//n
for k in range(n//2):
ind1 = k + offset
ind2 = k+n//2 + offset
even = A[ind1]
odd = A[ind2]*twiddles[k*depth]
A[ind1] = (even + odd)%ntt_prime
A[ind2] = (even - odd)%ntt_prime
offset += n
n*=2
if inverse:
inv_N = pow(N,ntt_prime-2,ntt_prime)
for i in range(N):
A[i] = A[i]*inv_N%ntt_prime
return
### FFT ALGORITHM BASED ON Cooley-Tukey
# Inplace FFT using Cooley-Tukey, a divide and conquer algorithm
# running in O(n log(n)) time implemented iteratively using bit_reversal,
# NOTE that Cooley-Tukey requires n to be a power of two
def FFT_CT(A,inverse=False):
N = len(A)
assert(isPowerOfTwo(N))
# Calculate the twiddle factors, with very good numerical stability
e = -2*pi/N*1j
if inverse:
e = -e
twiddles = [exp(e*k) for k in range(N//2)]
bit_reversal(A)
n = 2
while n<=N:
offset = 0
while offset<N:
depth = N//n
for k in range(n//2):
ind1 = k + offset
ind2 = k+n//2 + offset
even = A[ind1]
odd = A[ind2]*twiddles[k*depth]
A[ind1] = even + odd
A[ind2] = even - odd
offset += n
n*=2
if inverse:
inv_N = 1.0/N
for i in range(N):
A[i]*=inv_N
return A
# Circular convolution in O(nlog(n)) time
def circ_conv(A,B):
assert(len(A)==len(B))
n = len(A)
A = list(A)
B = list(B)
FFT(A)
FFT(B)
C = [A[i]*B[i] for i in range(n)]
FFT(C,inverse=True)
return C
# Polynomial multiplication in O((n+m)log(n+m)) time
def conv(A,B):
n = len(A)
m = len(B)
N = 1
while N<n+m-1:
N*=2
A = A + [0]*(N-n)
B = B + [0]*(N-m)
C = circ_conv(A,B)
return C[:n+m-1]
# example
for ntt in [False,True]:
# Switch between using ntt or ftt for convolution
if ntt:
# Set ntt prime
ntt_prime = (119<<23)+1
ntt_root = 3
print('Using NTT for convolution')
FFT = NTT_CT
else:
print('Using FFT for convolution')
FFT = FFT_CT
# Example
A = [1,1,2,3,4,5,6,100]
print('A',A)
FFT(A)
print('FFT(A)',A)
FFT(A,inverse=True)
print('iFFT(FFT(A))',A)
# Multiply (1+2x+3x^2) and (2+3x+4x^2+5x^3)
A = [1,2,3]
B = [2,3,4,5]
print('A=',A)
print('B=',B)
print('A*B=',conv(A,B))
can anyone provide reference java implementation ?
Best will be have a look at someone’s code and learn how to use it if you don’t want to implement it by yourself. I do the same thing infact
Have a look now.
Thanks for the help!
We need more good people like @gorre_morre who actually tell why a post needs editing so some poor mod isnt confused here and there xD
@rds_98 I’ve just added an example of a fully working polynomial multiplier using Cooley-Tukey FFT algorithm.
In my original answer I didn’t describe how to implement the FFT, just how polynomial multiplication can be done using FFT. The most common and probably simplest way to implement FFT is using Cooley-Tukey algortihm, wiki has a pretty good article about the algorithm. So that is what I’m using for my example.
I doubt that it is possible to solve this particular problem using FFT written on JAVA, because JAVA hasn’t analogue of C long double type and I suspect that BigDecimal is too slow. I tried to solve this problem on D language using FFT from standard library which use doubles and got WA due precision errors. Only when I wrote FFT with real type (long doubles in C), I got AC. If I want to solve this problem on JAVA I prefer to use Karatsuba multiplication algorithm (divide and conquerer).