File size: 1,914 Bytes
a438652
 
 
944b573
 
3b80834
944b573
3d03ce4
 
 
 
 
944b573
 
a438652
944b573
a438652
 
644ccbc
a438652
 
 
 
 
 
 
 
 
944b573
ab3abd8
a438652
 
 
 
 
 
 
 
 
3d03ce4
a438652
3d03ce4
a438652
704842e
a438652
e9448a1
944b573
 
3d03ce4
944b573
 
 
 
8fa65f5
a438652
bf69a80
e9448a1
 
ab3abd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, Request, Body, HTTPException, Depends
from fastapi.security import APIKeyHeader
from typing import Optional
from huggingface_hub import InferenceClient
import random
import os

API_URL = os.environ.get("API_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME")

client = InferenceClient(MODEL_NAME)
app = FastAPI()

security = APIKeyHeader(name="api_key", auto_error=False)

def get_api_key(api_key: Optional[str] = Depends(security)):
    if api_key is None or api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return api_key

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

@app.post("/api/v1/generate_text", response_model=dict)
def generate_text(
    request: Request,
    body: dict = Body(...),
    api_key: str = Depends(get_api_key)
):
    prompt = body.get("prompt", "")
    sys_prompt = body.get("sysPrompt", "")
    temperature = body.get("temperature", 0.5)
    top_p = body.get("top_p", 0.95)
    max_new_tokens = body.get("max_new_tokens",512)
    repetition_penalty = body.get("repetition_penalty", 1.0)
    print(f"temperature + {temperature}")
    history = []  # You might need to handle this based on your actual usage
    formatted_prompt = format_prompt(f"{sys_prompt}, {prompt}", history)

    stream = client.text_generation(
        formatted_prompt,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(0, 10**7),
        stream=False,
        details=False,
        return_full_text=False
    )

    return {"generated_text": stream}