Spaces:
Build error
Build error
"""Python file to serve as the frontend""" | |
import streamlit as st | |
from streamlit_chat import message | |
from langchain.chains import ConversationChain, LLMChain | |
from langchain import PromptTemplate | |
from langchain.llms.base import LLM | |
from langchain.memory import ConversationBufferWindowMemory | |
from typing import Optional, List, Mapping, Any | |
import torch | |
from peft import PeftModel | |
import transformers | |
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig | |
from transformers import BitsAndBytesConfig | |
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") | |
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True) | |
model = LlamaForCausalLM.from_pretrained( | |
"decapoda-research/llama-7b-hf", | |
# load_in_8bit=True, | |
# torch_dtype=torch.float16, | |
device_map="auto", | |
# device_map={"":"cpu"}, | |
max_memory={"cpu":"15GiB"} | |
quantization_config=quantization_config | |
) | |
model = PeftModel.from_pretrained( | |
model, "tloen/alpaca-lora-7b", | |
# torch_dtype=torch.float16, | |
device_map={"":"cpu"}, | |
) | |
device = "cpu" | |
print("model device :", model.device, flush=True) | |
# model.to(device) | |
model.eval() | |
def evaluate_raw_prompt( | |
prompt:str, | |
temperature=0.1, | |
top_p=0.75, | |
top_k=40, | |
num_beams=4, | |
**kwargs, | |
): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=256, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
# return output | |
return output.split("### Response:")[1].strip() | |
class AlpacaLLM(LLM): | |
temperature: float | |
top_p: float | |
top_k: int | |
num_beams: int | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
answer = evaluate_raw_prompt(prompt, | |
top_p= self.top_p, | |
top_k= self.top_k, | |
num_beams= self.num_beams, | |
temperature= self.temperature | |
) | |
return answer | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
"top_p": self.top_p, | |
"top_k": self.top_k, | |
"num_beams": self.num_beams, | |
"temperature": self.temperature | |
} | |
template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
You are a chatbot, you should answer my last question very briefly. You are consistent and non repetitive. | |
### Chat: | |
{history} | |
Human: {human_input} | |
### Response:""" | |
prompt = PromptTemplate( | |
input_variables=["history","human_input"], | |
template=template, | |
) | |
def load_chain(): | |
"""Logic for loading the chain you want to use should go here.""" | |
llm = AlpacaLLM(top_p=0.75, top_k=40, num_beams=4, temperature=0.1) | |
# chain = ConversationChain(llm=llm) | |
chain = LLMChain(llm=llm, prompt=prompt, memory=ConversationBufferWindowMemory(k=2)) | |
return chain | |
chain = load_chain() | |
# # From here down is all the StreamLit UI. | |
# st.set_page_config(page_title="LangChain Demo", page_icon=":robot:") | |
# st.header("LangChain Demo") | |
# if "generated" not in st.session_state: | |
# st.session_state["generated"] = [] | |
# if "past" not in st.session_state: | |
# st.session_state["past"] = [] | |
# def get_text(): | |
# input_text = st.text_input("Human: ", "Hello, how are you?", key="input") | |
# return input_text | |
# user_input = get_text() | |
# if user_input: | |
# output = chain.predict(human_input=user_input) | |
# st.session_state.past.append(user_input) | |
# st.session_state.generated.append(output) | |
# if st.session_state["generated"]: | |
# for i in range(len(st.session_state["generated"]) - 1, -1, -1): | |
# message(st.session_state["generated"][i], key=str(i)) | |
# message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") | |
st.title("ChatAlpaca") | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
st.session_state.history.append({"message": "Hey, I'm a Alpaca chatBot. Ask whatever you want!", "is_user": False}) | |
def generate_answer(): | |
user_message = st.session_state.input_text | |
inputs = tokenizer(st.session_state.input_text, return_tensors="pt") | |
result = model.generate(**inputs) | |
message_bot = tokenizer.decode(result[0], skip_special_tokens=True) # .replace("<s>", "").replace("</s>", "") | |
st.session_state.history.append({"message": user_message, "is_user": True}) | |
st.session_state.history.append({"message": message_bot, "is_user": False}) | |
st.text_input("Response", key="input_text", on_change=generate_answer) | |
for chat in st.session_state.history: | |
st_message(**chat) |