jhansi1 commited on
Commit
afd92b9
·
verified ·
1 Parent(s): a5046b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -56
app.py CHANGED
@@ -1,58 +1,87 @@
 
 
 
1
  import streamlit as st
2
- import pandas as pd
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
-
5
- # Define paths for the dataset splits
6
- splits = {
7
- 'train': 'data/train-00000-of-00001.parquet',
8
- 'validation': 'data/validation-00000-of-00001.parquet',
9
- 'test': 'data/test-00000-of-00001.parquet'
10
- }
11
-
12
- # Load the dataset
13
- @st.cache_resource
14
- def load_dataset(split="train"):
15
- return pd.read_parquet(f"hf://datasets/BEE-spoke-data/survivorslib-law-books/{splits[split]}")
16
-
17
- # Initialize the model and tokenizer
18
- @st.cache_resource
19
- def load_model():
20
- model_name = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model = AutoModelForCausalLM.from_pretrained(model_name)
23
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
24
-
25
- # Streamlit interface
26
- st.title("Legal Text Generator with NVIDIA Llama")
27
- st.write("Generate text based on the Survivorslib Legal Dataset and the NVIDIA Llama model.")
28
-
29
- # Load dataset and model pipeline
30
- st.sidebar.title("Options")
31
- split_option = st.sidebar.selectbox("Select dataset split", ["train", "validation", "test"])
32
- dataset = load_dataset(split=split_option)
33
- text_generator = load_model()
34
-
35
- # Show sample data from the dataset
36
- st.subheader(f"Sample Data from {split_option.capitalize()} Split")
37
- st.write(dataset.head()) # Displaying the first few rows of the selected dataset split
38
-
39
- # Prompt input
40
- prompt = st.text_area("Enter your prompt:", placeholder="Type a legal prompt or select a sample text...")
41
-
42
- # Optional: Select sample text from the dataset to use as a prompt
43
- if st.button("Use Sample Text"):
44
- if 'content' in dataset.columns:
45
- prompt = dataset['content'].iloc[0]
46
- st.write(f"Using sample text from dataset: {prompt}")
47
- else:
48
- st.write("Dataset does not contain a 'content' column with text data.")
49
-
50
- # Generate text based on the prompt
51
- if st.button("Generate Response"):
52
- if prompt:
53
- with st.spinner("Generating response..."):
54
- generated_text = text_generator(prompt, max_length=100, do_sample=True, temperature=0.7)[0]["generated_text"]
55
- st.write("**Generated Text:**")
56
- st.write(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
58
- st.write("Please enter a prompt to generate a response.")
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
  import streamlit as st
5
+ from transformers import pipeline
6
+ from datasets import load_dataset
7
+
8
+ # Initialize text-generation pipeline with the model
9
+ model_name = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
10
+ pipe = pipeline("text-generation", model=model_name)
11
+
12
+ # Load the dataset from the cloned local directory
13
+ ds = load_dataset("./canadian-legal-data", split="train")
14
+
15
+ # Gradio Interface setup
16
+ def respond(
17
+ message,
18
+ history: list[tuple[str, str]],
19
+ system_message,
20
+ max_tokens,
21
+ temperature,
22
+ top_p,
23
+ ):
24
+ messages = [{"role": "system", "content": system_message}]
25
+
26
+ for val in history:
27
+ if val[0]:
28
+ messages.append({"role": "user", "content": val[0]})
29
+ if val[1]:
30
+ messages.append({"role": "assistant", "content": val[1]})
31
+
32
+ messages.append({"role": "user", "content": message})
33
+
34
+ response = ""
35
+
36
+ for message in pipe(
37
+ prompt=message,
38
+ max_length=max_tokens,
39
+ do_sample=True,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ ):
43
+ token = message["generated_text"]
44
+ response += token
45
+ yield response
46
+
47
+ # Streamlit Interface setup
48
+ def streamlit_interface():
49
+ st.title("Canadian Legal Text Generator")
50
+ st.write("Enter a prompt related to Canadian legal data and generate text using Llama-3.1.")
51
+
52
+ # Show dataset sample
53
+ st.subheader("Sample Data from Canadian Legal Dataset:")
54
+ st.write(ds[:5]) # Display the first 5 rows of the dataset
55
+
56
+ # Prompt input
57
+ prompt = st.text_area("Enter your prompt:", placeholder="Type something...")
58
+
59
+ if st.button("Generate Response"):
60
+ if prompt:
61
+ # Generate text based on the prompt
62
+ with st.spinner("Generating response..."):
63
+ generated_text = pipe(prompt, max_length=100, do_sample=True, temperature=0.7)[0]["generated_text"]
64
+ st.write("**Generated Text:**")
65
+ st.write(generated_text)
66
+ else:
67
+ st.write("Please enter a prompt to generate a response.")
68
+
69
+
70
+ # Running Gradio and Streamlit interfaces
71
+ if __name__ == "__main__":
72
+ st.sidebar.title("Choose an Interface")
73
+ interface = st.sidebar.radio("Select", ("Streamlit", "Gradio"))
74
+
75
+ if interface == "Streamlit":
76
+ streamlit_interface()
77
  else:
78
+ demo = gr.ChatInterface(
79
+ respond,
80
+ additional_inputs=[
81
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
82
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
83
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
84
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
85
+ ],
86
+ )
87
+ demo.launch()