Spaces:
Sleeping
Sleeping
import json | |
import warnings | |
from typing import List, Optional, Union | |
import argilla as rg | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from gradio.oauth import ( | |
OAuthToken, | |
get_space, | |
) | |
from huggingface_hub import whoami | |
from jinja2 import Environment, meta | |
from synthetic_dataset_generator.constants import argilla_client | |
def get_duplicate_button(): | |
if get_space() is not None: | |
return gr.DuplicateButton(size="lg") | |
def list_orgs(oauth_token: Union[OAuthToken, None] = None): | |
if oauth_token is None: | |
return [] | |
try: | |
data = whoami(oauth_token.token) | |
except Exception: | |
swap_visibility(None) | |
return [] | |
try: | |
if data["auth"]["type"] == "oauth": | |
organizations = [data["name"]] + [org["name"] for org in data["orgs"]] | |
elif data["auth"]["type"] == "access_token": | |
organizations = [data["name"]] + [org["name"] for org in data["orgs"]] | |
else: | |
organizations = [ | |
entry["entity"]["name"] | |
for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"] | |
if "repo.write" in entry["permissions"] | |
] | |
organizations = [org for org in organizations if org != data["name"]] | |
organizations = [data["name"]] + organizations | |
except Exception as e: | |
warnings.warn(str(e)) | |
gr.Info( | |
"Your user token does not have the necessary permissions to push to organizations." | |
"Please check your OAuth permissions in https://huggingface.co/settings/connected-applications." | |
"Update your token permissions to include repo.write: https://huggingface.co/settings/tokens." | |
) | |
return [] | |
return organizations | |
def get_org_dropdown(oauth_token: Union[OAuthToken, None] = None): | |
if oauth_token is not None: | |
orgs = list_orgs(oauth_token) | |
else: | |
orgs = [] | |
return gr.Dropdown( | |
label="Organization", | |
choices=orgs, | |
value=orgs[0] if orgs else None, | |
allow_custom_value=True, | |
interactive=True, | |
) | |
def swap_visibility(oauth_token: Union[OAuthToken, None]): | |
if oauth_token: | |
return gr.update(elem_classes=["main_ui_logged_in"]) | |
else: | |
return gr.update(elem_classes=["main_ui_logged_out"]) | |
def get_argilla_client() -> Union[rg.Argilla, None]: | |
return argilla_client | |
def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]: | |
return list(set([label.lower().strip() for label in labels])) if labels else [] | |
def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]: | |
if column_name in dataframe.columns: | |
return dataframe[column_name].tolist() | |
else: | |
raise ValueError(f"Column '{column_name}' does not exist.") | |
def process_columns( | |
dataframe, | |
instruction_column: str, | |
response_columns: Union[str, List[str]], | |
) -> List[dict]: | |
instruction_column = [instruction_column] | |
if isinstance(response_columns, str): | |
response_columns = [response_columns] | |
data = [] | |
for _, row in dataframe.iterrows(): | |
instruction = "" | |
for col in instruction_column: | |
value = row[col] | |
if isinstance(value, (list, np.ndarray)): | |
user_contents = [d["content"] for d in value if d.get("role") == "user"] | |
if user_contents: | |
instruction = user_contents[-1] | |
elif isinstance(value, str): | |
try: | |
parsed_message = json.loads(value) | |
user_contents = [ | |
d["content"] for d in parsed_message if d.get("role") == "user" | |
] | |
if user_contents: | |
instruction = user_contents[-1] | |
except json.JSONDecodeError: | |
instruction = value | |
else: | |
instruction = "" | |
generations = [] | |
for col in response_columns: | |
value = row[col] | |
if isinstance(value, (list, np.ndarray)): | |
if all(isinstance(item, dict) and "role" in item for item in value): | |
assistant_contents = [ | |
d["content"] for d in value if d.get("role") == "assistant" | |
] | |
if assistant_contents: | |
generations.append(assistant_contents[-1]) | |
else: | |
generations.extend(value) | |
elif isinstance(value, str): | |
try: | |
parsed_message = json.loads(value) | |
assistant_contents = [ | |
d["content"] | |
for d in parsed_message | |
if d.get("role") == "assistant" | |
] | |
if assistant_contents: | |
generations.append(assistant_contents[-1]) | |
except json.JSONDecodeError: | |
generations.append(value) | |
else: | |
pass | |
data.append({"instruction": instruction, "generations": generations}) | |
return data | |
def extract_column_names(prompt_template: str) -> List[str]: | |
env = Environment() | |
parsed_content = env.parse(prompt_template) | |
variables = meta.find_undeclared_variables(parsed_content) | |
return list(variables) | |
def pad_or_truncate_list(lst, target_length): | |
lst = lst or [] | |
lst_length = len(lst) | |
if lst_length >= target_length: | |
return lst[-target_length:] | |
else: | |
return lst + [None] * (target_length - lst_length) | |