Veda0718 commited on
Commit
0308771
·
verified ·
1 Parent(s): b6710e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -84
app.py CHANGED
@@ -13,91 +13,100 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
  from langchain_core.prompts import PromptTemplate
15
  from langchain.chains import RetrievalQA
 
16
 
 
17
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
18
 
19
- loader = PyPDFLoader("Medical_Book.pdf")
20
- docs = loader.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- embeddings = HuggingFaceInstructEmbeddings(
23
- model_name="hkunlp/instructor-large", model_kwargs={"device": DEVICE}
24
- )
25
-
26
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
27
- texts = text_splitter.split_documents(docs)
28
-
29
- db = Chroma.from_documents(texts, embeddings, persist_directory="db")
30
-
31
- model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
32
- model_basename = "model"
33
-
34
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
35
-
36
- model = AutoGPTQForCausalLM.from_quantized(
37
- model_name_or_path,
38
- revision="gptq-4bit-128g-actorder_True",
39
- model_basename=model_basename,
40
- use_safetensors=True,
41
- trust_remote_code=True,
42
- inject_fused_attention=False,
43
- device=DEVICE,
44
- quantize_config=None,
45
- )
46
-
47
- DEFAULT_SYSTEM_PROMPT = """
48
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
49
-
50
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
51
- """.strip()
52
-
53
-
54
- def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
55
- return f"""
56
- [INST] <>
57
- {system_prompt}
58
- <>
59
-
60
- {prompt} [/INST]
61
- """.strip()
62
-
63
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
-
65
- text_pipeline = pipeline(
66
- "text-generation",
67
- model=model,
68
- tokenizer=tokenizer,
69
- max_new_tokens=1024,
70
- temperature=0,
71
- top_p=0.95,
72
- repetition_penalty=1.15,
73
- streamer=streamer,
74
- )
75
-
76
- llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0})
77
-
78
- SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
79
-
80
- template = generate_prompt(
81
- """
82
- {context}
83
-
84
- Question: {question}
85
- """,
86
- system_prompt=SYSTEM_PROMPT,
87
- )
88
-
89
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
90
-
91
- qa_chain = RetrievalQA.from_chain_type(
92
- llm=llm,
93
- chain_type="stuff",
94
- retriever=db.as_retriever(search_kwargs={"k": 2}),
95
- return_source_documents=True,
96
- chain_type_kwargs={"prompt": prompt},
97
- )
98
-
99
- # result = qa_chain("what is Doppler ultrasonography?")
100
- # print(result["source_documents"][0].page_content)
101
  st.title("Medical Chatbot")
102
 
103
  if "history" not in st.session_state:
@@ -106,6 +115,12 @@ if "history" not in st.session_state:
106
  user_input = st.text_input("Ask a question:", key="input")
107
 
108
  if st.button("Submit"):
109
- if user_input:
110
- result = qa_chain(user_input)
111
- st.write(result)
 
 
 
 
 
 
 
13
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
  from langchain_core.prompts import PromptTemplate
15
  from langchain.chains import RetrievalQA
16
+ from streamlit_chat import message
17
 
18
+ # Check if device is available
19
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Initialize everything in session state to avoid reloading
22
+ if "initialized" not in st.session_state:
23
+ st.session_state.initialized = False
24
+
25
+ if not st.session_state.initialized:
26
+ # Load PDF
27
+ loader = PyPDFLoader("Medical_Book.pdf")
28
+ docs = loader.load()
29
+
30
+ # Initialize embeddings
31
+ embeddings = HuggingFaceInstructEmbeddings(
32
+ model_name="hkunlp/instructor-large", model_kwargs={"device": DEVICE}
33
+ )
34
+
35
+ # Split documents into chunks
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
37
+ texts = text_splitter.split_documents(docs)
38
+
39
+ # Create Chroma vectorstore
40
+ db = Chroma.from_documents(texts, embeddings, persist_directory="db")
41
+
42
+ # Load model and tokenizer
43
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
44
+ model_basename = "model"
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
47
+ model = AutoGPTQForCausalLM.from_quantized(
48
+ model_name_or_path,
49
+ revision="gptq-4bit-128g-actorder_True",
50
+ model_basename=model_basename,
51
+ use_safetensors=True,
52
+ trust_remote_code=True,
53
+ inject_fused_attention=False,
54
+ device=DEVICE,
55
+ quantize_config=None,
56
+ )
57
+
58
+ # Set system prompt
59
+ DEFAULT_SYSTEM_PROMPT = """
60
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
61
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
62
+ """.strip()
63
+
64
+ def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
65
+ return f"""
66
+ [INST] <>
67
+ {system_prompt}
68
+ <>
69
+ {prompt} [/INST]
70
+ """.strip()
71
+
72
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
73
+
74
+ text_pipeline = pipeline(
75
+ "text-generation",
76
+ model=model,
77
+ tokenizer=tokenizer,
78
+ max_new_tokens=1024,
79
+ temperature=0,
80
+ top_p=0.95,
81
+ repetition_penalty=1.15,
82
+ streamer=streamer,
83
+ )
84
+
85
+ llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0})
86
+
87
+ SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
88
+
89
+ template = generate_prompt(
90
+ """
91
+ {context}
92
+ Question: {question}
93
+ """,
94
+ system_prompt=SYSTEM_PROMPT,
95
+ )
96
+
97
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
98
+
99
+ qa_chain = RetrievalQA.from_chain_type(
100
+ llm=llm,
101
+ chain_type="stuff",
102
+ retriever=db.as_retriever(search_kwargs={"k": 2}),
103
+ return_source_documents=True,
104
+ chain_type_kwargs={"prompt": prompt},
105
+ )
106
+
107
+ st.session_state.qa_chain = qa_chain
108
+ st.session_state.initialized = True
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  st.title("Medical Chatbot")
111
 
112
  if "history" not in st.session_state:
 
115
  user_input = st.text_input("Ask a question:", key="input")
116
 
117
  if st.button("Submit"):
118
+ if user_input:
119
+ result = st.session_state.qa_chain(user_input)
120
+ answer = result
121
+ st.session_state.history.append({"question": user_input, "answer": answer})
122
+
123
+ # Display chat history using streamlit-chat
124
+ for i, chat in enumerate(st.session_state.history):
125
+ message(chat['question'], is_user=True, key=f"user_{i}")
126
+ message(chat['answer'], key=f"bot_{i}")