ayush0504 commited on
Commit
48ceee6
·
verified ·
1 Parent(s): 042e364

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -41
app.py CHANGED
@@ -1,58 +1,62 @@
 
1
  import torch
2
  from peft import AutoPeftModelForCausalLM
3
  from transformers import AutoTokenizer, TextStreamer
4
- import streamlit as st
5
-
6
- # Initialize Streamlit UI
7
- st.title("Legal Query Chatbot")
8
- st.write("Ask questions related to Indian traffic laws and get AI-generated responses.")
9
 
10
  # Load LoRA fine-tuned model and tokenizer
11
- model_path = "lora_model"
12
- load_in_4bit = True
13
 
14
  # Load the model
15
- model = AutoPeftModelForCausalLM.from_pretrained(
16
- model_path,
17
- torch_dtype=torch.float16 if not load_in_4bit else torch.float32,
18
- load_in_4bit=load_in_4bit,
19
- device_map="auto"
20
- )
 
 
 
 
21
 
22
  # Load tokenizer
23
- tokenizer = AutoTokenizer.from_pretrained(model_path)
24
-
25
- # Enable inference mode
26
- model.eval()
27
 
28
- # Streamlit input for user prompt
29
- user_input = st.text_input("Enter your legal query:", "What are the penalties for breaking a red light in India?")
30
 
31
- if user_input:
32
- # Prepare the prompt
33
- messages = [{"role": "user", "content": user_input}]
34
-
35
- # Tokenize input
36
  inputs = tokenizer.apply_chat_template(
37
  messages,
38
  tokenize=True,
39
  add_generation_prompt=True,
40
  return_tensors="pt"
41
  ).to("cuda" if torch.cuda.is_available() else "cpu")
42
-
43
- # Streamlit progress indicator
44
- with st.spinner("Generating response..."):
45
- # Use a text streamer for efficient streaming output
46
- text_streamer = TextStreamer(tokenizer, skip_prompt=True)
47
-
48
- # Generate response
49
- output = model.generate(
50
- input_ids=inputs,
51
- streamer=text_streamer,
52
- max_new_tokens=128,
53
- use_cache=True,
54
- temperature=1.5,
55
- min_p=0.1
56
- )
57
-
58
- st.success("Generation Complete!")
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from peft import AutoPeftModelForCausalLM
4
  from transformers import AutoTokenizer, TextStreamer
 
 
 
 
 
5
 
6
  # Load LoRA fine-tuned model and tokenizer
7
+ model_path = "lora_model" # Your model folder path
8
+ load_in_4bit = True # Whether to load in 4-bit precision
9
 
10
  # Load the model
11
+ @st.cache_resource
12
+ def load_model():
13
+ model = AutoPeftModelForCausalLM.from_pretrained(
14
+ model_path,
15
+ torch_dtype=torch.float16 if not load_in_4bit else torch.float32,
16
+ load_in_4bit=load_in_4bit,
17
+ device_map="auto"
18
+ )
19
+ model.eval()
20
+ return model
21
 
22
  # Load tokenizer
23
+ @st.cache_resource
24
+ def load_tokenizer():
25
+ return AutoTokenizer.from_pretrained(model_path)
 
26
 
27
+ model = load_model()
28
+ tokenizer = load_tokenizer()
29
 
30
+ def generate_response(question):
31
+ messages = [{"role": "user", "content": question}]
 
 
 
32
  inputs = tokenizer.apply_chat_template(
33
  messages,
34
  tokenize=True,
35
  add_generation_prompt=True,
36
  return_tensors="pt"
37
  ).to("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True)
40
+ output = model.generate(
41
+ input_ids=inputs,
42
+ streamer=text_streamer,
43
+ max_new_tokens=1048,
44
+ use_cache=True,
45
+ temperature=0.7,
46
+ min_p=0.1
47
+ )
48
+
49
+ return tokenizer.decode(output[0], skip_special_tokens=True)
50
+
51
+ # Streamlit UI
52
+ st.title("Indian Penal Code AI Assistant")
53
+
54
+ question = st.text_area("Ask a legal question:")
55
+ if st.button("Generate Response"):
56
+ if question.strip():
57
+ with st.spinner("Generating response..."):
58
+ answer = generate_response(question)
59
+ st.subheader("Answer:")
60
+ st.write(answer)
61
+ else:
62
+ st.warning("Please enter a question.")