File size: 4,573 Bytes
c108da3
caa64e7
c108da3
f84e083
 
 
9441c54
c108da3
7a31970
4849bdc
7bc74bc
4849bdc
c108da3
4849bdc
f84e083
 
 
ce8dee8
 
f84e083
c0b9a69
d11c4a1
c108da3
 
f12ecf0
f84e083
 
 
 
e40242b
245c296
f84e083
 
 
5b8435c
 
c108da3
 
 
 
 
5b8435c
 
 
c108da3
9441c54
c108da3
 
 
4cc4589
c108da3
4cc4589
 
 
9441c54
 
4cc4589
9441c54
 
 
4cc4589
9441c54
 
c108da3
9441c54
c108da3
d0c61b6
215f4a9
c108da3
215f4a9
d0c61b6
f84e083
c108da3
 
 
 
f84e083
 
c108da3
9441c54
d0c61b6
d11c4a1
1aafe2e
d11c4a1
c1fff5f
 
 
 
d11c4a1
c1fff5f
 
 
 
 
 
27153aa
c0b9a69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98648e1
d33d65c
c0b9a69
 
d33d65c
c108da3
 
 
c0b9a69
 
d33d65c
c108da3
c0b9a69
c108da3
27153aa
c0b9a69
9441c54
ce8dee8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from typing import Generator
import json  # Asegúrate de que esta línea esté al principio del archivo
import nltk
import os
from transformers import pipeline, AutoTokenizer,AutoModelForSeq2SeqLM


nltk.data.path.append(os.getenv('NLTK_DATA'))

app = FastAPI()

# Initialize the InferenceClient with your model
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")



# summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")


class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.8
    max_new_tokens: int = 12000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(current_prompt, history):
    formatted_history = "<s>"
    for entry in history:
        if entry["role"] == "user":
            formatted_history += f"[USER] {entry['content']} [/USER]"
        elif entry["role"] == "assistant":
            formatted_history += f"[ASSISTANT] {entry['content']} [/ASSISTANT]"
    formatted_history += f"[USER] {current_prompt} [/USER]</s>"
    return formatted_history


def generate_stream(item: Item) -> Generator[bytes, None, None]:
    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    # Estimate token count for the formatted_prompt
    input_token_count = len(nltk.word_tokenize(formatted_prompt))  # NLTK tokenization

    # Ensure total token count doesn't exceed the maximum limit
    max_tokens_allowed = 32768
    max_new_tokens_adjusted = max(1, min(item.max_new_tokens, max_tokens_allowed - input_token_count))

    generate_kwargs = {
        "temperature": item.temperature,
        "max_new_tokens": max_new_tokens_adjusted,
        "top_p": item.top_p,
        "repetition_penalty": item.repetition_penalty,
        "do_sample": True,
        "seed": 42,
    }

    # Stream the response from the InferenceClient
    for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True):
        # This assumes 'details=True' gives you a structure where you can access the text like this
        chunk = {
            "text": response.token.text,
            "complete": response.generated_text is not None  # Adjust based on how you detect completion
        }
        yield json.dumps(chunk).encode("utf-8") + b"\n"


class SummarizeRequest(BaseModel):
    text: str

@app.post("/generate/")
async def generate_text(item: Item):
    # Stream response back to the client
    return StreamingResponse(generate_stream(item), media_type="application/x-ndjson")

def split_text_by_tokens(text, max_tokens=1024):
    print("Tokenizing text...")
    tokens = tokenizer.tokenize(text)
    
    chunks = []
    token_counts = []
    
    for i in range(0, len(tokens), max_tokens):
        chunk = tokenizer.convert_tokens_to_string(tokens[i:i+max_tokens])
        chunks.append(chunk)
        token_counts.append(len(tokenizer.encode(chunk)))  # Count tokens of the current chunk

    print("Tokenization complete.")
    return chunks, token_counts

# Load the tokenizer and model from Hugging Face Hub
tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus")

def summarize_legal_text(text):
    # Ensure the text is within the maximum length limit for the model
    inputs = tokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
    
    # Generate summary
    summary_ids = model.generate(
        inputs,
        num_beams=9,
        no_repeat_ngram_size=3,
        length_penalty=2.0,
        min_length=150,
        max_length=250,
        early_stopping=True
    )
    
    # Decode generated tokens to a string
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary


class SummarizeRequest(BaseModel):
    text: str

@app.post("/summarize")
async def summarize_text(request: SummarizeRequest):
    try:
        # Use the newly defined summarization function
        summarized_text = summarize_legal_text(request.text)
        return JSONResponse(content={"summary": summarized_text})
    except Exception as e:
        print(f"Error during summarization: {e}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)