awacke1 commited on
Commit
1530ac8
1 Parent(s): dc25cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -45
app.py CHANGED
@@ -11,44 +11,42 @@ headers = {
11
  "Content-Type": "application/json"
12
  }
13
 
14
- endpoint_url = API_URL
15
- hf_token = API_KEY
16
- client = InferenceClient(endpoint_url, token=hf_token)
17
- gen_kwargs = dict(
18
- max_new_tokens=512,
19
- top_k=30,
20
- top_p=0.9,
21
- temperature=0.2,
22
- repetition_penalty=1.02,
23
- stop_sequences=["\nUser:", "<|endoftext|>", "</s>"],
24
- )
25
  prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
26
- stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs)
27
- report=[]
28
- res_box = st.empty()
29
- collected_chunks=[]
30
- collected_messages=[]
31
- for r in stream:
32
- if r.token.special:
33
- continue
34
- if r.token.text in gen_kwargs["stop_sequences"]:
35
- break
36
- collected_chunks.append(r.token.text)
37
- chunk_message = r.token.text
38
- collected_messages.append(chunk_message)
39
 
40
- try:
41
- report.append(content)
42
- if len(r.token.text) > 0:
43
- result="".join(report).strip()
44
- res_box.markdown(f'*{result}*')
45
- except:
46
- st.write(' ')
47
-
48
- #full_reply = ''.join()
49
- #st.markdown(r.token.text, end = "")
50
- #st.write(r.token.text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
52
 
53
  def query(payload):
54
  response = requests.post(API_URL, headers=headers, json=payload)
@@ -60,17 +58,11 @@ def get_output(prompt):
60
 
61
  def main():
62
  st.title("Medical Llama Test Bench with Inference Endpoints Llama 7B")
63
- example_input = st.text_input("Enter your example text:")
 
64
 
65
- if st.button("Summarize with Variation 1"):
66
- prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface. {example_input}"
67
- output = get_output(prompt)
68
- st.markdown(f"**Output:** {output}")
69
-
70
- if st.button("Summarize with Variation 2"):
71
- prompt = f"Provide a summary of the medical transcription. Highlight the important entities, features, and relationships to CCDA and FHIR objects. {example_input}"
72
- output = get_output(prompt)
73
- st.markdown(f"**Output:** {output}")
74
 
75
  if __name__ == "__main__":
76
  main()
 
11
  "Content-Type": "application/json"
12
  }
13
 
14
+ # Prompt Set of Examples:
 
 
 
 
 
 
 
 
 
 
15
  prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def StreamLLMChatResponse(prompt):
18
+ endpoint_url = API_URL
19
+ hf_token = API_KEY
20
+ client = InferenceClient(endpoint_url, token=hf_token)
21
+ gen_kwargs = dict(
22
+ max_new_tokens=512,
23
+ top_k=30,
24
+ top_p=0.9,
25
+ temperature=0.2,
26
+ repetition_penalty=1.02,
27
+ stop_sequences=["\nUser:", "<|endoftext|>", "</s>"],
28
+ )
29
+ stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs)
30
+ report=[]
31
+ res_box = st.empty()
32
+ collected_chunks=[]
33
+ collected_messages=[]
34
+ for r in stream:
35
+ if r.token.special:
36
+ continue
37
+ if r.token.text in gen_kwargs["stop_sequences"]:
38
+ break
39
+ collected_chunks.append(r.token.text)
40
+ chunk_message = r.token.text
41
+ collected_messages.append(chunk_message)
42
 
43
+ try:
44
+ report.append(r.token.text)
45
+ if len(r.token.text) > 0:
46
+ result="".join(report).strip()
47
+ res_box.markdown(f'*{result}*')
48
+ except:
49
+ st.write(' ')
50
 
51
  def query(payload):
52
  response = requests.post(API_URL, headers=headers, json=payload)
 
58
 
59
  def main():
60
  st.title("Medical Llama Test Bench with Inference Endpoints Llama 7B")
61
+ prompt = f"Write instructions to teach anyone to write a discharge plan. List the entities, features and relationships to CCDA and FHIR objects in boldface."
62
+ example_input = st.text_input("Enter your example text:", value=prompt)
63
 
64
+ if st.button("Run Prompt With Dr Llama"):
65
+ StreamLLMChatResponse(example_input)
 
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
  main()