Spaces:
Runtime error
Runtime error
File size: 2,510 Bytes
640efa7 c5b0047 d9f8b28 c5b0047 4fe68bc c5b0047 640efa7 c5b0047 d9f8b28 c5b0047 640efa7 d9f8b28 c5b0047 d9f8b28 c5b0047 4fe68bc 91fbea0 4fe68bc 2e517b5 4fe68bc 5b981d0 4fe68bc 5b981d0 4fe68bc 4bb483c 4fe68bc 0af8d1d 2e517b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from dataclasses import dataclass
from enum import IntEnum
import yaml
from typing import Dict, Optional, List
from pydantic import BaseModel, ValidationError
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from openai import OpenAI
class OAuthProvider(IntEnum):
NONE = 0
GOOGLE = 1
@dataclass
class User:
oauth: OAuthProvider
username: str
permissions_id: str
class PileConfig(BaseModel):
file2persona: Dict[str, str]
file2prefix: Dict[str, str]
persona2system: Dict[str, str]
prompt: str
class InferenceConfig(BaseModel):
chat_template: str
permissions: Dict[str, list] = {}
class RepoConfig(BaseModel):
name: str
tag: str
class ModelConfig(BaseModel):
pile: PileConfig
inference: InferenceConfig
repo: RepoConfig
@classmethod
def from_yaml(cls, yaml_file = "datasets/config.yaml"):
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
try:
return cls(**data)
except ValidationError as e:
raise e
class Client:
def __init__(self, api_url, api_key, personas = {}):
self.api_url = api_url
self.api_key = api_key
self.input_personas = personas
self.init_all()
def init_all(self):
self.init_client()
self.get_metadata()
self.get_personas()
def init_client(self):
self.openai = OpenAI(
base_url=f"{self.api_url}/v1",
api_key=self.api_key,
)
def get_metadata(self):
models = self.openai.models.list()
vllm_model_name = models.data[0].id
model_name, *suffix = vllm_model_name.split("@")
revision = dict(enumerate(suffix)).get(0, None)
self.vllm_model_name = vllm_model_name
self.model_name = model_name
self.revision = revision
def get_personas(self):
personas = {}
if self.revision is not None:
try:
config_path = hf_hub_download(self.model_name, "config.yaml",
subfolder="datasets",
revision=self.revision)
self.config = ModelConfig.from_yaml(config_path)
personas = self.config.pile.persona2system
except EntryNotFoundError:
pass
personas["vanilla"] = None
self.personas = self.input_personas | personas |