首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何通过预先训练的骨干网络来更快地训练火炬视觉的RPN

如何通过预先训练的骨干网络来更快地训练火炬视觉的RPN
EN

Stack Overflow用户
提问于 2021-08-07 16:29:59
回答 1查看 183关注 0票数 0

正如标题所提到的,如果我已经预先训练了骨干,并且我想只训练RPN,而不是使用torchvision的更快的R-CNN来训练分类器。

是否有可以传递给create_model函数的参数,或者是否要停止训练我的train()函数中的分类器?

我在用手机,所以请原谅我的编辑

这是我的create model函数

代码语言:javascript
运行
复制
Create your backbone from timm
backbone = timm.create_model(
“resnet50”,
pretrained=True,
num_classes=0, # this is important to remove fc layers
global_pool="" # this is important to remove fc layers
)

backbone.out_channels = backbone.feature_info[-1][“num_chs”]

anchor_generator = AnchorGenerator(
sizes=((16, 32, 64, 128, 256),), aspect_ratios=((0.25, 0.5, 1.0, 2.0),)
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=[“0”], output_size=7, sampling_ratio=2
)
fastercnn_model = FasterRCNN(
backbone=backbone,
num_classes=1000,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-07 16:54:05

您可以执行以下操作

代码语言:javascript
运行
复制
# First you can use model.children() method to see the idx of the backbone
for idx, child in enumerate(fastercnn_model.children()):
    if idx == 1:
        # Now set requires_grad for that idx to False
        for param in child.parameters():
            param.requires_grad = False
        break

# ===============  UPDATED  ========================
# This will train only the box_predictor not even the RPN. You can try out
# Different strategies and find the best for you.
# setting everything to false

for child in fastercnn_model.children():
    for param in child.parameters():
        param.requires_grad = False
        
for idx, child in enumerate(fastercnn_model.children()):
    if idx == 3:
        for i, param in enumerate(child.parameters()):
            if i==1:
                param.requires_grad = True
        break
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68694221

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档