Skip to content

快速傅里叶变换(FFT)随笔 - Roor - 博客园

约 2110 字大约 7 分钟

FFT数学

2023-02-14

终于学会了FFT,水一篇随笔记录一下

前置知识网上一大堆,这里就不多赘述了,直接切入正题

01 介绍FFT

这里仅指出FFT在竞赛中的一般应用,即优化多项式乘法

一般情况下,计算两个规模为nn的多项式相乘的结果,复杂度为O(n2)O(n^2),但是神奇的FFT可以将其优化至O(nlogn)O(nlogn)

FFT的过程一般为:

多项式的系数表示\longrightarrow多项式的点值表示\longrightarrow多项式的系数表示

网上对每一步的叫法都有一定出入,这里称第一步变换为快速傅里叶变换,第二步为快速傅里叶逆变换

02快速傅里叶变换

先指出,接下来的每个nn都是22的整数次幂

首先我们有一个已知系数表达的nn项的多项式

A(x)=a0+a1x+a2x2++an1xn1A(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}

要确定其的点值表达(y0,y1,y2,,yn1)(y_0,y_1,y_2,\dots,y_{n-1}),朴素的做法就是取nn个不同值代进去,这么做显然是O(n2)O(n^2)

下面介绍快速傅里叶变换的做法

首先将多项式按照奇偶分类

A(x)=(a0+a2x2++an2xn2)+(a1x+a3x3++an1xn1)A(x)=(a_0+a_2x^2+\dots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\dots+a_{n-1}x^{n-1})

A(x)=(a0+a2x2++an2xn2)+x(a1+a3x2++an1xn2)A(x)=(a_0+a_2x^2+\dots+a_{n-2}x^{n-2})+x\cdot(a_1+a_3x^2+\dots+a_{n-1}x^{n-2})

A1(x)=a0+a2x++an2xn22A_1(x)=a_0+a_2x+\dots+a_{n-2}x^{\tfrac{n-2}{2}}

A2(x)=a1+a3x++an1xn22A_2(x)=a_1+a_3x+\dots+a_{n-1}x^{\tfrac{n-2}{2}}

不难发现

A(x)=A1(x2)+xA2(x2)A(x)=A_1(x^2)+xA_2(x^2)

k<n2k<\frac{n}{2}

ωnk\omega_{n}^k代入得

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)A(\omega_{n}^k)=A_1(\omega_{n}^{2k})+\omega_{n}^{k}A_2(\omega_{n}^{2k})

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)A(\omega_{n}^k)=A_1(\omega_{\tfrac{n}{2}}^{k})+\omega_{n}^{k}A_2(\omega_{\tfrac{n}{2}}^{k})

ωnk+n2\omega_{n}^{k+\frac{n}{2}}代入得

A(ωnk+n2)=A1(ωn2k+n)+ωnk+n2A2(ωn2k+n)A(\omega_{n}^{k+\tfrac{n}{2}})=A_1(\omega_{n}^{2k+n})+\omega_{n}^{k+\tfrac{n}{2}}A_2(\omega_{n}^{2k+n})

A(ωnk+n2)=A1(ωn2kωnn)ωnkA2(ωn2kωnn)A(\omega_{n}^{k+\tfrac{n}{2}})=A_1(\omega_{n}^{2k}\cdot\omega_{n}^{n})-\omega_{n}^{k}A_2(\omega_{n}^{2k}\cdot\omega_{n}^{n})

A(ωnk+n2)=A1(ωn2k)ωnkA2(ωn2k)A(\omega_{n}^{k+\tfrac{n}{2}})=A_1(\omega_{n}^{2k})-\omega_{n}^{k}A_2(\omega_{n}^{2k})

A(ωnk+n2)=A1(ωn2k)ωnkA2(ωn2k)A(\omega_{n}^{k+\tfrac{n}{2}})=A_1(\omega_{\tfrac{n}{2}}^{k})-\omega_{n}^{k}A_2(\omega_{\tfrac{n}{2}}^{k})

显然的,这两个式子只有常数项不同

kk取遍[0,n21][0,\frac{n}{2}-1]中所有值时k+n2k+\dfrac{n}{2}也取遍[n2,n1][\dfrac{n}{2},n-1]中所有值

因此,我们只需要在[0,n21][0,\dfrac{n}{2}-1]中枚举kk,这样就可以算出A(ωni)(i[0,n1])A(\omega_{n}^i)\quad(i\in[0,n-1])的所有值

如果我们已知A1(x),A2(x)A_1(x),A_2(x)ωn20,ωn21,,ωn2n21\omega_{\tfrac{n}{2}}^0,\omega_{\tfrac{n}{2}}^1,\dots,\omega_{\tfrac{n}{2}}^{\tfrac{n}{2}-1}的值,通过上面的两个式子就可以在O(n)O(n)的时间内求出A(x)A(x)

而求A1(x),A2(x)A_1(x),A_2(x)正好是求A(x)A(x)的子问题,并且可以递归求解

