|
import streamlit as st |
|
import uuid |
|
import sys |
|
import requests |
|
from peft import * |
|
import bitsandbytes as bnb |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
from datasets import load_dataset |
|
from huggingface_hub import notebook_login |
|
from peft import ( |
|
LoraConfig, |
|
PeftConfig, |
|
get_peft_model, |
|
prepare_model_for_kbit_training, |
|
) |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
) |
|
|
|
|
|
USER_ICON = "images/user-icon.png" |
|
AI_ICON = "images/ai-icon.png" |
|
MAX_HISTORY_LENGTH = 5 |
|
|
|
if 'user_id' in st.session_state: |
|
user_id = st.session_state['user_id'] |
|
else: |
|
user_id = str(uuid.uuid4()) |
|
st.session_state['user_id'] = user_id |
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
if "chats" not in st.session_state: |
|
st.session_state.chats = [ |
|
{ |
|
'id': 0, |
|
'question': '', |
|
'answer': '' |
|
} |
|
] |
|
|
|
if "questions" not in st.session_state: |
|
st.session_state.questions = [] |
|
|
|
if "answers" not in st.session_state: |
|
st.session_state.answers = [] |
|
|
|
if "input" not in st.session_state: |
|
st.session_state.input = "" |
|
|
|
st.markdown(""" |
|
<style> |
|
.block-container { |
|
padding-top: 32px; |
|
padding-bottom: 32px; |
|
padding-left: 0; |
|
padding-right: 0; |
|
} |
|
.element-container img { |
|
background-color: #000000; |
|
} |
|
|
|
.main-header { |
|
font-size: 24px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
def write_top_bar(): |
|
col1, col2, col3 = st.columns([1,10,2]) |
|
with col1: |
|
st.image(AI_ICON, use_column_width='always') |
|
with col2: |
|
header = "Cogwise Intelligent Assistant" |
|
st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True) |
|
with col3: |
|
clear = st.button("Clear Chat") |
|
return clear |
|
|
|
clear = write_top_bar() |
|
|
|
if clear: |
|
st.session_state.questions = [] |
|
st.session_state.answers = [] |
|
st.session_state.input = "" |
|
st.session_state["chat_history"] = [] |
|
|
|
def handle_input(): |
|
input = st.session_state.input |
|
question_with_id = { |
|
'question': input, |
|
'id': len(st.session_state.questions) |
|
} |
|
st.session_state.questions.append(question_with_id) |
|
|
|
chat_history = st.session_state["chat_history"] |
|
if len(chat_history) == MAX_HISTORY_LENGTH: |
|
chat_history = chat_history[:-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
|
dataset_name = "nisaar/Lawyer_GPT_India" |
|
|
|
dataset = load_dataset(dataset_name, split="train") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
load_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts" |
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
config.base_model_name_or_path, |
|
return_dict=True, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
|
"""## Inference |
|
|
|
You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`. |
|
""" |
|
|
|
generation_config = model.generation_config |
|
generation_config.max_new_tokens = 200 |
|
generation_config_temperature = 1 |
|
generation_config.top_p = 0.7 |
|
generation_config.num_return_sequences = 1 |
|
generation_config.pad_token_id = tokenizer.eos_token_id |
|
generation_config_eod_token_id = tokenizer.eos_token_id |
|
|
|
DEVICE = "cuda:0" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response(question: str) -> str: |
|
prompt = f""" |
|
<human>: {question} |
|
<assistant>: |
|
""".strip() |
|
encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=encoding.input_ids, |
|
attention_mask=encoding.attention_mask, |
|
generation_config=generation_config, |
|
) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
assistant_start = '<assistant>:' |
|
response_start = response.find(assistant_start) |
|
return response[response_start + len(assistant_start):].strip() |
|
|
|
|
|
prompt=input |
|
answer=generate_response(prompt) |
|
print(answer) |
|
|
|
|
|
chat_history.append((input, answer)) |
|
|
|
st.session_state.answers.append({ |
|
'answer': answer, |
|
'id': len(st.session_state.questions) |
|
}) |
|
st.session_state.input = "" |
|
|
|
def write_user_message(md): |
|
col1, col2 = st.columns([1,12]) |
|
|
|
with col1: |
|
st.image(USER_ICON, use_column_width='always') |
|
with col2: |
|
st.warning(md['question']) |
|
|
|
def render_answer(answer): |
|
col1, col2 = st.columns([1,12]) |
|
with col1: |
|
st.image(AI_ICON, use_column_width='always') |
|
with col2: |
|
st.info(answer) |
|
|
|
def write_chat_message(md, q): |
|
chat = st.container() |
|
with chat: |
|
render_answer(md['answer']) |
|
|
|
with st.container(): |
|
for (q, a) in zip(st.session_state.questions, st.session_state.answers): |
|
write_user_message(q) |
|
write_chat_message(a, q) |
|
|
|
st.markdown('---') |
|
input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) |
|
|