Marroco93 commited on
Commit
0b9c508
1 Parent(s): f12ecf0

no message

Browse files
Files changed (1) hide show
  1. copy.py +75 -0
copy.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ from huggingface_hub import InferenceClient
5
+ import uvicorn
6
+ from typing import Generator
7
+ import json # Asegúrate de que esta línea esté al principio del archivo
8
+ import nltk
9
+ import os
10
+
11
+
12
+ nltk.data.path.append(os.getenv('NLTK_DATA'))
13
+
14
+ app = FastAPI()
15
+
16
+ # Initialize the InferenceClient with your model
17
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
18
+
19
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
20
+
21
+
22
+ class Item(BaseModel):
23
+ prompt: str
24
+ history: list
25
+ system_prompt: str
26
+ temperature: float = 0.8
27
+ max_new_tokens: int = 12000
28
+ top_p: float = 0.15
29
+ repetition_penalty: float = 1.0
30
+
31
+ def format_prompt(current_prompt, history):
32
+ formatted_history = "<s>"
33
+ for entry in history:
34
+ if entry["role"] == "user":
35
+ formatted_history += f"[USER] {entry['content']} [/USER]"
36
+ elif entry["role"] == "assistant":
37
+ formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
38
+ formatted_history += f"[USER] {current_prompt} [/USER]</s>"
39
+ return formatted_history
40
+
41
+
42
+ def generate_stream(item: Item) -> Generator[bytes, None, None]:
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ # Estimate token count for the formatted_prompt
45
+ input_token_count = len(nltk.word_tokenize(formatted_prompt)) # NLTK tokenization
46
+
47
+ # Ensure total token count doesn't exceed the maximum limit
48
+ max_tokens_allowed = 32768
49
+ max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
50
+
51
+ generate_kwargs = {
52
+ "temperature": item.temperature,
53
+ "max_new_tokens": max_new_tokens_adjusted,
54
+ "top_p": item.top_p,
55
+ "repetition_penalty": item.repetition_penalty,
56
+ "do_sample": True,
57
+ "seed": 42,
58
+ }
59
+
60
+ # Stream the response from the InferenceClient
61
+ for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
62
+ # This assumes 'details=True' gives you a structure where you can access the text like this
63
+ chunk = {
64
+ "text": response.token.text,
65
+ "complete": response.generated_text is not None # Adjust based on how you detect completion
66
+ }
67
+ yield json.dumps(chunk).encode("utf-8") + b"\n"
68
+
69
+ @app.post("/generate/")
70
+ async def generate_text(item: Item):
71
+ # Stream response back to the client
72
+ return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
73
+
74
+ if __name__ == "__main__":
75
+ uvicorn.run(app, host="0.0.0.0", port=8000)