首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >快速傅里叶变换FFT& 数论变换NTT

快速傅里叶变换FFT& 数论变换NTT

作者头像
饶文津
发布2020-06-02 16:04:01
8690
发布2020-06-02 16:04:01
举报
文章被收录于专栏:饶文津的专栏饶文津的专栏

相关知识

快速傅里叶变换FFT

递归的FFT(c++代码)

typedef complex<double> CD;
const double pi = acos(-1);
CD tmp[N],epsilon[N];
void init_epsilon(int n){
    for(int i = 0; i < n; ++i){
        epsilon[i] = CD(cos(2.0 * pi * i / n), sin(2.0 * pi * i / n)); 
        arti_epsilon[i] = conj(epsilon[i]);
    }
}
void recursive_fft(int n, CD* A,int offset, int step, CD* w){
    if(n==1)return;
    int m=n>>1;
    recursive_fft(m,A,offset,step<<1,w);
    recursive_fft(m,A,offset+step,step<<1,w);
    for(int k=0;k<m;++k){
        int pos=2*step*k;
        tmp[k]  =A[pos+offset]+w[k*step]*A[pos+offset+step];
        tmp[k+m]=A[pos+offset]-w[k*step]*A[pos+offset+step];
    }
    for(int i=0;i<n;++i)
        A[i*step+offset]=tmp[i];
}

迭代实现

但是递归需要比较大的空间,如何实现迭代的写法?

观察递归的过程,第一步:

0(000)2(010)4(100)6(110),1(001)3(011)5(101)7(111)

第二步:

0 (000)4(100),2(010) 6(110),1(001)5(101)3(011)7(111)

将下标的二进制翻转过来就是:

000,001,010,011,100,101,110,111

对应了0,1,2,3,...

用reverse(i)表示i的二进制翻转后的数。就相当于二进制从高位到低位的+1,是从左往右遇到第一个0,就改为1,左边的1改为0。

int reverse(int x){
    for(int l=1<<bit_length;(x^=l)<l;l>>=l);
    return x;
}

我们得到递归最后一步的数组:

0,4,2,6,1,5,3,7

就可以从下到上迭代了。

void bit_reverse(CD* A,int n){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(A[i],A[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
}
void fft(CD* A, int n, CD* w){
    bit_reverse(A,n);
    for(int i=2;i<=n;i<<=1)//自下到上,i为每一层的步长,或者说子问题长度
      for(int j=0,m=i>>1;j<n;j+=i)//j为偏移量,或者说这一层的每一个子问题的起点
        for(int k=0;k<m;++k){//k=0..i/2,计算出子问题的第k个和第k+i/2个的值
          CD b=w[n/i*k]*A[j+m+k];
          A[j+m+k]=A[j+k]-b;
          A[j+k]+=b;
        }
}

离散傅里叶逆变换(IDFT)

FFT解决高精度乘法,c++代码

51 Nod 1028 大数乘法 V2

注意FFT因为用了cos和sin,以及是浮点数计算,所以会有精度误差,故加了0.5。

#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(int i=l;i<r;++i)
#define per(i,l,r) for(int i=r-1;i>=l;--i)
#define SZ(x) ((int)(x).size())

typedef double dd;
typedef complex<dd> CD;
const dd PI=acos(-1.0);
const int L=18,N=1<<L;

CD eps[N],inv_eps[N],f[N],g[N];
void init_eps(int p){
    rep(i,0,p)eps[i]=CD(cos(PI*i*2/p),sin(PI*i*2/p)),inv_eps[i]=conj(eps[i]);
}
void fft(CD p[], int n, CD w[]){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                CD b=w[n/i*k]*p[j+m+k];
                p[j+m+k]=p[j+k]-b;
                p[j+k]+=b;
            }
}
int ans[N];
int main(){
    string a,b;
    cin>>a>>b;
    int n=max(SZ(a),SZ(b)),p=1;
    while(p<n)p<<=1;p<<=1;
    rep(i,0,p)f[i]=g[i]=0;
    n=0;per(i,0,SZ(a))f[n++]=a[i]-'0';
    n=0;per(i,0,SZ(b))g[n++]=b[i]-'0';

    init_eps(p);
    fft(f,p,eps);fft(g,p,eps);
    rep(i,0,p)f[i]*=g[i];
    fft(f,p,inv_eps);

    int t=0;
    rep(i,0,p){
        ans[i]=t+(f[i].real()+0.5)/p;
        if(ans[i]>9){t=ans[i]/10;ans[i]%=10;}
        else t=0;
    }
    bool flag=0;
    per(i,0,p)if(ans[i]||flag){
        printf("%d",ans[i]);flag=1;
    }
    if(flag==0)puts("0");
    return 0;
}

NTT的c++代码

仍然是上面高精度乘法那题

#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(int i=l;i<r;++i)
#define per(i,l,r) for(int i=r-1;i>=l;--i)
#define SZ(x) ((int)(x).size())

typedef long long LL;
const int L=18,N=1<<L;

const LL C = 479;
const LL P = (C << 21) + 1;
const LL G = 3;

LL qpow(LL a, LL b, LL m){
    LL ans = 1;
    for(a%=m;b;b>>=1,a=a*a%m)if(b&amp;1)ans=ans*a%m;
    return ans;
}

LL eps[N],inv_eps[N],f[N],g[N];
void init_eps(int n){
    LL t=(P-1)/n, invG=qpow(G,P-2,P);
    rep(i,0,n) eps[i]=qpow(G,t*i,P),inv_eps[i]=qpow(invG,t*i,P);
}
void fft(LL p[], int n, LL w[]){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                LL b=w[n/i*k]*p[j+m+k]%P;
                p[j+m+k]=(p[j+k]-b+P)%P;
                p[j+k]=(p[j+k]+b)%P;
            }
}
LL ans[N];
int main(){
    string a,b;
    cin>>a>>b;
    int n=max(SZ(a),SZ(b)),p=1;
    while(p<n)p<<=1;p<<=1;
    rep(i,0,p)f[i]=g[i]=0;
    n=0;per(i,0,SZ(a))f[n++]=a[i]-'0';
    n=0;per(i,0,SZ(b))g[n++]=b[i]-'0';

    init_eps(p);
    fft(f,p,eps);fft(g,p,eps);
    rep(i,0,p)f[i]=f[i]*g[i]%P;
    fft(f,p,inv_eps);

    int t=0;
    LL invp=qpow(p,P-2,P);
    rep(i,0,p){
        ans[i]=(t+f[i]*invp%P)%P;
        if(ans[i]>9){t=ans[i]/10;ans[i]%=10;}
        else t=0;
    }
    bool flag=0;
    per(i,0,p)if(ans[i]||flag){
        printf("%lld",ans[i]);flag=1;
    }
    if(flag==0)puts("0");
    return 0;
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018-09-18 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 相关知识
  • 快速傅里叶变换FFT
  • 递归的FFT(c++代码)
    • 迭代实现
    • 离散傅里叶逆变换(IDFT)
    • FFT解决高精度乘法,c++代码
    • NTT的c++代码
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档