03快速傅里叶逆变换

在上面我们将一个多项式的系数表示转换成了点值表示,这里我们要研究将一个多项式的点值表示转换成系数表示

(a0,a1,,an1)(a_0,a_1,\dots,a_{n-1})A(x)A(x)的系数向量,而我们已知A(x)A(x)的点值表达为(A(x0),A(x1),,A(xn1))(A(x_0),A(x_1),\dots,A(x_{n-1}))

设向量(d0,d1,,dn1)(d_0,d_1,\dots,d_{n-1})是以(a0,a1,,an1)(a_0,a_1,\dots,a_{n-1})为系数向量时,快速傅里叶变换求得的点值表示

构造一个多项式F(x)=d0+d1x+d2x2++dn1xn1F(x)=d_0+d_1x+d_2x^2+\dots+d_{n-1}x^{n-1}

(c0,c1,,cn1)(c_0,c_1,\dots,c_{n-1})F(x)F(x)x=ωnkx=\omega_n^{-k}时的点值表示,即ck=F(ωnk)c_k=F(\omega_n^{-k}),也就是ck=i=0n1di(ωnk)ic_k=\sum_{i=0}^{n-1}d_i(\omega_n^{-k})^i

我们知道dk=A(ωnk)d_k=A(\omega_n^k),也就是dk=j=0n1aj(ωnk)jd_k=\sum_{j=0}^{n-1}a_j(\omega_n^k)^j

联立上面两个和式得

ck=i=0n1[j=0n1aj(ωni)j](ωnk)ic_k=\sum_{i=0}^{n-1} [\sum_{j=0}^{n-1}a_j(\omega_n^i)^j] (\omega_n^{-k})^i

=i=0n1j=0n1aj(ωnj)i(ωnk)i\quad \:=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1}a_j(\omega_n^j)^i (\omega_n^{-k})^i

=j=0n1aji=0n1(ωnjωnk)i\quad \:=\sum_{j=0}^{n-1} a_j \sum_{i=0}^{n-1} (\omega_n^j \omega_n^{-k})^i

=j=0n1aji=0n1(ωnjk)i\quad \:=\sum_{j=0}^{n-1} a_j \sum_{i=0}^{n-1} (\omega_n^{j-k})^i

我们分情况讨论后面的一个和式i=0n1(ωnjk)i\sum_{i=0}^{n-1} (\omega_n^{j-k})^i

j kj \neq\ k

那么后面的一个和式就转换为一个等比求和

i=0n1(ωnjk)i=(ωnjk)0[1(ωnjk)n]1ωnjk\sum_{i=0}^{n-1} (\omega_n^{j-k})^i=\frac{{(\omega_n^{j-k})}^0 [1-(\omega_n^{j-k})^n]}{1-\omega_n^{j-k}}

=1(ωnjk)n1ωnjk\qquad \qquad \quad \: \: \:=\frac{1-(\omega_n^{j-k})^n}{1-\omega_n^{j-k}}

=1(ωnn)jk1ωnjk\qquad \qquad \quad \: \: \:=\frac{1-(\omega_n^n)^{j-k}}{1-\omega_n^{j-k}}

=11jk1ωnjk\qquad \qquad \quad \: \: \:=\frac{1-1^{j-k}}{1-\omega_n^{j-k}}

=01ωnjk\qquad \qquad \quad \: \: \:=\frac{0}{1-\omega_n^{j-k}}

=0\qquad \qquad \quad \: \: \:=0

j=kj = k

那么ωnjk=1\omega_n^{j-k} = 1

i=0n1(ωnjk)i=n\sum_{i=0}^{n-1} (\omega_n^{j-k})^i = n

由上面两种情况,我们知道当且仅当j=kj = k时,整个式子才有值,其余情况都为00

所以有

cj=ajnc_j=a_jn

aj=cjna_j = \frac{c_j}{n}

到这里,我们就求出了A(x)A(x)的系数表达

从整个分析过程看,我们是将A(x)A(x)的点值表示(A(x0),A(x1),,A(xn1))(A(x_0),A(x_1),\dots,A(x_{n-1}))当作一个新的多项式F(x)F(x)的系数表示,再对F(x)F(x)做快速傅里叶变换得到(c0,c1,,cn1)(c_0,c_1,\dots,c_{n-1}),然后再除以nn就得到A(x)A(x)的系数表示了。需要指出的是,快速傅里叶变换中x=ωnkx=\omega_n^k但是在逆变换中代入的是ωnk\omega_n^{-k}

04实现

学会了前面的方法,具体实现就不难了

对于求C(x)=A(x)B(x)C(x)=A(x) \cdot B(x)

A(x)A(x)B(x)B(x)都转化成点值表达,即(a0,a1,,an1)(a_0,a_1,\dots,a_{n-1})(b0,b1,,bn1)(b_0,b_1,\dots,b_{n-1})

