前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >java落地AI模型-cnn手写体识别

java落地AI模型-cnn手写体识别

作者头像
三更两点
发布2024-10-01 08:06:11
350
发布2024-10-01 08:06:11
举报

cnn手写体识别

1. 基本介绍

  1. 手写体识别,是指对图像进行识别,判断图像中的内容是否为手写文字。
  2. 本项目是一手写数字识别为主,采用的模型是cnn。
1.1 步骤
  1. 数据集:MNIST手写数字数据集,该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28,共10个类别。
  2. python的框架是pytorch,使用pytorch的框架进行训练和测试。
  3. 识别准确率为,98%
  4. 模型转化:将pytorch的模型转化为onnx格式,方便在安卓端使用。
  5. 以java的代码推理模型,在安卓端或者其他环境中实现手写数字识别。
1.2 项目结构
代码语言:javascript
复制
.
├── DNS_tunnel_detect
│   ├── DNS_tunnel_detect.iml
│   ├── README.md
│   ├── bin
│   ├── lib
│   ├── out
│   ├── source
│   └── src
├── cnn_py
│   ├── data
│   ├── main.py
│   └── model
├── model2onnx
│   ├── model
│   ├── model2onnx.py
│   └── test_onnx_model.py
└── 第3集: java落地AI项目案例:cnn手写字体识别.md
1.3 模型结构

第一层包含卷积、批量归一化、ReLU激活和最大池化操作; 第二层结构相同但输出通道数为32; 全连接层将前一层输出扁平化后接分类器。

代码语言:javascript
复制
import torch
import torch.nn as nn

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
    
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
    return out

2.训练

代码语言:javascript
复制
model = ConvNet(num_classes).to(device)
print(model)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        print(images.size())
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

3.测试模型

代码语言:javascript
复制
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), './model/model.ckpt')
在这里插入图片描述
在这里插入图片描述

4. 模型转化

4.1 模型转化
代码语言:javascript
复制
import os
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

device = torch.device("cpu")
num_classes = 10
model = ConvNet(num_classes).to(device)
print(model)

model.load_state_dict(torch.load('../cnn_py/model/model.ckpt',map_location=device))

sample_input = torch.rand((1,1,28,28)).to(device)
print(sample_input)

model.eval()
with torch.no_grad():
    outputs = model(sample_input)
    print("output:",outputs)
    _, predicted = torch.max(outputs.data, 1)
    print("predicted:",predicted)
    
torch.onnx.export(model,
                  sample_input,
                  './model/model.onnx',
                  input_names=["input"],
                  output_names=["output"],
                  export_params=True,       # 是否保存模型参数
                  do_constant_folding=True)	# 是否执行常量折叠优化
    

torch.cuda.empty_cache()
在这里插入图片描述
在这里插入图片描述
4.2 pytorch模型转化为onnx模型
代码语言:javascript
复制
import os
import warnings
warnings.filterwarnings('ignore')

import onnxruntime
import torch

input_data = torch.rand(1,1,28,28)
session = onnxruntime.InferenceSession("./model/model.onnx")
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
print("result: ",result)
print(result[0][0])
max_value = max(list(result[0][0]))
predict = list(result[0][0]).index(max_value)
print(predict)
在这里插入图片描述
在这里插入图片描述

5. java端使用onnx模型进行预测

  • 需要安装onnxruntime库
代码语言:javascript
复制
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class App {
    public static void main(String[] args) throws Exception {
        String model_path = "./source/model.onnx";
        System.out.println(model_path);

        float[][][][] feature = new float[1][1][28][28];
        // 初始化数组元素
        for (int i = 0; i < 1; i++) {
            for (int j = 0; j < 1; j++) {
                for (int k = 0; k < 28; k++) {
                    for (int l = 0; l < 28; l++) {
                        feature[i][j][k][l] = (i + 1) * (j + 1) * (k + 1) * (l + 1);
                    }
                }
            }
        }
        System.out.println(Arrays.toString(feature));

        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.Result res = null;
        try (OrtSession session = env.createSession(model_path)){
            Map<String, OnnxTensor> container = new HashMap<>();

            OnnxTensor inputTensor = OnnxTensor.createTensor(env, feature);
            container.put("input", inputTensor);

            try(OrtSession.Result result = session.run(container)){
                OnnxTensor outputTensor = (OnnxTensor) result.get(0);
                float[][] result88 = (float[][])outputTensor.getValue();
                System.out.println(Arrays.toString(result88));
                for (int i = 0; i < result88.length; i++) {
                    for (int j = 0; j < result88[i].length; j++) {
                        System.out.println(result88[i][j]);
                    }
                }
            }
            OnnxValue.close(container);
        }catch (OrtException e) {
            throw new RuntimeException(e);
        } finally {
            System.out.println("all done");
        }
     }
}

6.总结

  1. 完成手写字体的python脚本训练和测试
  2. 完成onnx模型转化
  3. 完成java端使用onnx模型进行预测
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-09-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • cnn手写体识别
    • 1. 基本介绍
      • 1.1 步骤
      • 1.2 项目结构
      • 1.3 模型结构
    • 2.训练
      • 3.测试模型
        • 4. 模型转化
          • 4.1 模型转化
          • 4.2 pytorch模型转化为onnx模型
        • 5. java端使用onnx模型进行预测
          • 6.总结
          相关产品与服务
          腾讯云服务器利旧
          云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档