Rodin / Rodin.py
skkk's picture
rodin
ff1e8e9
raw
history blame
7.34 kB
import socketio
import requests
import json
import random
import base64
import io
from PIL import Image
from requests_toolbelt.multipart.encoder import MultipartEncoder
from constant import *
def login(email, password):
payload = {'password': password}
if email:
payload['email'] = email
response = requests.post(f"{BASE_URL}/user/login", json=payload)
response_data = response.json()
if 'error' in response_data and response_data['error']:
raise Exception(response_data['error'])
print("Login successful")
user_uuid = response_data['user_uuid']
token = response_data['token']
return user_uuid, token
def rodin_history(task_uuid, token):
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers)
return response.json()
def rodin_preprocess_image(generate_prompt, image, name, token):
m = MultipartEncoder(
fields={
'generate_prompt': "true" if generate_prompt else "false",
'images': (name, image, 'image/jpeg')
}
)
headers = {
'Content-Type': m.content_type,
'Authorization': f'Bearer {token}'
}
response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers)
return response.json()
def crop_image(image, type):
new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360
new_image_height = 360 # 新图片的高度
new_image = Image.new('RGB', (new_image_width, new_image_height))
for i in range(11520 // 720):
left = i * 720 + type[1]
upper = type[0]
right = left + 360
lower = upper + 360
cropped_image = image.crop((left, upper, right, lower))
new_image.paste(cropped_image, (i * 360, 0))
return new_image
# Perform Rodin mesh operation
def rodin_mesh(prompt, group_uuid, settings, images, name, token):
images = [convert_base64_to_binary(img) for img in images]
m = MultipartEncoder(
fields={
'prompt': prompt,
'group_uuid': group_uuid,
'settings': json.dumps(settings), # Convert settings dictionary to JSON string
**{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)}
}
)
headers = {
'Content-Type': m.content_type,
'Authorization': f'Bearer {token}'
}
response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers)
return response.json()
# Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64
def convert_base64_to_binary(base64_string):
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
image_buffer = io.BytesIO(image_data)
return image_buffer
def rodin_update(prompt, task_uuid, token, settings):
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
return response.json()
class Generator:
def __init__(self, user_id, password) -> None:
_, self.token = login(user_id, password)
self.task_uuid = None
def preprocess(self, prompt, image_path):
image_file = open(image_path, 'rb')
if image_file == None:
print("Invalid image file.")
try:
if not prompt:
preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name="images.jpeg", token=self.token)
else:
preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name="images.jpeg", token=self.token)
if 'error' in preprocess_response:
print("Error in image preprocessing:", preprocess_response['error'])
else:
if not prompt:
prompt = preprocess_response.get('prompt', 'Default prompt if none returned')
processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None)
finally:
image_file.close()
return prompt, processed_image
def generate_mesh(self, prompt, processed_image, task_uuid=""):
if task_uuid == "":
settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5]
images = [processed_image] # List of images, all the images should be processed first
mesh_response = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token)
progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key'])
progress_checker.start()
task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process
else:
new_prompt = prompt
settings = {
"view_weights": [1],
"seed": random.randint(0, 10000), # Customize your seed here
"escore": 5.5, # Temprature
}
update_response = rodin_update(new_prompt, task_uuid, self.token, settings)
# Check progress
subscription_key = update_response['job']['subscription_key']
checker = JobStatusChecker(BASE_URL, subscription_key)
checker.start()
preview_image = rodin_history(task_uuid, self.token)["v1"]["preview_image"]
response = requests.get(preview_image, stream=True)
if response.status_code == 200:
# 创建一个PIL Image对象
image = Image.open(response.raw)
# 在这里对image对象进行处理,如显示、保存等
else:
print(f"Can't get the preview image. Status code:{response.status_code}")
raise RuntimeError
response.close()
return image, task_uuid, crop_image(image, DEFAULT)
class JobStatusChecker:
def __init__(self, base_url, subscription_key):
self.base_url = base_url
self.subscription_key = subscription_key
self.sio = socketio.Client(logger=True, engineio_logger=True)
@self.sio.event
def connect():
print("Connected to the server.")
@self.sio.event
def disconnect():
print("Disconnected from server.")
@self.sio.on('message', namespace='*')
def message(*args, **kwargs):
if len(args) > 2:
data = args[2]
if data.get('jobStatus') == 'Succeeded':
print("Job Succeeded! Please find the SDF image in history")
self.sio.disconnect()
else:
print("Received event with insufficient arguments.")
def start(self):
self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}",
namespaces=['/api/scheduler_socket'], transports='websocket')
self.sio.wait()