GlastonR commited on
Commit
dec38ea
1 Parent(s): 32b54e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -2,33 +2,35 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load the model and tokenizer
5
- model_name = "Tommy0303000/Llama-2-7b-sql"
6
  model = AutoModelForCausalLM.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Function to generate SQL from natural language
10
- def generate_sql_query(query):
11
- # Tokenize the input query
12
- inputs = tokenizer(query, return_tensors="pt")
13
 
14
- # Generate SQL query using the model
 
 
 
15
  outputs = model.generate(inputs["input_ids"], max_length=100)
16
 
17
- # Decode the generated output
18
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
  return sql_query
20
 
21
  # Streamlit app UI
22
  def main():
23
- st.title("Text-to-SQL with Llama-2-7B Model")
24
- st.write("This app generates SQL queries based on natural language input.")
25
 
26
- # Input field for user's question
27
- query_input = st.text_input("Enter your question:")
28
 
29
- if query_input:
30
  # Generate the SQL query
31
- sql_query = generate_sql_query(query_input)
32
 
33
  # Display the generated SQL query
34
  st.write("Generated SQL Query:")
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load the model and tokenizer
5
+ model_name = "premai-io/prem-1B-SQL"
6
  model = AutoModelForCausalLM.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
+ # Function to generate SQL from the user's input
10
+ def generate_sql_query(question):
11
+ input_text = f"Question: {question} SQL Query:"
 
12
 
13
+ # Tokenize the input
14
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
15
+
16
+ # Generate the SQL query
17
  outputs = model.generate(inputs["input_ids"], max_length=100)
18
 
19
+ # Decode the output
20
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
  return sql_query
22
 
23
  # Streamlit app UI
24
  def main():
25
+ st.title("Text-to-SQL with prem-1B-SQL Model")
26
+ st.write("This app generates SQL queries based on your natural language question.")
27
 
28
+ # Input for the user's question
29
+ question_input = st.text_input("Enter your question:")
30
 
31
+ if question_input:
32
  # Generate the SQL query
33
+ sql_query = generate_sql_query(question_input)
34
 
35
  # Display the generated SQL query
36
  st.write("Generated SQL Query:")