encode的输入是变长的序列向量,每个向量之间会在batch内填充为固定长度,神经网络限制,不能输入变长的向量。 encode输出固定长度的向量,但序列数量和输入数量保持不变,也就是一个输入产生一个输出。每个输出之间是独立的。 encode的网络可以不固定,比如常见nlp任务用rnn,。 encode将可变序列编码为固定状态,decode将固定状态输入映射为其它可变序列。 decode的网络可以不固定,其中ctc 结合search策略也可以用来做decode。
通用的“编码器-解码器”接口定义:
from torch import nn
#在编码器接口中,我们只指定长度可变的序列作为编码器的输入X。任何继承这个Encoder基类的模型将完成代码实现。
class Encoder(nn.Module):
def __init__(self, **kwargs) -> None:
super(Encoder,self).__init__(**kwargs)
def forward(self, X, *args):
raise NotImplementedError
class Decoder(nn.Module):
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forard(self, X, state):
raise NotImplementedError
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder,self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state =self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
解码器接口中,我们新增一个init_state函数用于将编码器的输出(enc_outputs)转换为编码后的状态。注意,此步骤可能需要额外的输入,例如:输入序列的有效长度,逐个生成长度可变的标记序列,解码器在每个时间步都可以将输入(例如:在前一时间步生成的标记)和编码后的状态映射成当前时间步的输出标记。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/185012.html原文链接:https://javaforall.cn