pminervini commited on
Commit
8400fbd
1 Parent(s): 45b79c0
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import gradio as gr
3
- from transformers import pipeline
4
  from elasticsearch import Elasticsearch
5
 
 
6
  # Connect to Elasticsearch
7
  es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
8
  basic_auth=('elastic', os.environ['ES_PASSWORD']),
@@ -58,8 +59,18 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
58
  for message in messages:
59
  print('MSG', message)
60
 
 
 
 
 
 
 
 
 
 
 
61
  # Generate response using the LLM
62
- response = generator(messages, max_new_tokens=64, return_full_text=False)
63
 
64
  # Return the generated text and the documents
65
  return response[0]['generated_text'], joined_docs
 
1
  import os
2
  import gradio as gr
3
+ from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
4
  from elasticsearch import Elasticsearch
5
 
6
+
7
  # Connect to Elasticsearch
8
  es = Elasticsearch(hosts=["https://data.neuralnoise.com:9200"],
9
  basic_auth=('elastic', os.environ['ES_PASSWORD']),
 
59
  for message in messages:
60
  print('MSG', message)
61
 
62
+ # Define the stopping criteria using MaxTimeCriteria
63
+ stopping_criteria = StoppingCriteriaList([MaxTimeCriteria(32)])
64
+
65
+ # Define the generation_kwargs with stopping criteria
66
+ generation_kwargs = {
67
+ "max_new_tokens": 128,
68
+ "generation_kwargs": {"stopping_criteria": stopping_criteria},
69
+ "return_full_text": False
70
+ }
71
+
72
  # Generate response using the LLM
73
+ response = generator(messages, **generation_kwargs)
74
 
75
  # Return the generated text and the documents
76
  return response[0]['generated_text'], joined_docs