快速傅里叶变换(FFT)随笔 - Roor - 博客园
终于学会了FFT,水一篇随笔记录一下
前置知识网上一大堆,这里就不多赘述了,直接切入正题
01 介绍FFT
这里仅指出FFT在竞赛中的一般应用,即优化多项式乘法
一般情况下,计算两个规模为n的多项式相乘的结果,复杂度为O(n2),但是神奇的FFT可以将其优化至O(nlogn)
FFT的过程一般为:
多项式的系数表示⟶多项式的点值表示⟶多项式的系数表示
网上对每一步的叫法都有一定出入,这里称第一步变换为快速傅里叶变换,第二步为快速傅里叶逆变换
02快速傅里叶变换
先指出,接下来的每个n都是2的整数次幂
首先我们有一个已知系数表达的n项的多项式
A(x)=a0+a1x+a2x2+⋯+an−1xn−1
要确定其的点值表达(y0,y1,y2,…,yn−1),朴素的做法就是取n个不同值代进去,这么做显然是O(n2)
下面介绍快速傅里叶变换的做法
首先将多项式按照奇偶分类
A(x)=(a0+a2x2+⋯+an−2xn−2)+(a1x+a3x3+⋯+an−1xn−1)
A(x)=(a0+a2x2+⋯+an−2xn−2)+x⋅(a1+a3x2+⋯+an−1xn−2)
设
A1(x)=a0+a2x+⋯+an−2x2n−2
A2(x)=a1+a3x+⋯+an−1x2n−2
不难发现
A(x)=A1(x2)+xA2(x2)
令k<2n
将ωnk代入得
A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)
A(ωnk)=A1(ω2nk)+ωnkA2(ω2nk)
将ωnk+2n代入得
A(ωnk+2n)=A1(ωn2k+n)+ωnk+2nA2(ωn2k+n)
A(ωnk+2n)=A1(ωn2k⋅ωnn)−ωnkA2(ωn2k⋅ωnn)
A(ωnk+2n)=A1(ωn2k)−ωnkA2(ωn2k)
A(ωnk+2n)=A1(ω2nk)−ωnkA2(ω2nk)
显然的,这两个式子只有常数项不同
当k取遍[0,2n−1]中所有值时k+2n也取遍[2n,n−1]中所有值
因此,我们只需要在[0,2n−1]中枚举k,这样就可以算出A(ωni)(i∈[0,n−1])的所有值
如果我们已知A1(x),A2(x)在ω2n0,ω2n1,…,ω2n2n−1的值,通过上面的两个式子就可以在O(n)的时间内求出A(x)
而求A1(x),A2(x)正好是求A(x)的子问题,并且可以递归求解
03快速傅里叶逆变换
在上面我们将一个多项式的系数表示转换成了点值表示,这里我们要研究将一个多项式的点值表示转换成系数表示
记(a0,a1,…,an−1)是A(x)的系数向量,而我们已知A(x)的点值表达为(A(x0),A(x1),…,A(xn−1))
设向量(d0,d1,…,dn−1)是以(a0,a1,…,an−1)为系数向量时,快速傅里叶变换求得的点值表示
构造一个多项式F(x)=d0+d1x+d2x2+⋯+dn−1xn−1
设(c0,c1,…,cn−1)是F(x)在x=ωn−k时的点值表示,即ck=F(ωn−k),也就是ck=∑i=0n−1di(ωn−k)i
我们知道dk=A(ωnk),也就是dk=∑j=0n−1aj(ωnk)j
联立上面两个和式得
ck=∑i=0n−1[∑j=0n−1aj(ωni)j](ωn−k)i
=∑i=0n−1∑j=0n−1aj(ωnj)i(ωn−k)i
=∑j=0n−1aj∑i=0n−1(ωnjωn−k)i
=∑j=0n−1aj∑i=0n−1(ωnj−k)i
我们分情况讨论后面的一个和式∑i=0n−1(ωnj−k)i
j= k
那么后面的一个和式就转换为一个等比求和
∑i=0n−1(ωnj−k)i=1−ωnj−k(ωnj−k)0[1−(ωnj−k)n]
=1−ωnj−k1−(ωnj−k)n
=1−ωnj−k1−(ωnn)j−k
=1−ωnj−k1−1j−k
=1−ωnj−k0
=0
j=k
那么ωnj−k=1
∑i=0n−1(ωnj−k)i=n
由上面两种情况,我们知道当且仅当j=k时,整个式子才有值,其余情况都为0
所以有
cj=ajn
aj=ncj
到这里,我们就求出了A(x)的系数表达
从整个分析过程看,我们是将A(x)的点值表示(A(x0),A(x1),…,A(xn−1))当作一个新的多项式F(x)的系数表示,再对F(x)做快速傅里叶变换得到(c0,c1,…,cn−1),然后再除以n就得到A(x)的系数表示了。需要指出的是,快速傅里叶变换中x=ωnk但是在逆变换中代入的是ωn−k
04实现
学会了前面的方法,具体实现就不难了
对于求C(x)=A(x)⋅B(x)
将A(x)和B(x)都转化成点值表达,即(a0,a1,…,an−1)和(b0,b1,…,bn−1)
对应相乘(a0b0,a1b1,…,an−1bn−1),再将这一结果变换成C(x)的系数表达就完成了
贴一份C++的代码,这是洛谷上的FFT板子题P3803
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define MAXN 4000006
using namespace std;
class complex
{
public:
complex(){}
complex(double a,double b)
{
this->a=a;
this->b=b;
}
double a,b;
}a[MAXN],b[MAXN];
complex operator+ (complex x,complex y)
{
return complex(x.a+y.a,x.b+y.b);
}
complex operator- (complex x,complex y)
{
return complex(x.a-y.a,x.b-y.b);
}
complex operator* (complex x,complex y)
{
return complex(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);
}
const double pi=acos(-1.0);
void FFT(int l,complex *arr,int f)
{
if(l==1) return;
int dl=l>>1;
complex a1[dl],a2[dl];
for(int i=0;i<l;i+=2)
{
a1[i>>1]=arr[i];
a2[i>>1]=arr[i+1];
}
FFT(dl,a1,f);
FFT(dl,a2,f);
complex wn=complex(cos(2.0*pi/l),sin(2.0*pi/l)*f),w=complex(1.0,0.0);
for(int i=0;i<dl;i++,w=w*wn)
{
arr[i]=a1[i]+w*a2[i];
arr[i+dl]=a1[i]-w*a2[i];
}
}
int n,m,N;
int main()
{
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%lf",&a[i].a);
for(int i=0;i<=m;i++)
scanf("%lf",&b[i].a);
N=1;
while(N<n+m+1) N<<=1;
FFT(N,a,1);
FFT(N,b,1);
for(int i=0;i<N;i++)
a[i]=a[i]*b[i];
FFT(N,a,-1);
for(int i=0;i<n+m+1;i++)
printf("%d ",(int)(a[i].a/N+0.5));
puts("");
return 0;
}
闲着没事干,再贴一份Python的
import numpy as np
pi = np.arccos(-1.0)
def read():
def get_numbers():
try:
read.s = input().split()
read.s_len = len(read.s)
if read.s_len == 0:
get_numbers()
read.cnt = 0
return 1
except:
return 0
if not hasattr(read, 'cnt'):
if not get_numbers():
return 0
if read.cnt == read.s_len:
if not get_numbers():
return 0
read.cnt += 1
return eval(read.s[read.cnt - 1])
n = int(read())
m = int(read())
class Complex:
# 复数类
def __init__(self, a=0.0, b=0.0):
self.a = a
self.b = b
def __add__(self, other):
return Complex(self.a + other.a, self.b + other.b)
def __sub__(self, other):
return Complex(self.a - other.a, self.b - other.b)
def __mul__(self, other):
return Complex(self.a * other.a - self.b * other.b, self.a * other.b + self.b * other.a)
def fft(num, f, args):
if num == 1:
return
div_num = num >> 1
a1 = []
a2 = []
for i in range(0, num, 2):
a1.append(args[i])
a2.append(args[i + 1])
fft(div_num, f, a1)
fft(div_num, f, a2)
wn = Complex(np.cos(2.0 * pi / num), np.sin(2.0 * pi / num) * f)
w = Complex(1.0, 0.0)
for i in range(0, div_num):
args[i] = a1[i] + w * a2[i]
args[i + div_num] = a1[i] - w * a2[i]
w = w * wn
aa = []
bb = []
for j in range(0, n + 1):
aa.append(Complex(float(read()), 0.0))
for j in range(0, m + 1):
bb.append(Complex(float(read()), 0.0))
nn = 1
while nn < n + m + 1:
nn <<= 1
for j in range(n + 1, nn):
aa.append(Complex(0.0, 0.0))
for j in range(m + 1, nn):
bb.append(Complex(0.0, 0.0))
fft(nn, 1, aa)
fft(nn, 1, bb)
for j in range(0, nn):
aa[j] = aa[j] * bb[j]
fft(nn, -1, aa)
for j in range(0, n + m + 1):
print(int(aa[j].a / nn + 0.5), end=' ')
无奈Python实在是太慢了……
05结语
总算是学会了快速傅里叶变换,某种程度上说是弥补了过去的某些遗憾吧。
这里贴一张大佬的图,解释了FFT的思路
这里也推荐一下大佬的博客,以供参考