import socketio import requests import json import time import random import base64 import io import PIL from PIL import Image from io import BytesIO import gradio as gr from requests_toolbelt.multipart.encoder import MultipartEncoder from constant import * def login(email, password): payload = {'password': password} if email: payload['email'] = email response = requests.post(f"{BASE_URL}/user/login", json=payload) try: response_data = response.json() except json.JSONDecodeError as e: log("ERROR", f"Error in login: {response}") raise e if 'error' in response_data and response_data['error']: raise Exception(response_data['error']) log("INFO", f"Logged successfully") user_uuid = response_data['user_uuid'] token = response_data['token'] return user_uuid, token def rodin_history(task_uuid, token): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers) return response.json() def rodin_preprocess_image(generate_prompt, image, name, token): m = MultipartEncoder( fields={ 'generate_prompt': "true" if generate_prompt else "false", 'images': (name, image, 'image/jpeg') } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers) return response def crop_image(image, type): if image == None: raise gr.Error("Please generate the object first") new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360 new_image_height = 360 # 新图片的高度 new_image = Image.new('RGB', (new_image_width, new_image_height)) for i in range(11520 // 720): left = i * 720 + type[1] upper = type[0] right = left + 360 lower = upper + 360 cropped_image = image.crop((left, upper, right, lower)) new_image.paste(cropped_image, (i * 360, 0)) return new_image # Perform Rodin mesh operation def rodin_mesh(prompt, group_uuid, settings, images, name, token): images = [convert_base64_to_binary(img) for img in images] m = MultipartEncoder( fields={ 'prompt': prompt, 'group_uuid': group_uuid, 'settings': json.dumps(settings), # Convert settings dictionary to JSON string **{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)} } ) headers = { 'Content-Type': m.content_type, 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers) return response # Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64 def convert_base64_to_binary(base64_string): if ',' in base64_string: base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) image_buffer = io.BytesIO(image_data) return image_buffer def rodin_update(prompt, task_uuid, token, settings): headers = { 'Authorization': f'Bearer {token}' } response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers) return response def load_image(img_path): try: image = Image.open(img_path) except PIL.UnidentifiedImageError as e: raise gr.Error("Unsupported Image Format") # 按比例缩小图像到长度为1024 width, height = image.size if width > height: scale = 512 / width else: scale = 512 / height new_width = int(width * scale) new_height = int(height * scale) resized_image = image.resize((new_width, new_height)) # 将 PIL.Image 对象转换为字节流 byte_io = BytesIO() resized_image.save(byte_io, format='PNG') image_bytes = byte_io.getvalue() return image_bytes def log(level, info_text): print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}") class Generator: def __init__(self, user_id, password, token) -> None: # _, self.token = login(user_id, password) self.token = token self.user_id = user_id self.password = password self.task_uuid = None self.processed_image = None def preprocess(self, prompt, image_path, processed_image , task_uuid=""): if image_path == None: raise gr.Error("Please upload an image first") if processed_image and prompt and (not task_uuid): log("INFO", "Using cached image and prompt...") return prompt, processed_image log("INFO", "Preprocessing image...") success = False try_times = 0 while not success: if try_times > 3: raise gr.Error("Failed to preprocess image") try_times += 1 image_file = load_image(image_path) log("INFO", "Image loaded, processing...") try: if prompt and task_uuid: res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token) else: res = rodin_preprocess_image(generate_prompt=True, image=image_file, name=os.path.basename(image_path), token=self.token) preprocess_response = res.json() log("INFO", f"Image preprocessed: {preprocess_response.get('statusCode')}") except Exception as e: log("ERROR", f"Error in image preprocessing: {res}") raise gr.Error("Error in image preprocessing, please try again.") if 'error' in preprocess_response: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Error in image preprocessing, please try again.") elif preprocess_response.get("statusCode") == 400: if "InvalidFile.Content" in preprocess_response.get("message"): raise gr.Error("Unsupported Image Format") else: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Busy connection, please try again later.") elif preprocess_response.get("statusCode") == 401: log("WARNING", "Token expired. Logging in again...") _, self.token = login(self.user_id, self.password) continue else: try: if not (prompt and task_uuid): prompt = preprocess_response.get('prompt', None) processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None) success = True except Exception as e: log("ERROR", f"Error in image preprocessing: {preprocess_response}") raise gr.Error("Busy connection, please try again later.") return prompt, processed_image def generate_mesh(self, prompt, processed_image, task_uuid=""): log("INFO", "Generating mesh...") if task_uuid == "": settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5] images = [processed_image] # List of images, all the images should be processed first res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token) try: mesh_response = res.json() progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key']) progress_checker.start() except Exception as e: log("ERROR", f"Error in generating mesh: {e} and response: {res}") raise gr.Error("Error in generating mesh, please try again later.") task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process else: new_prompt = prompt settings = { "view_weights": [1], "seed": random.randint(0, 10000), # Customize your seed here "escore": 5.5, # Temprature } res = rodin_update(new_prompt, task_uuid, self.token, settings) try: update_response = res.json() subscription_key = update_response['job']['subscription_key'] checker = JobStatusChecker(BASE_URL, subscription_key) checker.start() except Exception as e: log("ERROR", f"Error in updating mesh: {e}") raise gr.Error("Error in generating mesh, please try again later.") try: history = rodin_history(task_uuid, self.token) preview_image = next(reversed(history.items()))[1]["preview_image"] except Exception as e: log("ERROR", f"Error in generating mesh: {history}") raise gr.Error("Busy connection, please try again later.") response = requests.get(preview_image, stream=True) if response.status_code == 200: # 创建一个PIL Image对象 image = Image.open(response.raw) # 在这里对image对象进行处理,如显示、保存等 else: log("ERROR", f"Error in generating mesh: {response}") raise RuntimeError response.close() return image, task_uuid, crop_image(image, DEFAULT) class JobStatusChecker: def __init__(self, base_url, subscription_key): self.base_url = base_url self.subscription_key = subscription_key self.sio = socketio.Client(logger=True, engineio_logger=True) @self.sio.event def connect(): print("Connected to the server.") @self.sio.event def disconnect(): print("Disconnected from server.") @self.sio.on('message', namespace='*') def message(*args, **kwargs): if len(args) > 2: data = args[2] if data.get('jobStatus') == 'Succeeded': print("Job Succeeded! Please find the SDF image in history") self.sio.disconnect() else: print("Received event with insufficient arguments.") def start(self): self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}", namespaces=['/api/scheduler_socket'], transports='websocket') self.sio.wait()