acecalisto3 commited on
Commit
cdf8fef
1 Parent(s): b15fec4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -49
app.py CHANGED
@@ -1,54 +1,35 @@
1
- import os
2
- import gradio as gr
3
- from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- os.environ["GRADIO_SERVER_PORT"] = "8507"
6
-
7
- def get_code_generative_models():
8
- models_dir = os.path.join(os.getcwd(), "models")
9
- if not os.path.exists(models_dir):
10
- os.makedirs(models_dir)
11
-
12
- models = []
13
- for model_name in os.listdir(models_dir):
14
- model_path = os.path.join(models_dir, model_name)
15
- if os.path.isdir(model_path):
16
- model_info = AutoModel.from_pretrained(model_path)
17
- if "config.json" in [f.name for f in model_info.files]:
18
- models.append((model_name, model_path))
19
- return models
20
-
21
- def model_inference(model_name, model_path, input_data):
22
- tokenizer = AutoTokenizer.from_pretrained(model_path)
23
- model = AutoModel.from_pretrained(model_path)
24
- inputs = tokenizer(input_data, return_tensors="pt")
25
- outputs = model(**inputs)
26
- result = outputs.last_hidden_state[:, 0, :]
27
- return result.tolist()
28
-
29
- def main():
30
- models = get_code_generative_models()
31
- with gr.Blocks() as demo:
32
- gr.Markdown("### Select Model and Input")
33
- with gr.Row():
34
- model_name = gr.Dropdown(label="Model", choices=[m[0] for m in models])
35
- input_data = gr.Textbox(label="Input")
36
-
37
- model_path = gr.State(None)
38
-
39
- def update_model_path(model_name):
40
- model_path.set(next(filter(lambda m: m[0] == model_name, models))[1])
41
-
42
- input_data.change(update_model_path, inputs=model_name, outputs=model_path)
43
-
44
- output = gr.Textbox(label="Output")
45
 
46
- def infer(model_name, input_data):
47
- return model_inference(model_name, model_path, input_data)
48
 
49
- output.change(fn=infer, inputs=[model_name, input_data], outputs=output)
 
50
 
51
- interface = demo.launch()
 
52
 
53
- if __name__ == "__main__":
54
- main()
 
1
+ class WebAppTemplate:
2
+ def __init__(self, template: str, output_path: str):
3
+ self.template = template
4
+ self.output_path = output_path
5
+
6
+ def generate_code(self, user_input: dict) -> str:
7
+ # Generate the code based on the user_input and the app template
8
+ tokenizer, model = load_model(user_input["model_name"], user_input["model_path"])
9
+ # Use the tokenizer and model to generate the code
10
+ pass
11
+
12
+ def load_model(self, model_name: str, model_path: str) -> Any:
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
15
+ return tokenizer, model
16
+
17
+ def main(self, port: int = 8000, debug: bool = False) -> None:
18
+ # Implement the main function that creates the Gradio interface and launches the app
19
+ pass
20
 
21
+ if __name__ == "__main__":
22
+ # Initialize the app template
23
+ app_template = WebAppTemplate("template.txt", "output_path")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Get user input
26
+ user_input = get_user_input()
27
 
28
+ # Generate the code
29
+ generated_code = app_template.generate_code(user_input)
30
 
31
+ # Save the generated code
32
+ save_generated_code(generated_code)
33
 
34
+ # Launch the app
35
+ app_template.main(port=8000, debug=False)