LLAMA-QA-AudioFiles / llm_ops.py
Samarth991's picture
adding HF pipeline for llm model
9c51c0d
raw
history blame
No virus
913 Bytes
import os
import torch
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer
import transformers
def get_openai_chat_model(API_key):
try:
from langchain.llms import OpenAI
except ImportError as err:
raise "{}, unable to load openAI. Please install openai and add OPENAIAPI_KEY"
os.environ["OPENAI_API_KEY"] = API_key
llm = OpenAI()
return llm
def get_llama_model():
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline("text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
max_length=1000,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
return llm