Spaces:
Running
Running
import requests | |
import os | |
import gradio as gr | |
from huggingface_hub import HfApi | |
from slugify import slugify | |
import gradio as gr | |
import uuid | |
from typing import Optional | |
def get_json_data(url): | |
api_url = f"https://civitai.com/api/v1/models/{url.split('/')[4]}" | |
try: | |
response = requests.get(api_url) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching JSON data: {e}") | |
return None | |
def check_nsfw(json_data): | |
if json_data["nsfw"]: | |
return False | |
for model_version in json_data["modelVersions"]: | |
for image in model_version["images"]: | |
if image["nsfw"] != "None": | |
return False | |
return True | |
def extract_info(json_data): | |
if json_data["type"] == "LORA": | |
for model_version in json_data["modelVersions"]: | |
if model_version["baseModel"] in ["SDXL 1.0", "SDXL 0.9"]: | |
for file in model_version["files"]: | |
if file["primary"]: | |
info = { | |
"urls_to_download": [ | |
{"url": file["downloadUrl"], "filename": file["name"], "type": "weightName"}, | |
{"url": model_version["images"][0]["url"], "filename": os.path.basename(model_version["images"][0]["url"]), "type": "imageName"} | |
], | |
"id": model_version["id"], | |
"modelId": model_version["modelId"], | |
"name": json_data["name"], | |
"description": json_data["description"], | |
"trainedWords": model_version["trainedWords"], | |
"creator": json_data["creator"]["username"] | |
} | |
return info | |
return None | |
def download_files(info, folder="."): | |
downloaded_files = { | |
"imageName": [], | |
"weightName": [] | |
} | |
for item in info["urls_to_download"]: | |
download_file(item["url"], item["filename"], folder) | |
downloaded_files[item["type"]].append(item["filename"]) | |
return downloaded_files | |
def download_file(url, filename, folder="."): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
with open(f"{folder}/{filename}", 'wb') as f: | |
f.write(response.content) | |
print(f"{filename} downloaded.") | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading file: {e}") | |
def process_url(url, download_files=True, folder="."): | |
json_data = get_json_data(url) | |
if json_data: | |
if check_nsfw(json_data): | |
info = extract_info(json_data) | |
if info: | |
if(download_files): | |
downloaded_files = download_files(info, folder) | |
else: | |
downloaded_files = [] | |
return info, downloaded_files | |
else: | |
raise gr.Error("Only SDXL LoRAs are supported for now") | |
else: | |
raise gr.Error("This model has content tagged as unsafe by CivitAI") | |
else: | |
raise gr.Error("Something went wrong in fetching CivitAI API") | |
def create_readme(info, downloaded_files, is_author=True, folder="."): | |
readme_content = "" | |
original_url = f"https://civitai.com/models/{info['id']}" | |
non_author_disclaimer = f'This model was originally uploaded on [CivitAI]({original_url}), by [{info["creator"]}](https://civitai.com/user/{info["creator"]}/models). The information below was provided by the author on CivitAI:' | |
content = f"""--- | |
license: other | |
tags: | |
- text-to-image | |
- stable-diffusion | |
- lora | |
- diffusers | |
base_model: stabilityai/stable-diffusion-xl-base-1.0 | |
instance_prompt: {info["trainedWords"][0]} | |
widget: | |
- text: {info["trainedWords"][0]} | |
--- | |
# {info["name"]} | |
{non_author_disclaimer if not is_author else ''} | |
![Image]({downloaded_files["imageName"][0]}) | |
{info["description"]} | |
""" | |
readme_content += content + "\n" | |
with open(f"{folder}/README.md", "w") as file: | |
file.write(readme_content) | |
def upload_civit_to_hf(profile: Optional[gr.OAuthProfile], url, progress=gr.Progress(track_tqdm=True)): | |
if not profile.name: | |
return gr.Error("Are you sure you are logged in?") | |
folder = str(uuid.uuid4()) | |
os.makedirs(folder, exist_ok=False) | |
info, downloaded_files = process_url(url, folder) | |
create_readme(info, downloaded_files, folder) | |
try: | |
api = HfApi(token=hf_token) | |
username = api.whoami()["name"] | |
slug_name = slugify(info["name"]) | |
repo_id = f"{username}/{slug_name}" | |
api.create_repo(repo_id=repo_id, private=True, exist_ok=True) | |
api.upload_folder( | |
folder_path=folder, | |
repo_id=repo_id, | |
repo_type="model" | |
) | |
except: | |
raise gr.Error("something went wrong") | |
return "Model uploaded!" | |
def check_civit_link(profile: Optional[gr.OAuthProfile], url): | |
info, _ = process_url(url, download_files=False) | |
url_creator = f"https://civitai.com/user/{info['creator']}/models" | |
# Open the target URL | |
driver.get(url_creator) | |
# Define the XPath expression | |
xpath_expression = "//a[contains(@class, 'mantine-UnstyledButton-root') and contains(@class, 'mantine-ActionIcon-root') and contains(@class, 'mantine-ubxmi3') and starts-with(@href, 'https://huggingface.co/')]" | |
# Find the element using the XPath expression | |
try: | |
element = WebDriverWait(driver, 10).until( | |
EC.presence_of_element_located((By.XPATH, xpath_expression)) | |
) | |
# Extract the href attribute | |
href = element.get_attribute("href") | |
# Extract the part after 'https://huggingface.co/' | |
extracted_part = href.replace("https://huggingface.co/", "") | |
except Exception as e: | |
print("Element not found or error occurred:", e) | |
finally: | |
driver.quit() | |
return extracted_part == profile.name | |
def swap_fill(profile: Optional[gr.OAuthProfile]): | |
if profile is None: | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
css = ''' | |
#login { | |
font-size: 0px; | |
width: 100% !important; | |
margin: 0 auto; | |
} | |
#login:after { | |
content: 'Authorize this app before uploading your model'; | |
visibility: visible; | |
display: block; | |
font-size: var(--button-large-text-size); | |
} | |
#login:disabled{ | |
font-size: var(--button-large-text-size); | |
} | |
#login:disabled:after{ | |
content:'' | |
} | |
#disabled_upload{ | |
opacity: 0.5; | |
pointer-events:none; | |
} | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.LoginButton(elem_id="login") | |
with gr.Column(elem_id="disabled_upload") as disabled_area: | |
with gr.Row(): | |
submit_source_civit = gr.Textbox( | |
label="CivitAI model URL", | |
info="URL of the CivitAI model, make sure it is a SDXL LoRA", | |
) | |
#is_author = gr.Checkbox(label="Are you the model author?", info="If you are not the author, a disclaimer with information about the author and the CivitAI source will be added", value=False) | |
submit_button_civit = gr.Button("Upload model to Hugging Face and submit") | |
output = gr.Textbox(label="Output progress") | |
with gr.Column(visible=False) as enabled_area: | |
with gr.Row(): | |
submit_source_civit = gr.Textbox( | |
label="CivitAI model URL", | |
info="URL of the CivitAI model, make sure it is a SDXL LoRA", | |
) | |
#is_author = gr.Checkbox(label="Are you the model author?", info="If you are not the author, a disclaimer with information about the author and the CivitAI source will be added", value=False) | |
instructions = gr.HTML("") | |
submit_button_civit = gr.Button("Upload model to Hugging Face") | |
output = gr.Textbox(label="Output progress") | |
demo.load(fn=swap_fill, outputs=[disabled_area, enabled_area]) | |
submit_source_civit.change(fn=check_civit_link, inputs=[submit_source_civit], outputs=[instructions]) | |
submit_button_civit.click(fn=upload_civit_to_hf, inputs=[submit_source_civit], outputs=[output]) | |
demo.queue() | |
demo.launch() |