oflakne26 commited on
Commit
d17ac0e
·
verified ·
1 Parent(s): 9ab2ecb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -12
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
2
- from typing import Any, Dict
3
  from pydantic import BaseModel
4
  from os import getenv
5
  from huggingface_hub import InferenceClient
@@ -24,6 +24,8 @@ class InputData(BaseModel):
24
  system_prompt: str
25
  user_input: str
26
  history: str = ""
 
 
27
 
28
  class WordCheckData(BaseModel):
29
  string: str
@@ -31,12 +33,21 @@ class WordCheckData(BaseModel):
31
 
32
  @app.post("/generate-response/")
33
  async def generate_response(data: InputData) -> Dict[str, Any]:
34
- client = InferenceClient(model=data.model, token=HF_TOKEN)
 
 
 
 
 
 
 
35
 
36
- user_sentences = tokenizer.tokenize(data.user_input)
 
 
 
 
37
 
38
- user_input_str = "\n".join(user_sentences)
39
-
40
  data.history += data.prompt_template.replace("{Prompt}", user_input_str)
41
 
42
  inputs = (
@@ -47,6 +58,7 @@ async def generate_response(data: InputData) -> Dict[str, Any]:
47
  seed = random.randint(0, 2**32 - 1)
48
 
49
  try:
 
50
  response = client.text_generation(
51
  inputs,
52
  temperature=1.0,
@@ -56,17 +68,21 @@ async def generate_response(data: InputData) -> Dict[str, Any]:
56
 
57
  response_str = str(response)
58
 
59
- ai_sentences = tokenizer.tokenize(response_str)
 
 
 
 
 
 
 
 
60
 
61
  cleaned_response = {
62
- "New response": ai_sentences,
63
- "Sentence count": min(len(ai_sentences), 3)
64
  }
65
 
66
- ai_response_str = "\n".join(ai_sentences)
67
-
68
- data.history += ai_response_str + "\n"
69
-
70
  return {
71
  "response": cleaned_response,
72
  "history": data.history + data.end_token
 
1
  from fastapi import FastAPI, HTTPException
2
+ from typing import Any, Dict, Optional
3
  from pydantic import BaseModel
4
  from os import getenv
5
  from huggingface_hub import InferenceClient
 
24
  system_prompt: str
25
  user_input: str
26
  history: str = ""
27
+ segment: bool = False
28
+ max_sentences: Optional[int] = None
29
 
30
  class WordCheckData(BaseModel):
31
  string: str
 
33
 
34
  @app.post("/generate-response/")
35
  async def generate_response(data: InputData) -> Dict[str, Any]:
36
+ if data.max_sentences is not None and data.max_sentences != 0:
37
+ data.segment = True
38
+ elif data.max_sentences == 0:
39
+ data.history += data.prompt_template.replace("{Prompt}", data.user_input)
40
+ return {
41
+ "response": "",
42
+ "history": data.history + data.end_token
43
+ }
44
 
45
+ if data.segment:
46
+ user_sentences = tokenizer.tokenize(data.user_input)
47
+ user_input_str = "\n".join(user_sentences)
48
+ else:
49
+ user_input_str = data.user_input
50
 
 
 
51
  data.history += data.prompt_template.replace("{Prompt}", user_input_str)
52
 
53
  inputs = (
 
58
  seed = random.randint(0, 2**32 - 1)
59
 
60
  try:
61
+ client = InferenceClient(model=data.model, token=HF_TOKEN)
62
  response = client.text_generation(
63
  inputs,
64
  temperature=1.0,
 
68
 
69
  response_str = str(response)
70
 
71
+ if data.segment:
72
+ ai_sentences = tokenizer.tokenize(response_str)
73
+ if data.max_sentences is not None:
74
+ ai_sentences = ai_sentences[:data.max_sentences]
75
+ ai_response_str = "\n".join(ai_sentences)
76
+ else:
77
+ ai_response_str = response_str
78
+
79
+ data.history += ai_response_str + "\n"
80
 
81
  cleaned_response = {
82
+ "New response": ai_sentences if data.segment else [response_str],
83
+ "Sentence count": len(ai_sentences) if data.segment else 1
84
  }
85
 
 
 
 
 
86
  return {
87
  "response": cleaned_response,
88
  "history": data.history + data.end_token