File size: 2,263 Bytes
7416d8a
5c0e14a
7416d8a
52d69de
7416d8a
 
5c0e14a
52d69de
5c0e14a
 
7416d8a
08eb742
52d69de
 
08eb742
5c0e14a
7416d8a
52d69de
5c0e14a
 
7416d8a
 
 
 
 
 
 
 
 
5c0e14a
 
7416d8a
 
5c0e14a
7416d8a
 
 
17fba42
7416d8a
 
 
 
 
 
52d69de
7416d8a
 
 
 
 
 
 
 
 
 
 
 
52d69de
7416d8a
 
 
 
 
 
 
 
52d69de
 
5c0e14a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import gradio as gr
import random
from huggingface_hub import InferenceClient

API_URL = "https://api-inference.huggingface.co/models/"

client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")

def format_prompt(message, history):
    prompt = "You're a helpful assistant."
    for user_prompt, bot_response in history:
        prompt += f" [INST] {user_prompt} [/INST] {bot_response}</s> "
    prompt += f" [INST] {message} [/INST]"
    return prompt

def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature) if temperature > 0 else 0.01
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(0, 10**7),
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

def load_database():
    try:
        with open("database.json", "r", encoding="utf-8") as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        print("Error loading database: File not found or invalid format. Creating an empty database.")
        return []

def save_database(data):
    try:
        with open("database.json", "w", encoding="utf-8") as f:
            json.dump(data, f, indent=4)
    except (IOError, json.JSONEncodeError):
        print("Error saving database: Encountered an issue while saving.")

def chat_interface(message):
    database = load_database()

    if (message, None) not in database:
        response = next(generate(message, history=[]))
        database.append((message, response))
        save_database(database)
    else:
        _, stored_response = next(item for item in database if item[0] == message)
        response = stored_response

    return response

with gr.Interface(fn=chat_interface, inputs="textbox", outputs="textbox", title="Chat Interface") as iface:
    iface.launch()