首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >计算torch.utils.data.DataLoader中数据对应的光流

计算torch.utils.data.DataLoader中数据对应的光流
EN

Stack Overflow用户
提问于 2019-04-12 20:15:53
回答 1查看 212关注 0票数 2

我已经在PyTorch的视频中建立了一个动作识别的CNN模型。我正在使用torch dataloader模块加载训练数据。

代码语言:javascript
运行
复制
train_loader = torch.utils.data.DataLoader(
            training_data,
            batch_size=8,
            shuffle=True,
            num_workers=4,
            pin_memory=True)

然后通过train_loader对模型进行训练。

代码语言:javascript
运行
复制
train_epoch(i, train_loader, action_detect_model, criterion, optimizer, opt,
                        train_logger, train_batch_logger)

现在我想添加一个额外的路径,它将采用视频帧的相应光流。为了计算光流,我使用了cv2.calcOpticalFlowFarneback。但问题是,我不确定如何获得与列车数据加载器张量中的数据相对应的图像,因为它们将被混洗。我不想预先计算光流,因为存储需求将是巨大的(每帧需要600kbs)。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-28 05:10:39

您必须使用自己的数据加载器类来动态计算光流。其思想是这个类获得一个文件名元组列表(curr image,next image),其中包含视频序列的当前和下一帧文件名,而不是简单的文件名列表。这允许在添加后缀文件名列表后获得正确的图像对。下面的代码给出了一个非常简单的示例实现:

代码语言:javascript
运行
复制
from torch.utils.data import Dataset
import cv2
import numpy as np

class FlowDataLoader(Dataset):
def __init__(self,
             filename_tuples):

    random.shuffle(filename_tuples)
    self.lines = filename_tuples

def __getitem__(self, index):
    img_filenames = self.lines[index]
    curr_img = cv2.cvtColor(cv2.imread(img_filenames[0]), cv2.BGR2GRAY)
    next_img = cv2.cvtColor(cv2.imread(img_filenames[1]), cv2.BGR2GRAY)
    flow = cv2.calcOpticalFlowFarneback(curr_img, next_img, ... [parameter])

    # code for loading the class label
    # label = ...
    #
    # this is a very simple data normalization
    curr_img= curr_img.astype(np.float32) / 255
    next_img = next_img .astype(np.float32) / 255
    # you can return the image and flow seperatly 
    return curr_img, flow, label
    # or stacked as follows
    # return np.dstack((curr_img,flow)), label

# at this place you need a function that create a list of training sample filenames
# that look like this
training_filelist = [(img000.png, img001.png), 
                     (img001.png, img002.png),
                     (img002.png, img003.png)] 

training_data = FlowDataLoader(training_filelist)
train_loader = torch.utils.data.DataLoader(
        training_data,
        batch_size=8,
        shuffle=True,
        num_workers=4,
        pin_memory=True)

这只是FlowDataLoader的一个简单示例。理想情况下,这应该被扩展,以便curr_image输出包含归一化的RGB值,并且光流也被归一化和修剪。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55651427

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档