|
from transformers import AutoTokenizer |
|
import transformers |
|
import torch |
|
import streamlit as st |
|
import os |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
prompt = st.text_input('Prompt', 'Hello, How you doing ?') |
|
|
|
model = "meta-llama/Llama-2-13b-chat-hf" |
|
|
|
|
|
|
|
|
|
def load_model(model): |
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model, |
|
torch_dtype=torch.float32, |
|
device_map="auto", |
|
do_sample=True, |
|
token=HF_TOKEN, |
|
) |
|
return pipeline |
|
|
|
def get_llama_response(pipeline,prompt): |
|
|
|
sequences = pipeline( |
|
prompt, |
|
top_k=10, |
|
num_return_sequences=1, |
|
max_length=256, |
|
) |
|
print(sequences[0]['generated_text']) |
|
|
|
|
|
pipeline = AutoTokenizer.from_pretrained(model, token=HF_TOKEN) |
|
|
|
response = get_llama_response(pipeline,prompt) |
|
|
|
st.write('Answer: ',response) |