daljeetsingh commited on
Commit
a90affe
·
1 Parent(s): 481b009
Files changed (3) hide show
  1. app.py +71 -4
  2. example_queries.py +129 -0
  3. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,74 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ ## https://www.kaggle.com/code/unravel/fine-tuning-of-a-sql-model
2
+
3
+ import spaces
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  import gradio as gr
6
+ import torch
7
+ from transformers.utils import logging
8
+ from example_queries import small_query, long_query
9
+
10
+ logging.set_verbosity_info()
11
+ logger = logging.get_logger("transformers")
12
+
13
+ model_name='t5-small'
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
16
+ original_model.to('cuda')
17
+
18
+ ft_model_name="cssupport/t5-small-awesome-text-to-sql"
19
+ ft_model = AutoModelForSeq2SeqLM.from_pretrained(ft_model_name, torch_dtype=torch.bfloat16)
20
+ ft_model.to('cuda')
21
+
22
+ @spaces.GPU
23
+ def translate_text(text):
24
+ prompt = f"{text}"
25
+ inputs = tokenizer(prompt, return_tensors='pt')
26
+ inputs = inputs.to('cuda')
27
+
28
+ try:
29
+ output = tokenizer.decode(
30
+ original_model.generate(
31
+ inputs["input_ids"],
32
+ max_new_tokens=200,
33
+ )[0],
34
+ skip_special_tokens=True
35
+ )
36
+ ft_output = tokenizer.decode(
37
+ ft_model.generate(
38
+ inputs["input_ids"],
39
+ max_new_tokens=200,
40
+ )[0],
41
+ skip_special_tokens=True
42
+ )
43
+ return [output, ft_output]
44
+ except Exception as e:
45
+ return f"Error: {str(e)}"
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ with gr.Row():
50
+ with gr.Column():
51
+ prompt = gr.Textbox(
52
+ value=small_query,
53
+ lines=8,
54
+ placeholder="Enter prompt...",
55
+ label="Prompt"
56
+ )
57
+ submit_btn = gr.Button(value="Generate")
58
+ with gr.Column():
59
+ orig_output = gr.Textbox(label="OriginalModel", lines=2)
60
+ ft_output = gr.Textbox(label="FTModel", lines=8)
61
+
62
+ submit_btn.click(
63
+ translate_text, inputs=[prompt], outputs=[orig_output, ft_output], api_name=False
64
+ )
65
+ examples = gr.Examples(
66
+ examples=[
67
+ [small_query],
68
+ [long_query],
69
+ ],
70
+ inputs=[prompt],
71
+ )
72
 
73
+ demo.launch(show_api=False, share=True, debug=True)
 
74
 
 
 
example_queries.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ small_query=""" Tables:
2
+ CREATE TABLE table_name_11 (date VARCHAR, away_team VARCHAR)
3
+
4
+ Question:
5
+ On what Date did the Away team essendon play?
6
+
7
+ Answer:"""
8
+ long_query="""Tables:
9
+
10
+ CREATE TABLE employees (
11
+
12
+ EMPLOYEE_ID decimal(6,0),
13
+
14
+ FIRST_NAME varchar(20),
15
+
16
+ LAST_NAME varchar(25),
17
+
18
+ EMAIL varchar(25),
19
+
20
+ PHONE_NUMBER varchar(20),
21
+
22
+ HIRE_DATE date,
23
+
24
+ JOB_ID varchar(10),
25
+
26
+ SALARY decimal(8,2),
27
+
28
+ COMMISSION_PCT decimal(2,2),
29
+
30
+ MANAGER_ID decimal(6,0),
31
+
32
+ DEPARTMENT_ID decimal(4,0)
33
+
34
+ )
35
+
36
+
37
+
38
+ CREATE TABLE jobs (
39
+
40
+ JOB_ID varchar(10),
41
+
42
+ JOB_TITLE varchar(35),
43
+
44
+ MIN_SALARY decimal(6,0),
45
+
46
+ MAX_SALARY decimal(6,0)
47
+
48
+ )
49
+
50
+
51
+
52
+ CREATE TABLE locations (
53
+
54
+ LOCATION_ID decimal(4,0),
55
+
56
+ STREET_ADDRESS varchar(40),
57
+
58
+ POSTAL_CODE varchar(12),
59
+
60
+ CITY varchar(30),
61
+
62
+ STATE_PROVINCE varchar(25),
63
+
64
+ COUNTRY_ID varchar(2)
65
+
66
+ )
67
+
68
+
69
+
70
+ CREATE TABLE countries (
71
+
72
+ COUNTRY_ID varchar(2),
73
+
74
+ COUNTRY_NAME varchar(40),
75
+
76
+ REGION_ID decimal(10,0)
77
+
78
+ )
79
+
80
+
81
+
82
+ CREATE TABLE job_history (
83
+
84
+ EMPLOYEE_ID decimal(6,0),
85
+
86
+ START_DATE date,
87
+
88
+ END_DATE date,
89
+
90
+ JOB_ID varchar(10),
91
+
92
+ DEPARTMENT_ID decimal(4,0)
93
+
94
+ )
95
+
96
+
97
+
98
+ CREATE TABLE regions (
99
+
100
+ REGION_ID decimal(5,0),
101
+
102
+ REGION_NAME varchar(25)
103
+
104
+ )
105
+
106
+
107
+
108
+ CREATE TABLE departments (
109
+
110
+ DEPARTMENT_ID decimal(4,0),
111
+
112
+ DEPARTMENT_NAME varchar(30),
113
+
114
+ MANAGER_ID decimal(6,0),
115
+
116
+ LOCATION_ID decimal(4,0)
117
+
118
+ )
119
+
120
+
121
+
122
+ Question:
123
+
124
+ For those employees who did not have any job in the past, give me the comparison about the amount of job_id over the job_id , and group by attribute job_id, and list from low to high by the JOB_ID please.
125
+
126
+
127
+
128
+ Answer:
129
+ """
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ diffusers
3
+ transformers
4
+ accelerate
5
+ openai