from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse import platform import subprocess import logging import urllib.request import os import json import torch from diffusers import DiffusionPipeline print(f"Is CUDA available: {torch.cuda.is_available()}") app = FastAPI() @app.get("/generate") def generate_image(prompt): print(f"Is CUDA available: {torch.cuda.is_available()}") #model_id = "CompVis/stable-diffusion-v1-4" #stabilityai/stable-diffusion-2-1 model_id = "runwayml/stable-diffusion-v1-5" # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead #pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) #pipe = StableDiffusionPipeline.from_pretrained(model_id) #pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) #pipe = pipe.to("cuda") pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) pipeline = pipeline.to("cuda") generator = torch.Generator("cuda").manual_seed(0) image = pipeline(prompt, generator=generator).images[0] #prompt = "a photo of an astronaut riding a horse on mars" #image = pipe(prompt, num_inference_steps=5).images[0] #image = pipe(prompt).images[0] print(image) image.save("static/ai.jpg") image.save("static/ai.png") @app.get("/generate-picsum") def generate_picsum(prompt): local_filename, headers = urllib.request.urlretrieve(('https://picsum.photos/id/' + prompt + '/800/800')) # Data to be written assertion = { "assertions": [ { "label": "com.truepic.custom.ai", "data": { "model_name": "Picsum", "model_version": "1.0", "prompt": prompt } } ] } json_object = json.dumps(assertion, indent=4) with open("assertion.json", "w") as outfile: outfile.write(json_object) subprocess.check_output(['./truepic-sign', 'sign', local_filename, '--profile', 'demo', '--assertions', 'assertion.json', '--output', (os.getcwd() + '/static/output.jpg')]) return {"response": "success"} app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")