pminervini commited on
Commit
d2e6098
1 Parent(s): e5b0595
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gradio as gr
3
 
4
  import torch
5
- from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria
6
  from openai import OpenAI
7
 
8
  from elasticsearch import Elasticsearch
@@ -79,8 +79,11 @@ def rag_pipeline(prompt, index="pubmed", num_docs=3, model_name="HuggingFaceH4/z
79
  print('OAI_RESPONSE', openai_res)
80
  response = openai_res.choices[0].message.content.strip()
81
  else:
 
 
 
82
  # Load your language model from HuggingFace Transformers
83
- generator = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
84
 
85
  # Generate response using the HF LLM
86
  hf_response = generator(messages, **generation_kwargs)
 
2
  import gradio as gr
3
 
4
  import torch
5
+ from transformers import pipeline, StoppingCriteriaList, MaxTimeCriteria, AutoTokenizer, AutoModelForCausalLM
6
  from openai import OpenAI
7
 
8
  from elasticsearch import Elasticsearch
 
79
  print('OAI_RESPONSE', openai_res)
80
  response = openai_res.choices[0].message.content.strip()
81
  else:
82
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True)
83
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
84
+
85
  # Load your language model from HuggingFace Transformers
86
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
87
 
88
  # Generate response using the HF LLM
89
  hf_response = generator(messages, **generation_kwargs)