|
import time |
|
from io import BytesIO |
|
|
|
import httpx |
|
import streamlit as st |
|
from anthropic import Anthropic |
|
from anthropic import APIError as AnthropicAPIError |
|
from openai import APIError as OpenAIAPIError |
|
from openai import OpenAI |
|
from PIL import Image |
|
|
|
from .config import config |
|
from .util import base64_decode_image_data_url |
|
|
|
|
|
def txt2txt_generate(api_key, service, parameters, **kwargs): |
|
model = parameters.get("model", "") |
|
base_url = config.services[service].url |
|
|
|
if service == "hf": |
|
base_url = f"{base_url}/{model}/v1" |
|
|
|
try: |
|
if service == "anthropic": |
|
client = Anthropic(api_key=api_key) |
|
with client.messages.stream(**parameters, **kwargs) as stream: |
|
return st.write_stream(stream.text_stream) |
|
else: |
|
client = OpenAI(api_key=api_key, base_url=base_url) |
|
stream = client.chat.completions.create(stream=True, **parameters, **kwargs) |
|
return st.write_stream(stream) |
|
except AnthropicAPIError as e: |
|
return e.message |
|
except OpenAIAPIError as e: |
|
|
|
|
|
return e.body if e.message == "An error occurred during streaming" else e.message |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs): |
|
headers = {} |
|
json = {**parameters, **kwargs} |
|
|
|
if service == "bfl": |
|
headers["x-key"] = api_key |
|
json["prompt"] = inputs |
|
|
|
if service == "fal": |
|
headers["Authorization"] = f"Key {api_key}" |
|
json["prompt"] = inputs |
|
|
|
if service == "hf": |
|
headers["Authorization"] = f"Bearer {api_key}" |
|
headers["X-Wait-For-Model"] = "true" |
|
headers["X-Use-Cache"] = "false" |
|
json = { |
|
"inputs": inputs, |
|
"parameters": {**parameters, **kwargs}, |
|
} |
|
|
|
if service == "together": |
|
headers["Authorization"] = f"Bearer {api_key}" |
|
json["prompt"] = inputs |
|
|
|
base_url = config.services[service].url |
|
|
|
if service not in ["together"]: |
|
base_url = f"{base_url}/{model}" |
|
|
|
try: |
|
timeout = config.timeout |
|
response = httpx.post(base_url, headers=headers, json=json, timeout=timeout) |
|
|
|
if response.status_code // 100 == 2: |
|
|
|
|
|
if service == "bfl": |
|
id = response.json()["id"] |
|
url = f"{config.services[service].url}/get_result?id={id}" |
|
|
|
retries = 0 |
|
while retries < timeout: |
|
response = httpx.get(url, timeout=timeout) |
|
if response.status_code // 100 != 2: |
|
return f"Error: {response.status_code} {response.text}" |
|
|
|
if response.json()["status"] == "Ready": |
|
image = httpx.get( |
|
response.json()["result"]["sample"], |
|
headers=headers, |
|
timeout=timeout, |
|
) |
|
return Image.open(BytesIO(image.content)) |
|
|
|
retries += 1 |
|
time.sleep(1) |
|
|
|
return "Error: API timeout" |
|
|
|
if service == "fal": |
|
|
|
url = response.json()["images"][0]["url"] |
|
if parameters.get("sync_mode", True): |
|
return base64_decode_image_data_url(url) |
|
else: |
|
image = httpx.get(url, headers=headers, timeout=timeout) |
|
return Image.open(BytesIO(image.content)) |
|
|
|
if service == "hf": |
|
return Image.open(BytesIO(response.content)) |
|
|
|
if service == "together": |
|
url = response.json()["data"][0]["url"] |
|
image = httpx.get(url, headers=headers, timeout=timeout) |
|
return Image.open(BytesIO(image.content)) |
|
|
|
else: |
|
return f"Error: {response.status_code} {response.text}" |
|
except Exception as e: |
|
return str(e) |
|
|