我正在FastAPI码头容器中运行稳定扩散。它运行良好,但在执行多个推理调用后,我注意到GPU的vRAM已满,推理失败。这就好像在进行推断之后内存没有被释放一样。知道如何强制释放记忆吗?
以下是main.py
中的脚本
import logging
import os
import random
import time
import torch
from diffusers import StableDiffusionPipeline
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List, Optional
# Load default logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
# Load Stable Diffusion model
log.info('Load Stable Diffusion model')
model_path = './models/stable-diffusion-v1-4'
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
revision='fp16',
torch_dtype=torch.float16
)
# Move pipeline to GPU for faster inference
pipe = pipe.to('cuda')
pipe.enable_attention_slicing()
# Declare inputs and outputs data types for the API endpoint
class Payload(BaseModel):
prompt: str # String of text used to generate the images
num_images = 1 # Number of images to be generated
height = 512 # Height of the images to be generated
width = 512 # Width of the images to be generated
seed: Optional[int] = None # Random integer used as a seed to guide the image generator
num_steps = 40 # Number of inference steps, results are better the more steps you use, at a cost of slower inference
guidance_scale = 8.5 # Forces generation to better match the prompt, 7 or 8.5 give good results, results are better the larger the number is, but will be less diverse
class Response(BaseModel):
images: List[str]
nsfw_content_detected: List[bool]
prompt: str
num_images: int
height: int
width: int
seed: int
num_steps: int
guidance_scale: float
# Create FastAPI app
log.info('Start API')
app = FastAPI(title='Stable Diffusion')
app.mount("/static", StaticFiles(directory="./static"), name="static") # Mount folder to expose generated images
# Declare imagine endpoint for inference
@app.post('/imagine', response_model=Response, description='Runs inferences with Stable Diffusion.')
def imagine(payload: Payload, request: Request):
"""The imagine function generates the /imagine endpoint and runs inferences"""
try:
# Check payload
log.info(f'Payload: {payload}')
# Default seed with a random integer if it is not provided by user
if payload.seed is None:
payload.seed = random.randint(-999999999, 999999999)
generator = torch.Generator('cuda').manual_seed(payload.seed)
# Create multiple prompts according to the number of images
prompt = [payload.prompt] * payload.num_images
# Run inference on GPU
log.info('Run inference')
with torch.autocast('cuda'):
result = pipe(
prompt=prompt,
height=payload.height,
width=payload.width,
num_inference_steps=payload.num_steps,
guidance_scale=payload.guidance_scale,
generator=generator
)
log.info('Inference completed')
# Save images
images_urls = []
for image in result.images:
image_name = str(time.time()).replace('.', '') + '.png'
image_path = os.path.join('static', image_name)
image.save(image_path)
image_url = request.url_for('static', path=image_name)
images_urls.append(image_url)
# Build response object
response = {}
response['images'] = images_urls
response['nsfw_content_detected'] = result['nsfw_content_detected']
response['prompt'] = payload.prompt
response['num_images'] = payload.num_images
response['height'] = payload.height
response['width'] = payload.width
response['seed'] = payload.seed
response['num_steps'] = payload.num_steps
response['guidance_scale'] = payload.guidance_scale
return response
except Exception as e:
log.error(repr(e))
raise HTTPException(status_code=500, detail=repr(e))
发布于 2022-10-17 10:44:09
我能够通过在进行推断之后添加这个代码片段来解决这个问题.我认为这确实应该添加到文档中的不同示例中。值得称赞的是我的同事,他是从稳定的扩散WebUI存储库中发现的。
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
https://stackoverflow.com/questions/74092819
复制相似问题