前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >FastAPI + ONNX 部署机器学习模型最佳实践

FastAPI + ONNX 部署机器学习模型最佳实践

原创
作者头像
别惹CC
发布2025-01-13 14:21:40
发布2025-01-13 14:21:40
22800
代码可运行
举报
运行总次数:0
代码可运行

引言

随着人工智能的迅猛发展,将训练好的模型部署到生产环境中,为用户提供实时预测服务,已成为众多企业和开发者关注的重点。然而,模型部署并非易事,涉及到模型格式转换、服务框架选择、性能优化等多个方面。本篇文章将介绍如何结合 FastAPIONNX,实现机器学习模型的高效部署,并分享其中的最佳实践。

背景介绍 🎨

机器学习模型的部署,常常会遇到以下挑战:

  • 模型兼容性:不同的深度学习框架(如 TensorFlow、PyTorch)有各自的模型格式,直接部署可能会有兼容性问题,导致部署困难。
  • 性能瓶颈:模型推理速度直接影响用户体验和系统资源消耗,性能优化至关重要。
  • 服务稳定性:需要确保服务在高并发情况下的稳定性和可靠性,否则可能会崩溃。
  • 安全性:需要防范潜在的安全风险,如输入数据的验证、攻击防护等,保障应用安全。

看到这里,可能有人会问:“有没有一种简单的方法,可以解决这些问题呢?”答案就是——FastAPI + ONNX

为什么选择 FastAPI 与 ONNX

  • 高性能:FastAPI 与 ONNX Runtime 的组合,提供了高效的推理和响应速度,让你的服务飞起来!
  • 易于开发和维护:FastAPI 简洁的代码结构和自动文档生成功能,大大降低了开发和维护的成本,不再为繁琐的配置烦恼。
  • 跨框架支持:ONNX 支持多种主流的深度学习框架,方便模型的转换和部署,再也不用陷入框架之争。
  • 社区活跃:两个项目都有活跃的社区支持,丰富的资源和教程,遇到问题有人帮,进步之路不孤单。

最佳实践 🛠️

1.模型转换为 ONNX 格式

模型转换是部署的第一步。将训练好的模型转换为 ONNX 格式,可以提高模型的兼容性和性能。

PyTorch 模型转换

假设你有一个训练好的 PyTorch 模型,将其转换为 ONNX 格式呢只需几行代码,如下:

代码语言:python
代码运行次数:0
复制
import torch
import torch.onnx

# 加载训练好的模型
model = torch.load('model.pth')
model.eval()

# 定义一个输入张量(示例输入)
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, 'model.onnx', 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")

TensorFlow 模型转换

对于 TensorFlow 的模型,也是类似的操作。

代码语言:python
代码运行次数:0
复制
import tensorflow as tf
import tf2onnx

# 加载训练好的模型
model = tf.keras.models.load_model('model.h5')

