Spaces:
Running
Running
import gradio as gr # Để tạo giao diện web | |
import requests # Để thực hiện các yêu cầu HTTP | |
import json # Để làm việc với dữ liệu JSON | |
import hashlib # Để tạo hash MD5 | |
import time # Để xử lý thời gian | |
import os # Để làm việc với biến môi trường | |
from PIL import Image # Để mở và xử lý hình ảnh | |
from pathlib import Path # Để làm việc với đường dẫn | |
from io import BytesIO # Để xử lý dữ liệu binary như hình ảnh | |
# URL and personal information from the API | |
url_pre = "https://ap-east-1.tensorart.cloud/v1" | |
# Directory to save generated images | |
SAVE_DIR = "generated_images" | |
Path(SAVE_DIR).mkdir(exist_ok=True) | |
# Get API key from environment | |
api_key_token = os.getenv("api_key_token") | |
if not api_key_token: | |
raise ValueError("API key token not found in environment variables.") | |
# API request function | |
def txt2img(prompt, width, height): | |
model_id = "770694094415489962" # Fixed model ID | |
vae_id = "sdxl-vae-fp16-fix.safetensors" # Fixed VAE | |
lora_items = [ | |
{"loraModel": "766419665653268679", "weight": 0.7}, | |
{"loraModel": "777630084346589138", "weight": 0.7}, | |
{"loraModel": "776587863287492519", "weight": 0.7} | |
] | |
txt2img_data = { | |
"request_id": hashlib.md5(str(int(time.time())).encode()).hexdigest(), | |
"stages": [ | |
{ | |
"type": "INPUT_INITIALIZE", | |
"inputInitialize": { | |
"seed": -1, | |
"count": 1 | |
} | |
}, | |
{ | |
"type": "DIFFUSION", | |
"diffusion": { | |
"width": width, | |
"height": height, | |
"prompts": [ | |
{ | |
"text": prompt | |
} | |
], | |
"negativePrompts": [ | |
{ | |
"text": "nsfw" | |
} | |
], | |
"sdModel": model_id, | |
"sdVae": vae_id, | |
"sampler": "Euler a", | |
"steps": 20, | |
"cfgScale": 3, | |
"clipSkip": 1, | |
"etaNoiseSeedDelta": 31337, | |
"lora": { | |
"items": lora_items | |
} | |
} | |
} | |
] | |
} | |
body = json.dumps(txt2img_data) | |
# Use Bearer token for authorization | |
headers = { | |
'Content-Type': 'application/json', | |
'Accept': 'application/json', | |
'Authorization': f'Bearer {api_key_token}' | |
} | |
# Create a new job | |
response = requests.post(f"{url_pre}/jobs", json=txt2img_data, headers=headers) | |
if response.status_code != 200: | |
return f"Error: {response.status_code} - {response.text}" | |
response_data = response.json() | |
job_id = response_data['job']['id'] | |
print(f"Job created. ID: {job_id}") | |
# Polling for job completion | |
start_time = time.time() # Start the timeout counter | |
timeout = 300 # Maximum wait time is 300 seconds (5 minutes) | |
while True: | |
time.sleep(10) # Wait 10 seconds between each check | |
# Check if the timeout has been exceeded | |
elapsed_time = time.time() - start_time | |
if elapsed_time > timeout: | |
return f"Error: Job timed out after {timeout} seconds." | |
# Send a request to get the job status | |
response = requests.get(f"{url_pre}/jobs/{job_id}", headers={ | |
'Content-Type': 'application/json', | |
'Accept': 'application/json', | |
'Authorization': f'Bearer {api_key_token}' # Use the Bearer token here | |
}) | |
if response.status_code != 200: | |
return f"Error: {response.status_code} - {response.text}" | |
get_job_response_data = response.json() | |
print("Job response data:", get_job_response_data) # Print the whole response for debugging | |
job_status = get_job_response_data['job']['status'] | |
print(f"Job status: {job_status}") | |
if job_status == 'SUCCESS': | |
if 'successInfo' in get_job_response_data['job']: # Kiểm tra nếu 'successInfo' tồn tại | |
image_url = get_job_response_data['job']['successInfo']['images'][0]['url'] | |
print(f"Job succeeded. Image URL: {image_url}") | |
# Tải ảnh từ URL và trả về dưới dạng ảnh PIL để hiển thị trong Gradio | |
response_image = requests.get(image_url) | |
img = Image.open(BytesIO(response_image.content)) | |
return img # Trả về hình ảnh dưới dạng đối tượng PIL | |
else: | |
return "Error: Output is missing in the job response." | |
elif job_status == 'FAILED': | |
return "Error: Job failed. Please try again with different settings." | |
else: | |
print("Job is still in progress...") | |
# Cài đặt giao diện Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# TAMS Image Generator") | |
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") | |
# Tạo một danh sách các tùy chọn cho kích thước hình ảnh | |
size_options = gr.Dropdown(choices=["1152x768", "768x1152"], label="Select Image Size") | |
output_image = gr.Image(label="Generated Image") # Để hiển thị hình ảnh | |
def generate(prompt, size_choice): | |
# Dựa vào tùy chọn kích thước mà người dùng chọn, thiết lập width và height | |
if size_choice == "1152x768": | |
width, height = 1152, 768 | |
else: | |
width, height = 768, 1152 | |
return txt2img(prompt, width, height) | |
generate_button = gr.Button("Generate") | |
generate_button.click(fn=generate, inputs=[prompt_input, size_options], outputs=output_image) | |
demo.launch() | |