Marroco93 commited on
Commit
c108da3
1 Parent(s): b86a744

no message

Browse files
Files changed (1) hide show
  1. main.py +34 -17
main.py CHANGED
@@ -1,24 +1,25 @@
1
- from fastapi import FastAPI
2
  from fastapi.responses import StreamingResponse
 
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
6
  from typing import Generator
7
- import json
8
  import nltk
9
  import os
10
  from transformers import pipeline
11
 
12
- # Set up the environment for NLTK
13
  nltk.data.path.append(os.getenv('NLTK_DATA'))
14
 
15
- # Initialize the FastAPI app
16
  app = FastAPI()
17
 
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
21
- # Initialize the summarization pipeline
 
22
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
23
 
24
  class Item(BaseModel):
@@ -30,24 +31,23 @@ class Item(BaseModel):
30
  top_p: float = 0.15
31
  repetition_penalty: float = 1.0
32
 
33
- def summarize_history(history):
34
- # Concatenate all history entries into a single string
35
- full_history = " ".join(entry['content'] for entry in history if entry['role'] == 'user')
36
- # Summarize the history
37
- summarized_history = summarizer(full_history, max_length=1024, truncation=True)
38
- return summarized_history[0]['summary_text']
39
-
40
  def format_prompt(current_prompt, history):
41
  formatted_history = "<s>"
42
- formatted_history += f"[HISTORY] {history} [/HISTORY]"
 
 
 
 
43
  formatted_history += f"[USER] {current_prompt} [/USER]</s>"
44
  return formatted_history
45
 
 
46
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
47
- summarized_history = summarize_history(item.history)
48
- formatted_prompt = format_prompt(item.prompt, summarized_history)
49
- input_token_count = len(nltk.word_tokenize(formatted_prompt))
50
 
 
51
  max_tokens_allowed = 32768
52
  max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
53
 
@@ -60,16 +60,33 @@ def generate_stream(item: Item) -> Generator[bytes, None, None]:
60
  "seed": 42,
61
  }
62
 
 
63
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
 
64
  chunk = {
65
  "text": response.token.text,
66
- "complete": response.generated_text is not None
67
  }
68
  yield json.dumps(chunk).encode("utf-8") + b"\n"
69
 
 
 
 
 
70
  @app.post("/generate/")
71
  async def generate_text(item: Item):
 
72
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
73
 
 
 
 
 
 
 
 
 
 
 
74
  if __name__ == "__main__":
75
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
+ from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
5
  from huggingface_hub import InferenceClient
6
  import uvicorn
7
  from typing import Generator
8
+ import json # Asegúrate de que esta línea esté al principio del archivo
9
  import nltk
10
  import os
11
  from transformers import pipeline
12
 
13
+
14
  nltk.data.path.append(os.getenv('NLTK_DATA'))
15
 
 
16
  app = FastAPI()
17
 
18
  # Initialize the InferenceClient with your model
19
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
20
 
21
+ # summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
22
+
23
  summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
24
 
25
  class Item(BaseModel):
 
31
  top_p: float = 0.15
32
  repetition_penalty: float = 1.0
33
 
 
 
 
 
 
 
 
34
  def format_prompt(current_prompt, history):
35
  formatted_history = "<s>"
36
+ for entry in history:
37
+ if entry["role"] == "user":
38
+ formatted_history += f"[USER] {entry['content']} [/USER]"
39
+ elif entry["role"] == "assistant":
40
+ formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
41
  formatted_history += f"[USER] {current_prompt} [/USER]</s>"
42
  return formatted_history
43
 
44
+
45
  def generate_stream(item: Item) -> Generator[bytes, None, None]:
46
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
47
+ # Estimate token count for the formatted_prompt
48
+ input_token_count = len(nltk.word_tokenize(formatted_prompt)) # NLTK tokenization
49
 
50
+ # Ensure total token count doesn't exceed the maximum limit
51
  max_tokens_allowed = 32768
52
  max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))
53
 
 
60
  "seed": 42,
61
  }
62
 
63
+ # Stream the response from the InferenceClient
64
  for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
65
+ # This assumes 'details=True' gives you a structure where you can access the text like this
66
  chunk = {
67
  "text": response.token.text,
68
+ "complete": response.generated_text is not None # Adjust based on how you detect completion
69
  }
70
  yield json.dumps(chunk).encode("utf-8") + b"\n"
71
 
72
+
73
+ class SummarizeRequest(BaseModel):
74
+ text: str
75
+
76
  @app.post("/generate/")
77
  async def generate_text(item: Item):
78
+ # Stream response back to the client
79
  return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")
80
 
81
+ @app.post("/summarize")
82
+ async def summarize_text(request: SummarizeRequest):
83
+ try:
84
+ # Perform the summarization
85
+ summary = summarizer(request.text, max_length=130, min_length=30, do_sample=False)
86
+ return JSONResponse(content={"summary": summary[0]['summary_text']})
87
+ except Exception as e:
88
+ # Handle exceptions that could arise during summarization
89
+ raise HTTPException(status_code=500, detail=str(e))
90
+
91
  if __name__ == "__main__":
92
  uvicorn.run(app, host="0.0.0.0", port=8000)