# 转换为 ONNX 格式
spec = (tf.TensorSpec(model.inputs[0].shape, dtype=tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, 
                                            input_signature=spec, 
                                            opset=11, 
                                            output_path=output_path)

print("✅ 模型已成功转换为 ONNX 格式!")

验证转换后的模型

转换完成后,别忘了验证一下模型是否正常工作!

代码语言:python
代码运行次数:0
复制
    import onnx
    import onnxruntime as ort

    # 加载 ONNX 模型
    onnx_model = onnx.load('model.onnx')
    onnx.checker.check_model(onnx_model)
    print("✅ 模型格式验证通过!")

    # 使用 ONNX Runtime 进行推理
    ort_session = ort.InferenceSession('model.onnx')

    # 准备输入数据
    import numpy as np
    input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

    # 运行推理
    outputs = ort_session.run(None, {'input': input_data})

    print("输出结果:", outputs)

示例输出:

2.构建 FastAPI 应用

现在,我们来创建一个基于 FastAPI 的应用,将模型部署为一个 API 服务。

安装依赖

首先,安装必要的依赖包:

代码语言:bash
复制
pip install fastapi uvicorn[standard] onnxruntime

定义 FastAPI 应用

编写应用主文件 main.py

代码语言:python
代码运行次数:0
复制
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
from pydantic import BaseModel

app = FastAPI(title="机器学习模型部署 API 🚀")

# 加载 ONNX 模型
ort_session = ort.InferenceSession('model.onnx')

# 定义输入数据模型
class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    # 将输入数据转换为 numpy 数组
    input_array = np.array(input_data.data).astype(np.float32)
    
    # 进行推理
    outputs = ort_session.run(None, {'input': input_array})
    
    # 返回结果
    return {"prediction": outputs[0].tolist()}

运行应用

使用uvicorn启动应用:

代码语言:bash
复制
uvicorn main:app --host 0.0.0.0 --port 8000

测试接口

可以使用 curl 或其他工具测试一下接口是否正常工作:

代码语言:bash
复制
curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{
    "data": [0.5, 0.3, 0.2]
}'

示例输出:

代码语言:json
复制
{
    "prediction": [[0.1, 0.9]]
}

至此我们的 API 已经可以正常工作了!

3.性能优化

性能对于一个服务来说至关重要,这里介绍一些优化技巧。

模型优化

  • 使用模型优化工具:ONNX 提供了模型优化工具,可简化和加速模型。
代码语言:bash
复制
python -m onnxruntime.tools.optimizer_cli --input model.onnx --output model_optimized.onnx --optimization_level all
  • 量化模型:通过模型量化,将浮点数精度降低,减小模型大小,加速推理。
代码语言:bash
复制
python -m onnxruntime.quantization.quantize --input model.onnx --output model_quant.onnx --per_channel

推理加速

  • 使用 GPU 加速:如果有 GPU 资源,可以使用 GPU 提供商提升推理速度。
代码语言:python
代码运行次数:0
复制
ort_session = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
  • 多线程或多进程:根据服务器性能,调整并发数,充分利用硬件资源。

4.安全性考虑

安全是服务的底线,我们需要考虑以下几点。

输入验证

  • 数据格式验证:使用 Pydantic 模型,确保输入数据的格式和类型正确。
  • 异常处理:捕获可能的异常,如数据维度错误,返回友好的错误信息。
代码语言:python
代码运行次数:0
复制
@app.post("/predict")
async def predict(input_data: InputData):
    try:
        # 输入验证
        input_array = np.array(input_data.data).astype(np.float32)
        # 检查输入维度(根据模型需求调整)
        if input_array.shape != (1, 3, 224, 224):
            return {"error": "输入数据维度不正确"}
        # 进行推理
        outputs = ort_session.run(None, {'input': input_array})
        return {"prediction": outputs[0].tolist()}
    except Exception as e:
        return {"error": str(e)}

安全防护

  • 限制请求频率:通过中间件或网关,防止恶意请求和 DDoS 攻击。
  • SSL/HTTPS:在生产环境中,确保通信的安全性。

案例示例 🎯

下面以一个手写数字识别模型为例,展示完整的部署过程。

1.模型训练与转换

代码语言:python
代码运行次数:0
复制
# 使用 MNIST 数据集训练模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义简单的神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

# 创建模型实例
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 加载数据集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=64, shuffle=True)

# 训练模型(这里只训练一个 epoch 作示例)
for epoch in range(1):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# 保存模型
torch.save(model.state_dict(), 'mnist.pth')

print("✅ 模型训练完成并已保存!")

# 加载模型并转换为 ONNX
model = Net()
model.load_state_dict(torch.load('mnist.pth'))
model.eval()

dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, 'mnist.onnx', input_names=['input'], output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")

2.构建 FastAPI 应用

代码语言:python
代码运行次数:0
复制
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import onnxruntime as ort
import numpy as np

app = FastAPI(title="手写数字识别 API 🖊️")
ort_session = ort.InferenceSession('mnist.onnx')

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # 读取上传的图片
    image = Image.open(file.file).convert('L')
    # 图片预处理
    image = image.resize((28, 28))
    image_data = np.array(image).astype(np.float32).reshape(1, 1, 28, 28)
    # 归一化
    image_data /= 255.0
    # 推理
    outputs = ort_session.run(None, {'input': image_data})
    prediction = np.argmax(outputs[0])
    # 返回结果
    return {"prediction": int(prediction)}

3.测试

使用工具(如 cURL、Postman)发送请求,验证接口功能。

代码语言:bash
复制
curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@test_digit.png;type=image/png"

示例输出:

代码语言:json
复制
    {
        "prediction": 7
    }

小结 ⚡️

通过遵循上述最佳实践,我们可以简化部署流程,提高模型的推理性能,增强服务的可靠性和安全性。当然,在实际应用中,我们还需要根据具体情况进行优化和调整,希望本篇文章可以对各位读者有所帮助!

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 引言
  • 背景介绍 🎨
    • 为什么选择 FastAPI 与 ONNX
  • 最佳实践 🛠️
    • 1.模型转换为 ONNX 格式
      • PyTorch 模型转换
      • TensorFlow 模型转换
      • 验证转换后的模型
    • 2.构建 FastAPI 应用
      • 安装依赖
      • 定义 FastAPI 应用
      • 运行应用
      • 测试接口
    • 3.性能优化
      • 模型优化
      • 推理加速
    • 4.安全性考虑
      • 输入验证
      • 安全防护
  • 案例示例 🎯
    • 1.模型训练与转换
    • 2.构建 FastAPI 应用
    • 3.测试
  • 小结 ⚡️
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档