|
import Linlada |
|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_credentials=True, |
|
allow_origins=["*"], |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token) |
|
pipe = pipe.to("cpu") |
|
pipe.enable_attention_slicing() |
|
|
|
def dummy(images, **kwargs): |
|
return images, False |
|
|
|
pipe.safety_checker = dummy |
|
|
|
@app.get("/") |
|
def hello(): |
|
return "Hello, I'm Artist" |
|
|
|
@app.post('/generate_completion') |
|
async def generate_completion( |
|
model: str = Query('gpt-4', description='The model to use for generating the completion'), |
|
messages: List[Dict[str, str]] = Query(..., description='The list of messages to generate the completion for'), |
|
stream: bool = Query(False, description='Whether to stream the response') |
|
): |
|
response = index._create_completion(model=model, messages=messages, stream=stream) |
|
|
|
result = [] |
|
for message in response: |
|
result.append(message) |
|
|
|
return result |
|
|