jclyo1's picture
Update main.py
b8c84f9
from fastapi import FastAPI, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import subprocess
import os
import json
import uuid
import html
import torch
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
app = FastAPI()
def file_extension(filename):
filename_list = filename.split(".")
return filename_list[1].lower()
@app.get("/generate")
def generate_image(prompt, model):
torch.cuda.empty_cache()
prompt = html.escape(prompt)
model = html.escape(model)
modelArray = model.split(",")
modelName = modelArray[0]
modelVersion = modelArray[1]
pipeline = StableDiffusionPipeline.from_pretrained(
str(modelName), torch_dtype=torch.float16
)
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to("cuda")
image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0]
filename = str(uuid.uuid4()) + ".jpg"
image.save(filename)
assertion = {
"assertions": [
{
"label": "com.truepic.custom.ai",
"data": {
"model_name": modelName,
"model_version": modelVersion,
"prompt": prompt,
},
}
]
}
json_object = json.dumps(assertion)
subprocess.check_output(
[
"./scripts/sign.sh",
filename,
filename,
"--assertions-inline",
json_object
]
)
subprocess.check_output(
[
"cp",
filename,
"static/" + filename,
]
)
return {"response": filename}
@app.post("/verify")
def verify_image(fileUpload: UploadFile):
# check if the file has been uploaded
if fileUpload.filename:
fileupload_extension = file_extension(fileUpload.filename)
input_filename = str(uuid.uuid4()) + "." + fileupload_extension
output_filename = str(uuid.uuid4()) + "." + fileupload_extension
# strip the leading path from the file name
fn = os.path.basename(input_filename)
# open read and write the file into the server
open(fn, "wb").write(fileUpload.file.read())
response = subprocess.check_output(
[
"./scripts/verify.sh",
input_filename,
output_filename
]
)
response_list = response.splitlines()
c2pa_string = str(response_list[0])
c2pa = c2pa_string.split(":", 1)
c2pa = c2pa[1].strip(" ").strip("'")
watermark_string = str(response_list[1])
watermark = watermark_string.split(":", 1)
watermark = watermark[1].strip(" ").strip("'")
original_media_string = str(response_list[2])
original_media = original_media_string.split(":", 1)
original_media = original_media[1].strip(" ").strip("'")
if c2pa == 'true':
response = subprocess.check_output(
[
"cp",
input_filename,
"static/" + output_filename,
]
)
result_media = output_filename
elif original_media != 'n/a':
original_media_extension = file_extension(original_media)
filename = str(uuid.uuid4()) + "." + original_media_extension
response = subprocess.check_output(
[
"cp",
original_media,
"static/" + filename,
]
)
result_media = filename
else:
result_media = 'n/a'
return {"response": fileUpload.filename, "contains_c2pa" : c2pa, "contains_watermark" : watermark, "result_media" : result_media}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")