首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何将BackgroundTasks的结果作为FastApi中的websocket应答返回?

如何将BackgroundTasks的结果作为FastApi中的websocket应答返回?
EN

Stack Overflow用户
提问于 2022-09-13 17:55:48
回答 1查看 79关注 0票数 1

我有下一个代码:

代码语言:javascript
运行
复制
from fastapi import FastAPI, WebSocket, BackgroundTasks
import uvicorn
import time

app = FastAPI()


def run_model():
    ...
    ## code of the model
    answer = [1, 2, 3]
    ...
    results = {"message": "the model has been excuted succesfully!!", "results": answer}
    return results


@app.post("/execute-model")
async def ping(background_tasks: BackgroundTasks):
    background_tasks.add_task(run_model)
    return {"message": "the model is executing"}


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    while True:
        ## Here I wnat the results of run_model
        await websocket.send_text(1)

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8001)

我需要给/执行一个模型。这个端点将导出一个run_model函数作为后台任务。当run_model()完成时,我需要将答案返回到前面,并且我想在使用websockets时,但是我不知道如何做。请帮帮忙。

EN

回答 1

Stack Overflow用户

发布于 2022-09-30 18:43:31

我也有过类似的经历。下面是我是如何做到的(不是说它是最好的,甚至是一个好的解决方案,但到目前为止,它是有效的):

路由端点:

代码语言:javascript
运行
复制
# client makes a post request, gets saved model immeditely, while a background task is started to process the image
@app.post("/analyse", response_model=schemas.ImageAnalysis , tags=["Image Analysis"])
async def create_image_analysis( 
    img: schemas.ImageAnalysisCreate, 
    background_tasks: BackgroundTasks, 
    db: Session = Depends(get_db),
):
    saved = crud.create_analysis(db=db, img=img)
    background_tasks.add_task(analyse_image,db=db, img=img)

    #model includes a ws_token (some random string) that the client can connect to right away
    return saved

websocket端点:

代码语言:javascript
运行
复制
@app.websocket("/ws/{ws_token}")
async def websocket_endpoint(websocket: WebSocket, ws_token: str):
    #add the websocket to the connections dict (by ws_token)
    await socket_connections.connect(websocket,ws_token=ws_token)
    try:
        while True:
            print(socket_connections)
            await websocket.receive_text() #not really necessary
            
    except WebSocketDisconnect:
        socket_connections.disconnect(websocket,ws_token=ws_token)

analyse_image函数:

代码语言:javascript
运行
复制
#notice - the function is not async, as it does not work with background tasks otherwise!!
def analyse_image (db: Session, img: ImageAnalysis):

    print('analyse_image started')
    for index, round in enumerate(img.rounds):
        
        # some heavy workload etc

        # send update to user
        socket_connections.send_message({
                "status":EstimationStatus.RUNNING,
                "current_step":index+1,
                "total_steps":len(img.rounds)
            }, ws_token=img.ws_token)

    print("analysis finished")

连接管理器:

代码语言:javascript
运行
复制
import asyncio
from typing import Dict, List
from fastapi import  WebSocket

#notice: active_connections is changed to a dict (key= ws_token), so we know which user listens to which model
class ConnectionManager:
    
    def __init__(self):
        self.active_connections: Dict[str, List[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, ws_token: str):
        await websocket.accept()
        if ws_token in self.active_connections:
             self.active_connections.get(ws_token).append(websocket)
        else:
            self.active_connections.update({ws_token: [websocket]})


    def disconnect(self, websocket: WebSocket, ws_token: str):
        self.active_connections.get(ws_token).remove(websocket)
        if(len(self.active_connections.get(ws_token))==0):
            self.active_connections.pop(ws_token)

    # notice: changed from async to sync as background tasks messes up with async functions
    def send_message(self, data: dict,ws_token: str):
        sockets = self.active_connections.get(ws_token)
        if sockets:
            #notice: socket send is originally async. We have to change it to syncronous code - 
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

            for socket in sockets:
                socket.send_text
                loop.run_until_complete(socket.send_json(data))


socket_connections = ConnectionManager()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73707373

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档