adamelliotfields's picture
Support image-to-image workflows
ad24d66 verified
raw
history blame
4.32 kB
import time
from io import BytesIO
import httpx
import streamlit as st
from anthropic import Anthropic
from anthropic import APIError as AnthropicAPIError
from openai import APIError as OpenAIAPIError
from openai import OpenAI
from PIL import Image
from .config import config
from .util import base64_decode_image_data_url
def txt2txt_generate(api_key, service, parameters, **kwargs):
model = parameters.get("model", "")
base_url = config.services[service].url
if service == "hf":
base_url = f"{base_url}/{model}/v1"
try:
if service == "anthropic":
client = Anthropic(api_key=api_key)
with client.messages.stream(**parameters, **kwargs) as stream:
return st.write_stream(stream.text_stream)
else:
client = OpenAI(api_key=api_key, base_url=base_url)
stream = client.chat.completions.create(stream=True, **parameters, **kwargs)
return st.write_stream(stream)
except AnthropicAPIError as e:
return e.message
except OpenAIAPIError as e:
# OpenAI uses this message for streaming errors and attaches response.error to error.body
# https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
return e.body if e.message == "An error occurred during streaming" else e.message
except Exception as e:
return str(e)
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
headers = {}
json = {**parameters, **kwargs}
if service == "bfl":
headers["x-key"] = api_key
json["prompt"] = inputs
if service == "fal":
headers["Authorization"] = f"Key {api_key}"
json["prompt"] = inputs
if service == "hf":
headers["Authorization"] = f"Bearer {api_key}"
headers["X-Wait-For-Model"] = "true"
headers["X-Use-Cache"] = "false"
json = {
"inputs": inputs,
"parameters": {**parameters, **kwargs},
}
if service == "together":
headers["Authorization"] = f"Bearer {api_key}"
json["prompt"] = inputs
base_url = config.services[service].url
if service not in ["together"]:
base_url = f"{base_url}/{model}"
try:
timeout = config.timeout
response = httpx.post(base_url, headers=headers, json=json, timeout=timeout)
if response.status_code // 100 == 2: # 2xx
# BFL is async so we need to poll for result
# https://api.bfl.ml/docs
if service == "bfl":
id = response.json()["id"]
url = f"{config.services[service].url}/get_result?id={id}"
retries = 0
while retries < timeout:
response = httpx.get(url, timeout=timeout)
if response.status_code // 100 != 2:
return f"Error: {response.status_code} {response.text}"
if response.json()["status"] == "Ready":
image = httpx.get(
response.json()["result"]["sample"],
headers=headers,
timeout=timeout,
)
return Image.open(BytesIO(image.content))
retries += 1
time.sleep(1)
return "Error: API timeout"
if service == "fal":
# Sync mode means wait for image base64 string instead of CDN link
url = response.json()["images"][0]["url"]
if parameters.get("sync_mode", True):
return base64_decode_image_data_url(url)
else:
image = httpx.get(url, headers=headers, timeout=timeout)
return Image.open(BytesIO(image.content))
if service == "hf":
return Image.open(BytesIO(response.content))
if service == "together":
url = response.json()["data"][0]["url"]
image = httpx.get(url, headers=headers, timeout=timeout)
return Image.open(BytesIO(image.content))
else:
return f"Error: {response.status_code} {response.text}"
except Exception as e:
return str(e)