Carlos Rosas commited on
Commit
130c728
·
verified ·
1 Parent(s): 7ce59b1

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +7 -7
  2. app(1).py +170 -0
  3. gitattributes +5 -0
  4. gitattributes(1) +36 -0
  5. requirements.txt +12 -0
  6. theme_builder.py +3 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Testrag
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: test_carlos
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Cassandre
3
+ emoji: 📜
4
+ colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app(1).py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+ from vllm import LLM, SamplingParams
5
+ import torch
6
+ import gradio as gr
7
+ import json
8
+ import os
9
+ import shutil
10
+ import requests
11
+ import lancedb
12
+ import pandas as pd
13
+
14
+ # Define the device
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Define variables
18
+ temperature = 0.7
19
+ max_new_tokens = 3000
20
+ top_p = 0.95
21
+ repetition_penalty = 1.2
22
+
23
+ model_name = "PleIAs/Cassandre-RAG"
24
+
25
+ # Initialize vLLM
26
+ llm = LLM(model_name, max_model_len=8128)
27
+
28
+ # Connect to the LanceDB database
29
+ db = lancedb.connect("content/lancedb_data")
30
+ table = db.open_table("eduv1")
31
+
32
+ def hybrid_search(text):
33
+ results = table.search(text, query_type="hybrid").limit(6).to_pandas()
34
+
35
+ document = []
36
+ document_html = []
37
+ for _, row in results.iterrows():
38
+ hash_id = str(row['hash'])
39
+ title = row['main_title']
40
+ #content = row['text'][:100] + "..." # Truncate the text for preview
41
+ content = row['text']
42
+
43
+ document.append(f"**{hash_id}**\n{title}\n{content}")
44
+ document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>')
45
+
46
+ document = "\n\n".join(document)
47
+ document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
48
+ return document, document_html
49
+
50
+ class CassandreChatBot:
51
+ def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."):
52
+ self.system_prompt = system_prompt
53
+
54
+ def predict(self, user_message):
55
+ fiches, fiches_html = hybrid_search(user_message)
56
+ sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"])
57
+
58
+ detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n"""
59
+
60
+ prompts = [detailed_prompt]
61
+ outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
62
+ generated_text = outputs[0].outputs[0].text
63
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
64
+ fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
65
+ return generated_text, fiches_html
66
+
67
+ def format_references(text):
68
+ ref_start_marker = '<ref text="'
69
+ ref_end_marker = '</ref>'
70
+
71
+ parts = []
72
+ current_pos = 0
73
+ ref_number = 1
74
+
75
+ while True:
76
+ start_pos = text.find(ref_start_marker, current_pos)
77
+ if start_pos == -1:
78
+ parts.append(text[current_pos:])
79
+ break
80
+
81
+ parts.append(text[current_pos:start_pos])
82
+
83
+ end_pos = text.find('">', start_pos)
84
+ if end_pos == -1:
85
+ break
86
+
87
+ ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
88
+ ref_text_encoded = ref_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
89
+
90
+ ref_end_pos = text.find(ref_end_marker, end_pos)
91
+ if ref_end_pos == -1:
92
+ break
93
+
94
+ ref_id = text[end_pos + 2:ref_end_pos].strip()
95
+
96
+ tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>'
97
+ parts.append(tooltip_html)
98
+
99
+ current_pos = ref_end_pos + len(ref_end_marker)
100
+ ref_number = ref_number + 1
101
+
102
+ return ''.join(parts)
103
+
104
+ # Initialize the CassandreChatBot
105
+ cassandre_bot = CassandreChatBot()
106
+
107
+ # CSS for styling
108
+ css = """
109
+ .generation {
110
+ margin-left:2em;
111
+ margin-right:2em;
112
+ }
113
+ :target {
114
+ background-color: #CCF3DF;
115
+ }
116
+ .source {
117
+ float:left;
118
+ max-width:17%;
119
+ margin-left:2%;
120
+ }
121
+ .tooltip {
122
+ position: relative;
123
+ cursor: pointer;
124
+ font-variant-position: super;
125
+ color: #97999b;
126
+ }
127
+
128
+ .tooltip:hover::after {
129
+ content: attr(data-text);
130
+ position: absolute;
131
+ left: 0;
132
+ top: 120%;
133
+ white-space: pre-wrap;
134
+ width: 500px;
135
+ max-width: 500px;
136
+ z-index: 1;
137
+ background-color: #f9f9f9;
138
+ color: #000;
139
+ border: 1px solid #ddd;
140
+ border-radius: 5px;
141
+ padding: 5px;
142
+ display: block;
143
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
144
+ }
145
+ """
146
+
147
+ # Gradio interface
148
+ def gradio_interface(user_message):
149
+ response, sources = cassandre_bot.predict(user_message)
150
+ return response, sources
151
+
152
+ # Create Gradio app
153
+ demo = gr.Blocks(css=css)
154
+
155
+ with demo:
156
+ gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""")
157
+ with gr.Row():
158
+ with gr.Column(scale=2):
159
+ text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
160
+ text_button = gr.Button("Interroger Cassandre")
161
+ with gr.Column(scale=3):
162
+ text_output = gr.HTML(label="La réponse de Cassandre")
163
+ with gr.Row():
164
+ embedding_output = gr.HTML(label="Les sources utilisées")
165
+
166
+ text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
167
+
168
+ # Launch the app
169
+ if __name__ == "__main__":
170
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ education_corrected/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
2
+ education_corrected/e150eb41-e894-45c4-b97c-80ced9ff2123/data_level0.bin filter=lfs diff=lfs merge=lfs -text
3
+ education_corrected/a9ac8f33-9498-450a-ae99-f116efb66330/data_level0.bin filter=lfs diff=lfs merge=lfs -text
4
+ education_corrected/6af97eb5-0cfa-40b2-a4df-732ca13bd66a/data_level0.bin filter=lfs diff=lfs merge=lfs -text
5
+ content/lancedb_data/eduv1.lance/_indices/fts/55ac048af92d47c0903552a94300d4e3.store filter=lfs diff=lfs merge=lfs -text
gitattributes(1) ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ education_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ einops
4
+ accelerate
5
+ tiktoken
6
+ scipy
7
+ vllm
8
+ lancedb
9
+ sentence_transformers
10
+ gradio
11
+ pandas
12
+ tantivy
theme_builder.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import gradio as gr
2
+
3
+ gr.themes.builder()