if sample_rate != sr:
waveform = torchaudio.transforms.Resample(sample_rate, sr)(waveform)
sample_rate = sr我想知道这里的Resamle是怎么工作的。所以看看torchaudio的文档。我以为会有__call__函数。因为重采样被用作函数。我是说那个Resample()(waveform)。但在内部,只有__init__和forward函数。我认为forward函数是工作函数,但我不知道为什么它被命名为'forward‘而不是__call__。我遗漏了什么?
class Resample(torch.nn.Module):
r"""Resample a signal from one frequency to another. A resampling method can be given.
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
"""
def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
# unpack batch
waveform = waveform.view(shape[:-1] + waveform.shape[-1:])
return waveform
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))-编辑--
我环顾了一下torch.nn.module。没有def __call__。但只有__call__ : Callable[..., Any] = _call_impl才能成为解决之道呢?
发布于 2020-08-19 14:25:09
下面是PyTorch中forward函数如何工作的简单的类似演示。
请检查以下内容:
from typing import Callable, Any
class parent:
def _unimplemented_forward(self, *input):
raise NotImplementedError
def _call_impl(self, *args):
# original nn.Module _call_impl function contains lot more code
# to handle exceptions, to handle hooks and for other purposes
self.forward(*args)
forward : Callable[..., Any] = _unimplemented_forward
__call__ : Callable[..., Any] = _call_impl
class child(parent):
def forward(self, *args):
print('forward function')
class child_2(parent):
pass运行时:
>>> c1 = child_1()
>>> c1()
forward function
>>> c2 = child_2()
>>> c2()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".\callable.py", line 8, in _call_impl
self.forward(*args)
File ".\callable.py", line 5, in _unimplemented_forward
raise NotImplementedError
NotImplementedErrorhttps://stackoverflow.com/questions/63480624
复制相似问题