首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >OpenCV dnn模块生成与原始torch模型不同的预测

OpenCV dnn模块生成与原始torch模型不同的预测
EN

Stack Overflow用户
提问于 2022-05-18 09:35:55
回答 1查看 181关注 0票数 1

我试图在OpenCV中使用dnn模块,一个torch模型来对图像进行分割和背景删除。

该模型是一个预先训练的U2Net,在torch中,它为我的任务产生了非常好的结果。我将模型导出到onnx,然后通过dnn.readNetFromONNX函数读取它,但是结果非常糟糕。

我已经生成了一个代码,它可以在OpenCV和torch之间共享几乎所有的东西,当然,除了调用模型来进行预测之外。我没有将blobFromImage函数用于OpenCV nn输入,而是使用在torch中使用的相同代码进行图像预处理。

这是测试图像上的结果:

代码(将在Google上进行测试)如下:

代码语言:javascript
复制
### upgrade opencv to last version ###
!pip install --upgrade opencv-python0

### clone git with U2Net ###
%cd /content
!git clone https://github.com/shreyas-bk/U-2-Net

### download weights for U2Net ###
!gdown --id 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ -O /content/U-2-Net/u2net.pth

###
%cd /content/U-2-Net

### imports ###
from google.colab import files
from model import U2NET
import torch
import os
import numpy as np
from torchvision import transforms
import cv2 as cv
from skimage import io, transform
from PIL import Image

### instantiate U2Net ###
model_dir = '/content/U-2-Net/u2net.pth'
net = U2NET(3, 1)
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval() 

### export torch model to onnx ###
img = torch.randn(1, 3, 320, 320, requires_grad=False)
img = img.to(torch.device('cpu'))
output_dir = os.path.join('/content/u2net.onnx')
torch.onnx.export(net, img, output_dir, opset_version=11, verbose=True)

### load model from OpenCV ###
cv_net = cv.dnn.readNetFromONNX('/content/u2net.onnx')

### your test image here ###
IMG_PATH = '/content/<IMAGE_NAME>.png'

### load image ###
image = Image.open(IMG_PATH)

### preprocessing ###

def preprocess_image(image, output_size=320, for_torch=False):
  ''''''
  # resize image
  img = transform.resize(image, (output_size, output_size), mode="constant")

  # mean subtraction and normalization
  tmp_img = np.zeros((img.shape[0], img.shape[1], 3))
  img = img / np.max(img)

  tmp_img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
  tmp_img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
  tmp_img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
  tmp_img = tmp_img.transpose((2, 0, 1))
  
  if for_torch:
    return torch.from_numpy(tmp_img)

  return tmp_img

### predictions norm ###
def norm_pred(d):
  '''
  normalize predictions
  '''
  ma = d.max()
  mi = d.min()
  dn = (d - mi) / (ma - mi)
  return dn

### the magic ###
def remove_bg(image, processed, for_torch=False):
  pred = None
  if for_torch:
    with torch.no_grad():
      inputs_test = torch.FloatTensor(processed.unsqueeze(0).float())
      preds, _, _, _, _, _, _ = net(inputs_test)
  else:
    cv_net.setInput(np.expand_dims(processed, axis=0))
    preds = cv_net.forward()  
    
  pred = preds[:, 0, :, :]
  # normalization
  pred_normalized = norm_pred(pred.cpu().detach().numpy() if for_torch else pred)
  # squeeze
  predict = pred_normalized.squeeze()
  # to RGB
  img_out = Image.fromarray(predict * 255).convert("RGB")
  image = image.resize((img_out.size), resample=Image.BILINEAR)
  empty_img = Image.new("RGBA", (image.size), 0)
  img_out = Image.composite(image, empty_img, img_out.convert("L"))

  # draw 
  img_out = img_out.resize((image.size), resample=Image.BILINEAR)
  empty_img = Image.new("RGBA", (image.size), 0)
  img_out = Image.composite(image, empty_img, img_out)

  return img_out

### preprocess image ###
sample = preprocess_image(np.array(image), for_torch=True)

### torch results ###
remove_bg(image, sample, for_torch=True)

### opencv results ###
remove_bg(image, sample, for_torch=False)

OpenCV => 4.5.5,Platform => Google,Torch => 1.11.0+cu113

更新

这是我从torch.onnx.export获得的警告

代码语言:javascript
复制
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
  warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")

Update2

我还尝试加载模型并使用onnxruntime模块执行推理。一切都很好。此时,我认为这是OpenCV的一个问题。附加代码:

代码语言:javascript
复制
### install onnxruntime (version 1.11.1) ###
!pip install onnxruntime

### loading model with onnxruntime ###
import onnxruntime
ort_session = onnxruntime.InferenceSession("/content/u2net.onnx")

### updating remove_bg ###
def remove_bg(image, processed, backend='torch'):
  ''''''
  if backend not in ['torch', 'OpenCV', 'onnx']: raise AttributeError('Wrong backend.')

  pred = None
  if backend == 'torch':
    print('NO')
    with torch.no_grad():
      inputs_test = torch.FloatTensor(processed.unsqueeze(0).float())
      preds, _, _, _, _, _, _ = net(inputs_test)
  if backend == 'OpenCV':
    print('NO')
    cv_net.setInput(np.expand_dims(processed, axis=0))
    preds = cv_net.forward()  
  if backend== 'onnx':
    print('HI')
    ort_inputs = {ort_session.get_inputs()[0].name: np.expand_dims(processed.astype(np.float32), axis=0)}
    ort_outs = ort_session.run(None, ort_inputs)
    preds = ort_outs[0]
    
  pred = preds[:, 0, :, :]
  # normalization
  pred_normalized = norm_pred(pred.numpy() if backend=='torch' else pred)
  # squeeze
  predict = pred_normalized.squeeze()
  # to RGB
  img_out = Image.fromarray(predict * 255).convert("RGB")
  image = image.resize((img_out.size), resample=Image.BILINEAR)
  empty_img = Image.new("RGBA", (image.size), 0)
  img_out = Image.composite(image, empty_img, img_out.convert("L"))

  # draw 
  img_out = img_out.resize((image.size), resample=Image.BILINEAR)
  empty_img = Image.new("RGBA", (image.size), 0)
  img_out = Image.composite(image, empty_img, img_out)

  return img_out


### inference with onnxruntime ###
remove_bg(image, sample, backend='onnx')
EN

回答 1

Stack Overflow用户

发布于 2022-05-18 12:26:38

我尝试了我自己的形象,也有类似的结果。

在网络的源代码中有_upsample_like函数,看起来像

代码语言:javascript
复制
F.upsample(src,size=tar.shape[2:],mode='bilinear')

但根据ONNX官方回购,不支持双线性插值,在torch.onnx.export日志中,ONNX使用线性插值代替:

代码语言:javascript
复制
onnx::Resize[coordinate_transformation_mode="pytorch_half_pixel", cubic_coeff_a=-0.75, mode="linear", nearest_mode="floor"]

我想您应该更改U2Net的源代码,即

代码语言:javascript
复制
F.upsample(src,size=tar.shape[2:],mode='bilinear')

例如,到

代码语言:javascript
复制
F.upsample(src,size=tar.shape[2:],mode='linear')

或者,更好的是,由于降级警告,

代码语言:javascript
复制
F.interpolate(src,size=tar.shape[2:],mode='linear')

您可以使用任何其他插值,由ONNX支持。

然后用新的插值方法(或从头开始训练)将模型导出到ONNX->opencv.dnn。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72286654

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档