|
from openai import OpenAI |
|
import gradio as gr |
|
import os |
|
import json |
|
import functools |
|
import random |
|
import datetime |
|
|
|
api_key = os.environ.get('FEATHERLESS_API_KEY') |
|
client = OpenAI( |
|
base_url="https://api.featherless.ai/v1", |
|
api_key=api_key |
|
) |
|
|
|
def respond(message, history, model): |
|
history_openai_format = [] |
|
for human, assistant in history: |
|
history_openai_format.append({"role": "user", "content": human }) |
|
history_openai_format.append({"role": "assistant", "content":assistant}) |
|
history_openai_format.append({"role": "user", "content": message}) |
|
|
|
response = client.chat.completions.create( |
|
model=model, |
|
messages= history_openai_format, |
|
temperature=1.0, |
|
stream=True, |
|
max_tokens=2000 |
|
) |
|
|
|
partial_message = "" |
|
for chunk in response: |
|
if chunk.choices[0].delta.content is not None: |
|
partial_message = partial_message + chunk.choices[0].delta.content |
|
yield partial_message |
|
|
|
logo = open('./logo.svg').read() |
|
|
|
with open('./model-cache.json', 'r') as f_model_cache: |
|
model_cache = json.load(f_model_cache) |
|
|
|
|
|
model_class_filter = { |
|
"mistral-v02-7b-std-lc": True, |
|
"llama3-8b-8k": True, |
|
"llama2-solar-10b7-4k": True, |
|
"mistral-nemo-12b-lc": True, |
|
"llama2-13b-4k": True, |
|
"llama3-15b-8k": True, |
|
|
|
"qwen2-32b-lc":False, |
|
"llama3-70b-8k":False, |
|
"qwen2-72b-lc":False, |
|
"mixtral-8x22b-lc":False, |
|
"llama3-405b-lc":False, |
|
} |
|
|
|
def build_model_choices(): |
|
all_choices = [] |
|
for model_class in model_cache: |
|
if model_class not in model_class_filter: |
|
print(f"Warning: new model class {model_class}. Treating as blacklisted") |
|
continue |
|
|
|
if not model_class_filter[model_class]: |
|
continue |
|
all_choices += [ (f"{model_id} ({model_class})", model_id) for model_id in model_cache[model_class] ] |
|
|
|
return all_choices |
|
|
|
model_choices = build_model_choices() |
|
|
|
def initial_model(referer=None): |
|
|
|
if referer == 'http://127.0.0.1:7860/': |
|
return 'Sao10K/Venomia-1.1-m7' |
|
|
|
if referer and referer.startswith("https://huggingface.co/"): |
|
possible_model = referer[23:] |
|
full_model_list = functools.reduce(lambda x,y: x+y, model_cache.values(), []) |
|
model_is_supported = possible_model in full_model_list |
|
if model_is_supported: |
|
return possible_model |
|
|
|
|
|
key=os.environ.get('RANDOM_SEED', 'kcOtfNHA+e') |
|
o = random.Random(f"{key}-{datetime.date.today().strftime('%Y-%m-%d')}") |
|
return o.choice(model_choices)[1] |
|
|
|
title_text="HuggingFace's missing inference widget" |
|
with gr.Blocks(title_text, css='.logo-mark { fill: #ffe184; }') as demo: |
|
gr.HTML(""" |
|
<h1 align="center">HuggingFace's missing inference widget</h1> |
|
<p align="center"> |
|
Test any <=15B LLM from the hub. |
|
</p> |
|
<h2 align="center"> |
|
Please select your model from the list 👇 as HF spaces can't see the refering model card. |
|
</h2> |
|
""") |
|
|
|
|
|
|
|
model_selector = gr.Dropdown( |
|
label="Select your Model", |
|
choices=build_model_choices(), |
|
value=initial_model |
|
|
|
) |
|
|
|
gr.ChatInterface( |
|
respond, |
|
additional_inputs=[model_selector], |
|
head=""", |
|
<script>console.log("Hello from gradio!")</script> |
|
""", |
|
) |
|
gr.HTML(f""" |
|
<p align="center"> |
|
Inference by <a href="https://featherless.ai">{logo}</a> |
|
</p> |
|
""") |
|
def update_initial_model_choice(request: gr.Request): |
|
return initial_model(request.headers.get('referer')) |
|
|
|
demo.load(update_initial_model_choice, outputs=model_selector) |
|
|
|
demo.launch() |
|
|