首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >Pytorch中的nn.ModuleList改写成TF2.6时可以换成什么?

Pytorch中的nn.ModuleList改写成TF2.6时可以换成什么?

提问于 2022-03-09 09:57:23
回答 0关注 0查看 47

我想将一个Pytorch中的PyramidPooling类改写为TF2.6的版本使用,但其中的nn.ModuleList和nn.Sequential不知道如何修改,另外在class中的def和普通定义的函数def中,conv2d的使用有何区别?

代码语言:js
复制
import torch
from torch import nn

class PyramidPooling(nn.Module):
    def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1):
        super().__init__()
        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales])
        self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def _make_stage(self, in_channels, scale, ct_channels):
        prior = nn.AvgPool2d(kernel_size=(scale, scale))
        conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False)
        relu = nn.LeakyReLU(0.2, inplace=True)
        return nn.Sequential(prior, conv, relu)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1)
        return self.relu(self.bottleneck(priors))

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档