专栏首页数据分析与挖掘pytorch读取一张图像进行分类预测需要注意的问题(opencv、PIL)

pytorch读取一张图像进行分类预测需要注意的问题(opencv、PIL)

读取图像一般是两个库:opencv和PIL

1、使用opencv读取图像

import cv2
image=cv2.imread("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
print(image.shape)

(490, 410, 3)

2、使用PIL读取图像

import PIL
image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
print(image.shape)

这里会报错:

AttributeError                            Traceback (most recent call last)
<ipython-input-30-807ec7af434b> in <module>()
      1 import PIL
      2 image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
----> 3 print(image.shape)
AttributeError: 'JpegImageFile' object has no attribute 'shape'

我们要输出要这么做:

import numpy as np
print(np.array(image).shape)

(490, 410, 3)

需要注意的是:

使用opencv读取图像之后是BGR格式的,使用PIL读取图像之后是RGB格式的。

3、opencv格式的和PIL格式的之间的转换

这里参考:https://www.cnblogs.com/enumx/p/12359850.html

(1)opencv格式转换为PIL格式

import cv2
from PIL import Image
import numpy
 
img = cv2.imread("plane.jpg")
cv2.imshow("OpenCV",img)
image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
image.show()
cv2.waitKey()

(2)PIL格式转换为opencv格式

import cv2
from PIL import Image
import numpy
 
image = Image.open("plane.jpg")
image.show()
img = cv2.cvtColor(numpy.asarray(image),cv2.COLOR_RGB2BGR)
cv2.imshow("OpenCV",img)
cv2.waitKey()

4、使用pytorch读取一张图片并进行分类预测

需要注意两个问题:

  • 输入要转换为:[1,channel,H,W]
  • 对输入的图像进行数据增强时要求是PIL.Image格式的
import torchvision
import sys
import torch
import torch.nn as nn
from PIL import Image
sys.path.append("/content/drive/My Drive/colab notebooks")
import glob
import numpy as np
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features,4,bias=False)
model.to(device)
model.eval()
save_path="/content/drive/My Drive/colab notebooks/checkpoint/resnet18_best_v2.t7" 
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model'])
print("当前模型准确率为:",checkpoint["epoch_acc"])
images_path="/content/drive/My Drive/colab notebooks/data/dataset/test/four"
transform = transforms.Compose([transforms.Resize((224,224))])
def predict():
  true_labels=[]
  output_labels=[]
  for image in glob.glob(images_path+"/*.png"):
    print(image)
    true_labels.append(0)
    #image=Image.open(image)
    #image=image.resize((224,224))
    image=cv2.imread(image)
    image=cv2.resize(image,(224,224))
    image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
    #print(np.array(image).shape)
    tensor=torch.from_numpy(np.asarray(image)).permute(2,0,1).float()/255.0
    tensor=tensor.reshape((1,3,224,224))
    tensor=tensor.to(device)
    #print(tensor.shape)
    output=model(tensor)
    print(output)
    _, pred = torch.max(output.data,1)
    output_labels.append(pred.item())
  return true_labels,output_labels

true_labels,output_labels=predict()
print("正确的标签是:")
print(true_labels)
print("预测的标签是:")
print(output_labels)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • mybatis插件开发初探

    说明:执行sql方法时会调用四大对象,如果不是自己配置拦截的类型,就放过,否则就进行拦截。我们定义的插件是拦截StatementHandler类中的parame...

    绝命生
  • 【python-opencv】轨迹栏作为调色板

    在这里,我们将创建一个简单的应用程序,以显示您指定的颜色。您有一个显示颜色的窗口,以及三个用于指定B、G、R颜色的跟踪栏。滑动轨迹栏,并相应地更改窗口颜色。默认...

    绝命生
  • springboot消息之AmqpAdmin管理组件的使用

    绝命生
  • 傅里叶变换

    低频位于频率变换图像的中心。 这些示例的变换图像显示实心图像具有大多数低频分量(如中心亮点所示)。 条纹转换图像包含白色和黑色区域的低频以及这些颜色之间的边...

    小飞侠xp
  • Python文件夹批处理操作代码实例

    砸漏
  • Street Lanes Finder - 检测自动驾驶汽车的车道

    在今天的文章中,将使用基本的计算机视觉技术来解决对于自动驾驶汽车至关重要的街道车道检测问题。到本文结束时,将能够使用Python和OpenCV执行实时通道检测。

    代码医生工作室
  • 颜色转换,利用HSV颜色空间检测

    绘制出这些通道的灰度版本 以便观察各通道的强度,像素越亮 代表的红色、绿色或蓝色值就越高。我们可以看到 粉色气球的红色值很高 蓝色值也相对比较高,但值大小不一 ...

    小飞侠xp
  • Creating a Filter, Edge Detection

    Below, you've been given one common type of edge detection filter: a Sobel opera...

    小飞侠xp
  • Python从入门到摔门(5):18式优雅你的Python

    在cmd中输入jupyter notebook --generate-config,然后找到生成的配置文件jupyter_notebook_config.py,...

    Python之道
  • 2018-05-30 <通用魔术>天才们想法太超前了

    Albert陈凯

扫码关注云+社区

领取腾讯云代金券