Carlos Rosas commited on
Commit
aa38253
·
verified ·
1 Parent(s): 8088280

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -16
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import transformers
2
  import re
3
  from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
- from vllm import LLM, SamplingParams
5
  import torch
6
  import gradio as gr
7
  import json
@@ -15,15 +14,29 @@ import pandas as pd
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # Define variables
18
- temperature = 0.7
19
  max_new_tokens = 3000
20
  top_p = 0.95
21
- repetition_penalty = 1.2
 
 
22
 
23
- model_name = "PleIAs/llama-reasoning-rag"
24
 
25
- # Initialize vLLM
26
- llm = LLM(model_name, max_model_len=8128)
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Connect to the LanceDB database
29
  db = lancedb.connect("content 5/lancedb_data")
@@ -37,7 +50,6 @@ def hybrid_search(text):
37
  for _, row in results.iterrows():
38
  hash_id = str(row['hash'])
39
  title = row['section']
40
- #content = row['text'][:100] + "..." # Truncate the text for preview
41
  content = row['text']
42
 
43
  document.append(f"**{hash_id}**\n{title}\n{content}")
@@ -53,16 +65,37 @@ class CassandreChatBot:
53
 
54
  def predict(self, user_message):
55
  fiches, fiches_html = hybrid_search(user_message)
56
- sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"])
57
-
58
  detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Analysis ###\n"""
59
 
60
- prompts = [detailed_prompt]
61
- outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
62
- generated_text = outputs[0].outputs[0].text
63
- generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
64
- fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
65
- return generated_text, fiches_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def format_references(text):
68
  ref_start_marker = '<ref text="'
@@ -104,7 +137,7 @@ def format_references(text):
104
  # Initialize the CassandreChatBot
105
  cassandre_bot = CassandreChatBot()
106
 
107
- # CSS for styling
108
  css = """
109
  .generation {
110
  margin-left:2em;
 
1
  import transformers
2
  import re
3
  from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
 
4
  import torch
5
  import gradio as gr
6
  import json
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  # Define variables
17
+ temperature = 0.4
18
  max_new_tokens = 3000
19
  top_p = 0.95
20
+ repetition_penalty = 1.0
21
+ min_new_tokens = 1000
22
+ early_stopping = False
23
 
24
+ model_name = "PleIAs/Pleias-Rag"
25
 
26
+ # Get Hugging Face token from environment variable
27
+ hf_token = os.environ.get('HF_TOKEN')
28
+ if not hf_token:
29
+ raise ValueError("Please set the HF_TOKEN environment variable")
30
+
31
+ # Initialize model and tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
33
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token)
34
+ model.to(device)
35
+
36
+ # Set tokenizer configuration
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ tokenizer.pad_token_id = tokenizer.eos_token_id
39
+ tokenizer.eos_token = "<|end_of_text|>"
40
 
41
  # Connect to the LanceDB database
42
  db = lancedb.connect("content 5/lancedb_data")
 
50
  for _, row in results.iterrows():
51
  hash_id = str(row['hash'])
52
  title = row['section']
 
53
  content = row['text']
54
 
55
  document.append(f"**{hash_id}**\n{title}\n{content}")
 
65
 
66
  def predict(self, user_message):
67
  fiches, fiches_html = hybrid_search(user_message)
68
+
 
69
  detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Analysis ###\n"""
70
 
71
+ # Convert inputs to tensor
72
+ input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device)
73
+ attention_mask = torch.ones_like(input_ids)
74
+
75
+ try:
76
+ output = model.generate(
77
+ input_ids,
78
+ attention_mask=attention_mask,
79
+ max_new_tokens=max_new_tokens,
80
+ do_sample=False,
81
+ early_stopping=early_stopping,
82
+ min_new_tokens=min_new_tokens,
83
+ temperature=temperature,
84
+ repetition_penalty=repetition_penalty,
85
+ pad_token_id=tokenizer.pad_token_id,
86
+ eos_token_id=tokenizer.eos_token_id
87
+ )
88
+
89
+ generated_text = tokenizer.decode(output[0])
90
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
91
+ fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
92
+ return generated_text, fiches_html
93
+
94
+ except Exception as e:
95
+ print(f"Error during generation: {str(e)}")
96
+ import traceback
97
+ traceback.print_exc()
98
+ return None, None
99
 
100
  def format_references(text):
101
  ref_start_marker = '<ref text="'
 
137
  # Initialize the CassandreChatBot
138
  cassandre_bot = CassandreChatBot()
139
 
140
+ # CSS for styling
141
  css = """
142
  .generation {
143
  margin-left:2em;