前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用pytorch和卷积实现stft/istft

使用pytorch和卷积实现stft/istft

原创
作者头像
languageX
发布2021-12-01 20:34:48
4.3K2
发布2021-12-01 20:34:48
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

语音项目中我们通常会使用stft对特征进行提取,很多python库也提供了接口。本文主要介绍使用librosa,torch,以及卷积方式进行stft和istft的运算。

1. stft运算

关于傅里叶变换和逆变换的基础知识在之前文章中已经做过介绍:https://cloud.tencent.com/developer/article/1811451

这里就不再介绍了,下面直接通过代码来得出音频振幅谱和相位谱。

2. librosa接口

librosa提供的接口非常简单,我们通过一个例子进行stft和istft来恢复一段音频

代码语言:javascript
复制
def test_lib(data):
    win_len = 320
    win_hop = 160
    fft_len = 512
    spec = librosa.stft(data, window='hann', win_length=win_len, n_fft=fft_len, hop_length=win_hop,
                                center=True)
    outputs = librosa.istft(spec, window='hann', win_length=win_len, hop_length=win_hop,
                 center=True)
    sf.write('./lib_stft.wav', outputs, 16000)
    return outputs

其中librosa_stft是一个复数形式,我们可以获取其中的一些特征,比如

代码语言:javascript
复制
# 实部
real = np.real(spec)
# 虚部
imag = np.imag(spec)
# 振幅谱
mags = np.sqrt(real ** 2 + imag ** 2)
# 相位谱
phase = np.angle(spec)

3. torch接口

同样我们通过一个例子使用torch提供的接口来进行stft和istft恢复一段音频

代码语言:javascript
复制
def test_torch(inputs):
    fft_len=512
    win_len=320
    len_hop=160
    inputs = torch.from_numpy(inputs.reshape(1,-1).astype(np.float32))
    window = torch.hann_window(win_len)
    spec = torch.stft(inputs, fft_len, len_hop, win_len, window, center=True, return_complex=False)
    print("stft out", spec.shape)
    out = torch.istft(spec, fft_len, len_hop, win_len, window, True, return_complex=False)
    return out

其中spec是一个虚部和实部concatenate一起的,我们同样可以获取其中的一些特征:

代码语言:javascript
复制
real = spec[:, :, :, 0]  # 实部
imag = spec[:, :, :, 1]  # 虚部
mags = torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2)))
phase = torch.atan2(imag.data, real.data)

4. 利用卷积实现stft

python中使用librosa以及pytorch中使用接口都是很常用的特征提取方式,但是有时我们需要将算子移植到终端就比较麻烦,框架通常不直接提供这两个op,所以使用卷积实现stft和istft更容易进行工程移植。

我参考了这里的实现:https://github.com/huyanxin/DeepComplexCRN/blob/master/conv_stft.py

其中在使用test_fft()测试时会提示错误,所以对代码进行了一点修改,其中修改地方添加了注释:

代码语言:javascript
复制
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from scipy.signal import get_window


def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
    if win_type == 'None' or win_type is None:
        window = np.ones(win_len)
    else:
        window = get_window(win_type, win_len, fftbins=True)#**0.5
    
    N = fft_len
    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
    real_kernel = np.real(fourier_basis)
    imag_kernel = np.imag(fourier_basis)
    kernel = np.concatenate([real_kernel, imag_kernel], 1).T
    
    if invers :
        kernel = np.linalg.pinv(kernel).T 

    kernel = kernel*window
    kernel = kernel[:, None, :]
    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))


class ConvSTFT(nn.Module):

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        super(ConvSTFT, self).__init__() 
        
        if fft_len == None:
            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        
        kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
        #self.weight = nn.Parameter(kernel, requires_grad=(not fix))
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.stride = win_inc
        self.win_len = win_len
        self.dim = self.fft_len

    def forward(self, inputs):
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)
        # 注意这里pad方式的对齐
        inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride], mode='reflect')
        outputs = F.conv1d(inputs, self.weight, stride=self.stride)
        # 前半段系数为实数,后半段系数为虚数
        if self.feature_type == 'complex':
            return outputs
        else:
            dim = self.dim//2+1
            real = outputs[:, :dim, :]
            imag = outputs[:, dim:, :]
            mags = torch.sqrt(real**2+imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase

class ConviSTFT(nn.Module):

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        super(ConviSTFT, self).__init__() 
        if fft_len == None:
            self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
        #self.weight = nn.Parameter(kernel, requires_grad=(not fix))
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.win_type = win_type
        self.win_len = win_len
        self.stride = win_inc
        self.stride = win_inc
        self.dim = self.fft_len
        self.register_buffer('window', window)
        self.register_buffer('enframe', torch.eye(win_len)[:,None,:])

    def forward(self, inputs, phase=None):
        """
        inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
        phase: [B, N//2+1, T] (if not none)
        """ 

        if phase is not None:
            real = inputs*torch.cos(phase)
            imag = inputs*torch.sin(phase)
            inputs = torch.cat([real, imag], 1)
        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 

        # this is from torch-stft: https://github.com/pseeth/torch-stft
        t = self.window.repeat(1,1,inputs.size(-1))**2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
        outputs = outputs/(coff+1e-8)
        #outputs = torch.where(coff == 0, outputs, outputs/coff)
        outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)] 
        return outputs

def test_fft():
    torch.manual_seed(20)
    win_len = 320
    win_inc = 160 
    fft_len = 512
    inputs = torch.randn([1, 1, 16000*4])
    fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real')
    import librosa
    outputs1 = fft(inputs)[0]
    outputs1 = outputs1.numpy()[0]
    np_inputs = inputs.numpy().reshape([-1])
    # center=True, 在input的两侧,分别镜像填充n_fft//2个数据
    librosa_stft = librosa.stft(np_inputs, window='hann',win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=True)
    print(np.mean((outputs1 - np.abs(librosa_stft))**2))

def test_conv_complex(data):
    inputs = data.reshape([1, 1, -1])
    N = 320
    inc = 160
    fft_len = 512
    fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
    ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex')
    inputs = torch.from_numpy(inputs.astype(np.float32))
    outputs1 = fft(inputs)
    outputs2 = ifft(outputs1)
    sf.write('./conv_stft_complex.wav', outputs2.numpy()[0, 0, :], 16000)
    return outputs2.numpy()[0, 0, :]
    
if __name__ == '__main__':
    test_fft()
    #test_conv_complex(data)

总结下如果是python项目可以直接使用librosa接口,如果是pytorch项目可以直接使用torch接口,如果是需要模型移植到终端的项目,建议可使用卷积方式方便移植~

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. stft运算
  • 2. librosa接口
  • 3. torch接口
  • 4. 利用卷积实现stft
相关产品与服务
语音识别
腾讯云语音识别(Automatic Speech Recognition,ASR)是将语音转化成文字的PaaS产品,为企业提供精准而极具性价比的识别服务。被微信、王者荣耀、腾讯视频等大量业务使用,适用于录音质检、会议实时转写、语音输入法等多个场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档