对应相乘(a0b0,a1b1,,an1bn1)(a_0b_0,a_1b_1,\dots,a_{n-1}b_{n-1}),再将这一结果变换成C(x)C(x)的系数表达就完成了

贴一份C++的代码,这是洛谷上的FFT板子题P3803

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cmath>
    #define MAXN 4000006
    using namespace std;
    class complex
    {
    public:
        complex(){}
        complex(double a,double b)
        {
            this->a=a;
            this->b=b;
        }
        double a,b;
    }a[MAXN],b[MAXN];
    complex operator+ (complex x,complex y)
    {
        return complex(x.a+y.a,x.b+y.b);
    }
    complex operator- (complex x,complex y)
    {
        return complex(x.a-y.a,x.b-y.b);
    }
    complex operator* (complex x,complex y)
    {
        return complex(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);
    }
    const double pi=acos(-1.0);
    void FFT(int l,complex *arr,int f)
    {
        if(l==1) return;
        int dl=l>>1;
        complex a1[dl],a2[dl];
        for(int i=0;i<l;i+=2)
        {
            a1[i>>1]=arr[i];
            a2[i>>1]=arr[i+1];
        }
        FFT(dl,a1,f);
        FFT(dl,a2,f);
        complex wn=complex(cos(2.0*pi/l),sin(2.0*pi/l)*f),w=complex(1.0,0.0);
        for(int i=0;i<dl;i++,w=w*wn)
        {
            arr[i]=a1[i]+w*a2[i];
            arr[i+dl]=a1[i]-w*a2[i];
        }
    }
    int n,m,N;
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=0;i<=n;i++)
            scanf("%lf",&a[i].a);
        for(int i=0;i<=m;i++)
            scanf("%lf",&b[i].a);
        N=1;
        while(N<n+m+1) N<<=1;
        FFT(N,a,1);
        FFT(N,b,1);
        for(int i=0;i<N;i++)
            a[i]=a[i]*b[i];
        FFT(N,a,-1);
        for(int i=0;i<n+m+1;i++)
            printf("%d ",(int)(a[i].a/N+0.5));
        puts("");
        return 0;
    }

闲着没事干,再贴一份Python的

    import numpy as np
    
    pi = np.arccos(-1.0)
    
    
    def read():
        def get_numbers():
            try:
                read.s = input().split()
                read.s_len = len(read.s)
                if read.s_len == 0:
                    get_numbers()
                read.cnt = 0
                return 1
            except:
                return 0
    
        if not hasattr(read, 'cnt'):
            if not get_numbers():
                return 0
        if read.cnt == read.s_len:
            if not get_numbers():
                return 0
        read.cnt += 1
        return eval(read.s[read.cnt - 1])
    
    
    n = int(read())
    m = int(read())
    
    
    class Complex:
        # 复数类
    
        def __init__(self, a=0.0, b=0.0):
            self.a = a
            self.b = b
    
        def __add__(self, other):
            return Complex(self.a + other.a, self.b + other.b)
    
        def __sub__(self, other):
            return Complex(self.a - other.a, self.b - other.b)
    
        def __mul__(self, other):
            return Complex(self.a * other.a - self.b * other.b, self.a * other.b + self.b * other.a)
    
    
    def fft(num, f, args):
        if num == 1:
            return
        div_num = num >> 1
        a1 = []
        a2 = []
        for i in range(0, num, 2):
            a1.append(args[i])
            a2.append(args[i + 1])
        fft(div_num, f, a1)
        fft(div_num, f, a2)
        wn = Complex(np.cos(2.0 * pi / num), np.sin(2.0 * pi / num) * f)
        w = Complex(1.0, 0.0)
    
        for i in range(0, div_num):
            args[i] = a1[i] + w * a2[i]
            args[i + div_num] = a1[i] - w * a2[i]
            w = w * wn
    
    
    aa = []
    bb = []
    for j in range(0, n + 1):
        aa.append(Complex(float(read()), 0.0))
    for j in range(0, m + 1):
        bb.append(Complex(float(read()), 0.0))
    
    nn = 1
    while nn < n + m + 1:
        nn <<= 1
    
    for j in range(n + 1, nn):
        aa.append(Complex(0.0, 0.0))
    for j in range(m + 1, nn):
        bb.append(Complex(0.0, 0.0))
    
    fft(nn, 1, aa)
    fft(nn, 1, bb)
    
    for j in range(0, nn):
        aa[j] = aa[j] * bb[j]
    fft(nn, -1, aa)
    
    for j in range(0, n + m + 1):
        print(int(aa[j].a / nn + 0.5), end=' ')

无奈Python实在是太慢了……

05结语

总算是学会了快速傅里叶变换,某种程度上说是弥补了过去的某些遗憾吧。

这里贴一张大佬的图,解释了FFT的思路

这里也推荐一下大佬的博客,以供参考

快速傅里叶变换(FFT)详解 - 自为风月马前卒 - 博客园 (cnblogs.com)

一小时学会快速傅里叶变换(Fast Fourier Transform) - 知乎 (zhihu.com)