tams-image-gen / app.py
TDN-M's picture
Update app.py
744ebe6 verified
raw
history blame
5.97 kB
import gradio as gr # Để tạo giao diện web
import requests # Để thực hiện các yêu cầu HTTP
import json # Để làm việc với dữ liệu JSON
import hashlib # Để tạo hash MD5
import time # Để xử lý thời gian
import os # Để làm việc với biến môi trường
from PIL import Image # Để mở và xử lý hình ảnh
from pathlib import Path # Để làm việc với đường dẫn
from io import BytesIO # Để xử lý dữ liệu binary như hình ảnh
# URL and personal information from the API
url_pre = "https://ap-east-1.tensorart.cloud/v1"
# Directory to save generated images
SAVE_DIR = "generated_images"
Path(SAVE_DIR).mkdir(exist_ok=True)
# Get API key from environment
api_key_token = os.getenv("api_key_token")
if not api_key_token:
raise ValueError("API key token not found in environment variables.")
# API request function
def txt2img(prompt, width, height):
model_id = "770694094415489962" # Fixed model ID
vae_id = "sdxl-vae-fp16-fix.safetensors" # Fixed VAE
lora_items = [
{"loraModel": "766419665653268679", "weight": 0.7},
{"loraModel": "777630084346589138", "weight": 0.7},
{"loraModel": "776587863287492519", "weight": 0.7}
]
txt2img_data = {
"request_id": hashlib.md5(str(int(time.time())).encode()).hexdigest(),
"stages": [
{
"type": "INPUT_INITIALIZE",
"inputInitialize": {
"seed": -1,
"count": 1
}
},
{
"type": "DIFFUSION",
"diffusion": {
"width": width,
"height": height,
"prompts": [
{
"text": prompt
}
],
"negativePrompts": [
{
"text": "nsfw"
}
],
"sdModel": model_id,
"sdVae": vae_id,
"sampler": "Euler a",
"steps": 20,
"cfgScale": 3,
"clipSkip": 1,
"etaNoiseSeedDelta": 31337,
"lora": {
"items": lora_items
}
}
}
]
}
body = json.dumps(txt2img_data)
# Use Bearer token for authorization
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key_token}'
}
# Create a new job
response = requests.post(f"{url_pre}/jobs", json=txt2img_data, headers=headers)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
response_data = response.json()
job_id = response_data['job']['id']
print(f"Job created. ID: {job_id}")
# Polling for job completion
start_time = time.time() # Start the timeout counter
timeout = 300 # Maximum wait time is 300 seconds (5 minutes)
while True:
time.sleep(10) # Wait 10 seconds between each check
# Check if the timeout has been exceeded
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
return f"Error: Job timed out after {timeout} seconds."
# Send a request to get the job status
response = requests.get(f"{url_pre}/jobs/{job_id}", headers={
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key_token}' # Use the Bearer token here
})
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
get_job_response_data = response.json()
print("Job response data:", get_job_response_data) # Print the whole response for debugging
job_status = get_job_response_data['job']['status']
print(f"Job status: {job_status}")
if job_status == 'SUCCESS':
if 'successInfo' in get_job_response_data['job']: # Kiểm tra nếu 'successInfo' tồn tại
image_url = get_job_response_data['job']['successInfo']['images'][0]['url']
print(f"Job succeeded. Image URL: {image_url}")
# Tải ảnh từ URL và trả về dưới dạng ảnh PIL để hiển thị trong Gradio
response_image = requests.get(image_url)
img = Image.open(BytesIO(response_image.content))
return img # Trả về hình ảnh dưới dạng đối tượng PIL
else:
return "Error: Output is missing in the job response."
elif job_status == 'FAILED':
return "Error: Job failed. Please try again with different settings."
else:
print("Job is still in progress...")
# Cài đặt giao diện Gradio
with gr.Blocks() as demo:
gr.Markdown("# TAMS Image Generator")
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
# Tạo một danh sách các tùy chọn cho kích thước hình ảnh
size_options = gr.Dropdown(choices=["1152x768", "768x1152"], label="Select Image Size")
output_image = gr.Image(label="Generated Image") # Để hiển thị hình ảnh
def generate(prompt, size_choice):
# Dựa vào tùy chọn kích thước mà người dùng chọn, thiết lập width và height
if size_choice == "1152x768":
width, height = 1152, 768
else:
width, height = 768, 1152
return txt2img(prompt, width, height)
generate_button = gr.Button("Generate")
generate_button.click(fn=generate, inputs=[prompt_input, size_options], outputs=output_image)
demo.launch()