pminervini commited on
Commit
482d0d4
1 Parent(s): 8400fbd
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import gradio as gr
 
 
3
  from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
4
  from elasticsearch import Elasticsearch
5
 
@@ -38,7 +40,7 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
38
  A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
39
  """
40
  # Load your language model from HuggingFace Transformers
41
- generator = pipeline("text-generation", model=model_name)
42
 
43
  num_docs = int(num_docs)
44
 
 
1
  import os
2
  import gradio as gr
3
+
4
+ import torch
5
  from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
6
  from elasticsearch import Elasticsearch
7
 
 
40
  A simple RAG pipeline that retrieves documents and uses them to enrich the context for the LLM.
41
  """
42
  # Load your language model from HuggingFace Transformers
43
+ generator = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16)
44
 
45
  num_docs = int(num_docs)
46