File size: 913 Bytes
2faf743
9c51c0d
 
 
 
2faf743
 
 
 
 
 
 
 
 
 
 
9c51c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os 
import torch 
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer
import transformers

def get_openai_chat_model(API_key):
    try:
        from langchain.llms import OpenAI
    except ImportError as err:
        raise "{}, unable to load openAI. Please install openai and add OPENAIAPI_KEY"
    os.environ["OPENAI_API_KEY"] = API_key
    llm = OpenAI()
    return llm


def get_llama_model():
    model = "meta-llama/Llama-2-7b-chat-hf"

    tokenizer = AutoTokenizer.from_pretrained(model)

    pipeline = transformers.pipeline("text-generation", 
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        max_length=1000,
        eos_token_id=tokenizer.eos_token_id
        )
    llm = HuggingFacePipeline(pipeline = pipeline, model_kwargs = {'temperature':0})
    return llm