随着人工智能的迅猛发展,将训练好的模型部署到生产环境中,为用户提供实时预测服务,已成为众多企业和开发者关注的重点。然而,模型部署并非易事,涉及到模型格式转换、服务框架选择、性能优化等多个方面。本篇文章将介绍如何结合 FastAPI 和 ONNX,实现机器学习模型的高效部署,并分享其中的最佳实践。
机器学习模型的部署,常常会遇到以下挑战:
看到这里,可能有人会问:“有没有一种简单的方法,可以解决这些问题呢?”答案就是——FastAPI + ONNX!
模型转换是部署的第一步。将训练好的模型转换为 ONNX 格式,可以提高模型的兼容性和性能。
假设你有一个训练好的 PyTorch 模型,将其转换为 ONNX 格式呢只需几行代码,如下:
import torch
import torch.onnx
# 加载训练好的模型
model = torch.load('model.pth')
model.eval()
# 定义一个输入张量(示例输入)
dummy_input = torch.randn(1, 3, 224, 224)
# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, 'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
print("✅ 模型已成功转换为 ONNX 格式!")
对于 TensorFlow 的模型,也是类似的操作。
import tensorflow as tf
import tf2onnx
# 加载训练好的模型
model = tf.keras.models.load_model('model.h5')
# 转换为 ONNX 格式
spec = (tf.TensorSpec(model.inputs[0].shape, dtype=tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model,
input_signature=spec,
opset=11,
output_path=output_path)
print("✅ 模型已成功转换为 ONNX 格式!")
转换完成后,别忘了验证一下模型是否正常工作!
import onnx
import onnxruntime as ort
# 加载 ONNX 模型
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("✅ 模型格式验证通过!")
# 使用 ONNX Runtime 进行推理
ort_session = ort.InferenceSession('model.onnx')
# 准备输入数据
import numpy as np
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 运行推理
outputs = ort_session.run(None, {'input': input_data})
print("输出结果:", outputs)
示例输出:
现在,我们来创建一个基于 FastAPI 的应用,将模型部署为一个 API 服务。
首先,安装必要的依赖包:
pip install fastapi uvicorn[standard] onnxruntime
编写应用主文件 main.py
:
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
from pydantic import BaseModel
app = FastAPI(title="机器学习模型部署 API 🚀")
# 加载 ONNX 模型
ort_session = ort.InferenceSession('model.onnx')
# 定义输入数据模型
class InputData(BaseModel):
data: list
@app.post("/predict")
async def predict(input_data: InputData):
# 将输入数据转换为 numpy 数组
input_array = np.array(input_data.data).astype(np.float32)
# 进行推理
outputs = ort_session.run(None, {'input': input_array})
# 返回结果
return {"prediction": outputs[0].tolist()}
使用uvicorn
启动应用:
uvicorn main:app --host 0.0.0.0 --port 8000
可以使用 curl
或其他工具测试一下接口是否正常工作:
curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{
"data": [0.5, 0.3, 0.2]
}'
示例输出:
{
"prediction": [[0.1, 0.9]]
}
至此我们的 API 已经可以正常工作了!
性能对于一个服务来说至关重要,这里介绍一些优化技巧。
python -m onnxruntime.tools.optimizer_cli --input model.onnx --output model_optimized.onnx --optimization_level all
python -m onnxruntime.quantization.quantize --input model.onnx --output model_quant.onnx --per_channel
ort_session = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
安全是服务的底线,我们需要考虑以下几点。
@app.post("/predict")
async def predict(input_data: InputData):
try:
# 输入验证
input_array = np.array(input_data.data).astype(np.float32)
# 检查输入维度(根据模型需求调整)
if input_array.shape != (1, 3, 224, 224):
return {"error": "输入数据维度不正确"}
# 进行推理
outputs = ort_session.run(None, {'input': input_array})
return {"prediction": outputs[0].tolist()}
except Exception as e:
return {"error": str(e)}
下面以一个手写数字识别模型为例,展示完整的部署过程。
# 使用 MNIST 数据集训练模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义简单的神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(28*28, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.fc(x)
return x
# 创建模型实例
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 加载数据集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
batch_size=64, shuffle=True)
# 训练模型(这里只训练一个 epoch 作示例)
for epoch in range(1):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 保存模型
torch.save(model.state_dict(), 'mnist.pth')
print("✅ 模型训练完成并已保存!")
# 加载模型并转换为 ONNX
model = Net()
model.load_state_dict(torch.load('mnist.pth'))
model.eval()
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, 'mnist.onnx', input_names=['input'], output_names=['output'])
print("✅ 模型已成功转换为 ONNX 格式!")
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import onnxruntime as ort
import numpy as np
app = FastAPI(title="手写数字识别 API 🖊️")
ort_session = ort.InferenceSession('mnist.onnx')
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# 读取上传的图片
image = Image.open(file.file).convert('L')
# 图片预处理
image = image.resize((28, 28))
image_data = np.array(image).astype(np.float32).reshape(1, 1, 28, 28)
# 归一化
image_data /= 255.0
# 推理
outputs = ort_session.run(None, {'input': image_data})
prediction = np.argmax(outputs[0])
# 返回结果
return {"prediction": int(prediction)}
使用工具(如 cURL、Postman)发送请求,验证接口功能。
curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@test_digit.png;type=image/png"
示例输出:
{
"prediction": 7
}
通过遵循上述最佳实践,我们可以简化部署流程,提高模型的推理性能,增强服务的可靠性和安全性。当然,在实际应用中,我们还需要根据具体情况进行优化和调整,希望本篇文章可以对各位读者有所帮助!
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。