Redis简介及其在机器学习中的作用
Redis 是一款开源的内存数据结构存储系统,以极高的速度、优秀的持久性以及对多种数据结构的支持,成为机器学习应用场景下缓存的理想之选。它能够满足实时推理任务对高吞吐量的需求。
本文将带你深入理解 Redis 缓存在机器学习工作流中的重要性。我们将通过 FastAPI 与 Redis 构建一个健壮的机器学习应用,涵盖 Redis 在 Windows 下的安装、如何在本地运行、以及如何集成到项目中。最后,还将通过发送重复与唯一请求,验证 Redis 缓存系统的有效性。
为什么机器学习要用 Redis 缓存?
在当今快节奏的数字时代,用户对机器学习应用的响应速度有极高期望。例如,电商平台通常用推荐模型为用户推荐商品。若利用 Redis 缓存重复请求,平台可以极大缩短响应时间。
当用户请求产品推荐时,系统首先检查是否已有缓存;若有,微秒级响应即可返回缓存内容,用户体验极为流畅。若没有,则模型处理请求并将结果缓存入 Redis,供后续复用。这种方式不仅提升用户满意度,还优化服务器资源,使模型能高效应对更多并发请求。
构建基于 Redis 的钓鱼邮件分类应用
本项目将开发一个钓鱼邮件分类应用,流程包括加载并处理 Kaggle 数据集、训练模型、评估效果、保存模型,以及基于 FastAPI + Redis 的服务部署。
1. 环境搭建
从 Kaggle 下载钓鱼邮件检测数据集,放入data/目录。
安装 Redis Python 客户端:
pip install redis
Windows 用户如未安装 WSL(Windows Subsystem for Linux),请参考微软官方指南启用 WSL,并从 Microsoft Store 安装 Linux 发行版(如 Ubuntu)。
设置好 WSL 后,在 WSL 终端执行:
sudo apt update
sudo apt install redis-server
启动 Redis 服务:
sudo service redis-server start
你会看到 redis-server 成功启动的提示。
2. 训练模型
训练脚本加载数据集、处理数据、训练模型并本地保存。
import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
def main():
# 加载数据集
df = pd.read_csv("data/Phishing_Email.csv") # 路径需自行调整
X = df["Email Text"].fillna("")
y = df["Email Type"]
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 构建 TF-IDF + 逻辑回归流水线
pipeline = Pipeline(
[
("tfidf", TfidfVectorizer(stop_words="english")),
("clf", LogisticRegression(solver="liblinear")),
]
)
# 训练模型
pipeline.fit(X_train, y_train)
# 保存模型
joblib.dump(pipeline, "phishing_model.pkl")
print("模型训练完成,已保存为 phishing_model.pkl")
if __name__ == "__main__":
main()
执行:
python train.py
输出:
模型训练完成,已保存为 phishing_model.pkl
3. 模型评估
评估脚本会加载数据集和模型,进行性能评估:
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib
def main():
df = pd.read_csv("data/Phishing_Email.csv")
X = df["Email Text"].fillna("")
y = df["Email Type"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = joblib.load("phishing_model.pkl")
y_pred = model.predict(X_test)
print("准确率: ", accuracy_score(y_test, y_pred))
print("分类报告:")
print(classification_report(y_test, y_pred))
if __name__ == "__main__":
main()
执行:
python validate.py
输出示例:
准确率: 0.9723860589812332
分类报告:
precision recall f1-score support
Phishing Email 0.96 0.97 0.96 1457
Safe Email 0.98 0.97 0.98 2273
accuracy 0.97 3730
macro avg 0.97 0.97 0.97 3730
weighted avg 0.97 0.97 0.97 3730
4. 使用Redis和FastAPI部署模型服务
我们将利用 FastAPI 搭建 REST API,并集成 Redis 实现预测结果缓存。
import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis
# 创建异步Redis客户端(确保Redis运行在localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# 加载训练好的模型
model = joblib.load("phishing_model.pkl")
app = FastAPI()
# 定义请求和响应数据模型
class PredictionRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
probability: float
@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
# 用邮件文本作为缓存key
cache_key = f"prediction:{data.text}"
cached = await redis_client.get(cache_key)
if cached:
return json.loads(cached)
# 使用线程池执行模型推理,避免阻塞事件循环
pred = await asyncio.to_thread(model.predict, [data.text])
prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())
result = {"prediction": str(pred[0]), "probability": float(prob)}
# 缓存结果1小时(3600秒)
await redis_client.setex(cache_key, 3600, json.dumps(result))
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
运行服务:
python serve.py
启动后,你可以访问 http://localhost:8000/docs 查看API交互文档。
Redis 缓存在机器学习应用中的工作原理
Redis 缓存的基本流程如下:
客户端提交输入,请求机器学习模型做出预测。
系统根据输入数据生成唯一标识(key),用于查找是否已有缓存。
查询Redis缓存,查找对应预测。
如找到缓存,直接返回JSON响应,速度极快。
如未找到缓存,将输入传递给模型生成新预测。
将新预测存入Redis缓存,供未来复用。
最终结果以JSON格式返回客户端。
5. 钓鱼邮件分类应用测试
现在我们通过发送多条邮件文本,测试API的准确性与缓存效果。我们将用cURL命令向/predict端点发送5次请求,其中3次为独特邮件,2次为前面邮件的重复。
echo "\n===== Testing API Endpoint with 5 Requests =====\n"
# 唯一邮件1
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# 唯一邮件2
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
# 重复邮件1(与第1封重复)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# 唯一邮件3
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'
# 重复邮件2(与第2封重复)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"
执行上述脚本后,API会对每封邮件返回预测结果。对于重复邮件,响应将直接来自Redis缓存,响应速度显著加快。
6. 验证Redis缓存
我们可以通过check_redis.py脚本检查Redis数据库中的缓存内容:
import redis
import json
from tabulate import tabulate
def main():
# 连接到Redis
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# 获取所有 prediction: 前缀的 key
keys = redis_client.keys("prediction:*")
total_entries = len(keys)
print(f"Total number of cached prediction entries: {total_entries}\n")
table_data = []
# 只处理前5条
for key in keys[:5]:
email_text = key.replace("prediction:", "", 1)
value = redis_client.get(key)
try:
data = json.loads(value)
except json.JSONDecodeError:
data = {}
prediction = data.get("prediction", "N/A")
words = email_text.split()
truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")
table_data.append([truncated_text, prediction])
headers = ["Email Text (First 7 Words)", "Prediction"]
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
if __name__ == "__main__":
main()
运行:
python check_redis.py
输出示例:
Total number of cached prediction entries: 3
+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction |
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+
7. 总结与思考
通过对钓鱼邮件智能识别应用的多次请求测试,我们验证了API能够准确检测钓鱼邮件,并能高效缓存重复请求,大幅提升了预测服务的响应速度和整体性能。
尽管本例中模型相对简单,但对于更大型、更复杂的机器学习模型(如图像识别等),Redis 缓存能节省大量计算资源、显著提升响应效率——这对高并发业务场景尤为重要。
领取专属 10元无门槛券
私享最新 技术干货