我试图在OpenCV中使用dnn
模块,一个torch模型来对图像进行分割和背景删除。
该模型是一个预先训练的U2Net,在torch中,它为我的任务产生了非常好的结果。我将模型导出到onnx
,然后通过dnn.readNetFromONNX
函数读取它,但是结果非常糟糕。
我已经生成了一个代码,它可以在OpenCV和torch之间共享几乎所有的东西,当然,除了调用模型来进行预测之外。我没有将blobFromImage
函数用于OpenCV nn输入,而是使用在torch中使用的相同代码进行图像预处理。
这是测试图像上的结果:
代码(将在Google上进行测试)如下:
### upgrade opencv to last version ###
!pip install --upgrade opencv-python0
### clone git with U2Net ###
%cd /content
!git clone https://github.com/shreyas-bk/U-2-Net
### download weights for U2Net ###
!gdown --id 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ -O /content/U-2-Net/u2net.pth
###
%cd /content/U-2-Net
### imports ###
from google.colab import files
from model import U2NET
import torch
import os
import numpy as np
from torchvision import transforms
import cv2 as cv
from skimage import io, transform
from PIL import Image
### instantiate U2Net ###
model_dir = '/content/U-2-Net/u2net.pth'
net = U2NET(3, 1)
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()
### export torch model to onnx ###
img = torch.randn(1, 3, 320, 320, requires_grad=False)
img = img.to(torch.device('cpu'))
output_dir = os.path.join('/content/u2net.onnx')
torch.onnx.export(net, img, output_dir, opset_version=11, verbose=True)
### load model from OpenCV ###
cv_net = cv.dnn.readNetFromONNX('/content/u2net.onnx')
### your test image here ###
IMG_PATH = '/content/<IMAGE_NAME>.png'
### load image ###
image = Image.open(IMG_PATH)
### preprocessing ###
def preprocess_image(image, output_size=320, for_torch=False):
''''''
# resize image
img = transform.resize(image, (output_size, output_size), mode="constant")
# mean subtraction and normalization
tmp_img = np.zeros((img.shape[0], img.shape[1], 3))
img = img / np.max(img)
tmp_img[:, :, 0] = (img[:, :, 0] - 0.485) / 0.229
tmp_img[:, :, 1] = (img[:, :, 1] - 0.456) / 0.224
tmp_img[:, :, 2] = (img[:, :, 2] - 0.406) / 0.225
tmp_img = tmp_img.transpose((2, 0, 1))
if for_torch:
return torch.from_numpy(tmp_img)
return tmp_img
### predictions norm ###
def norm_pred(d):
'''
normalize predictions
'''
ma = d.max()
mi = d.min()
dn = (d - mi) / (ma - mi)
return dn
### the magic ###
def remove_bg(image, processed, for_torch=False):
pred = None
if for_torch:
with torch.no_grad():
inputs_test = torch.FloatTensor(processed.unsqueeze(0).float())
preds, _, _, _, _, _, _ = net(inputs_test)
else:
cv_net.setInput(np.expand_dims(processed, axis=0))
preds = cv_net.forward()
pred = preds[:, 0, :, :]
# normalization
pred_normalized = norm_pred(pred.cpu().detach().numpy() if for_torch else pred)
# squeeze
predict = pred_normalized.squeeze()
# to RGB
img_out = Image.fromarray(predict * 255).convert("RGB")
image = image.resize((img_out.size), resample=Image.BILINEAR)
empty_img = Image.new("RGBA", (image.size), 0)
img_out = Image.composite(image, empty_img, img_out.convert("L"))
# draw
img_out = img_out.resize((image.size), resample=Image.BILINEAR)
empty_img = Image.new("RGBA", (image.size), 0)
img_out = Image.composite(image, empty_img, img_out)
return img_out
### preprocess image ###
sample = preprocess_image(np.array(image), for_torch=True)
### torch results ###
remove_bg(image, sample, for_torch=True)
### opencv results ###
remove_bg(image, sample, for_torch=False)
OpenCV => 4.5.5,Platform => Google,Torch => 1.11.0+cu113
更新
这是我从torch.onnx.export
获得的警告
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Update2
我还尝试加载模型并使用onnxruntime
模块执行推理。一切都很好。此时,我认为这是OpenCV的一个问题。附加代码:
### install onnxruntime (version 1.11.1) ###
!pip install onnxruntime
### loading model with onnxruntime ###
import onnxruntime
ort_session = onnxruntime.InferenceSession("/content/u2net.onnx")
### updating remove_bg ###
def remove_bg(image, processed, backend='torch'):
''''''
if backend not in ['torch', 'OpenCV', 'onnx']: raise AttributeError('Wrong backend.')
pred = None
if backend == 'torch':
print('NO')
with torch.no_grad():
inputs_test = torch.FloatTensor(processed.unsqueeze(0).float())
preds, _, _, _, _, _, _ = net(inputs_test)
if backend == 'OpenCV':
print('NO')
cv_net.setInput(np.expand_dims(processed, axis=0))
preds = cv_net.forward()
if backend== 'onnx':
print('HI')
ort_inputs = {ort_session.get_inputs()[0].name: np.expand_dims(processed.astype(np.float32), axis=0)}
ort_outs = ort_session.run(None, ort_inputs)
preds = ort_outs[0]
pred = preds[:, 0, :, :]
# normalization
pred_normalized = norm_pred(pred.numpy() if backend=='torch' else pred)
# squeeze
predict = pred_normalized.squeeze()
# to RGB
img_out = Image.fromarray(predict * 255).convert("RGB")
image = image.resize((img_out.size), resample=Image.BILINEAR)
empty_img = Image.new("RGBA", (image.size), 0)
img_out = Image.composite(image, empty_img, img_out.convert("L"))
# draw
img_out = img_out.resize((image.size), resample=Image.BILINEAR)
empty_img = Image.new("RGBA", (image.size), 0)
img_out = Image.composite(image, empty_img, img_out)
return img_out
### inference with onnxruntime ###
remove_bg(image, sample, backend='onnx')
发布于 2022-05-18 12:26:38
我尝试了我自己的形象,也有类似的结果。
在网络的源代码中有_upsample_like
函数,看起来像
F.upsample(src,size=tar.shape[2:],mode='bilinear')
但根据ONNX官方回购,不支持双线性插值,在torch.onnx.export
日志中,ONNX使用线性插值代替:
onnx::Resize[coordinate_transformation_mode="pytorch_half_pixel", cubic_coeff_a=-0.75, mode="linear", nearest_mode="floor"]
我想您应该更改U2Net的源代码,即
F.upsample(src,size=tar.shape[2:],mode='bilinear')
例如,到
F.upsample(src,size=tar.shape[2:],mode='linear')
或者,更好的是,由于降级警告,
F.interpolate(src,size=tar.shape[2:],mode='linear')
您可以使用任何其他插值,由ONNX支持。
然后用新的插值方法(或从头开始训练)将模型导出到ONNX->opencv.dnn。
https://stackoverflow.com/questions/72286654
复制相似问题