acecalisto3 commited on
Commit
afe9aee
1 Parent(s): a02dd09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -102
app.py CHANGED
@@ -6,10 +6,7 @@ import json
6
  import re
7
  import torch
8
  import tempfile
9
- import subprocess
10
- import ast
11
  import os
12
- import dataclasses
13
  from pathlib import Path
14
  from typing import Dict, List, Tuple, Optional, Any, Union
15
  from dataclasses import dataclass, field
@@ -19,7 +16,6 @@ from sentence_transformers import SentenceTransformer
19
  import faiss
20
  import numpy as np
21
  from PIL import Image
22
- from templates import TemplateManager, Template # Import TemplateManager and Template
23
 
24
  # Configure logging
25
  logging.basicConfig(
@@ -37,72 +33,112 @@ DEFAULT_PORT = 7860
37
  MODEL_CACHE_DIR = Path("model_cache")
38
  TEMPLATE_DIR = Path("templates")
39
  TEMP_DIR = Path("temp")
40
- DATABASE_PATH = Path("code_database.json") # Path for our simple database
41
 
42
  # Ensure directories exist
43
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
44
  directory.mkdir(exist_ok=True, parents=True)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  class RAGSystem:
47
  def __init__(self, model_name: str = "gpt2", device: str = "cuda" if torch.cuda.is_available() else "cpu", embedding_model="all-mpnet-base-v2"):
 
 
 
 
 
 
 
48
  try:
49
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
50
- self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
51
- self.device = device
52
  self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device)
53
  self.embedding_model = SentenceTransformer(embedding_model)
54
  self.load_database()
55
  logger.info("RAG system initialized successfully.")
56
  except Exception as e:
57
  logger.error(f"Error loading language model or embedding model: {e}. Falling back to placeholder generation.")
58
- self.pipe = None
59
- self.embedding_model = None
60
- self.code_embeddings = None
61
 
62
  def load_database(self):
63
- """Loads or creates the code database"""
64
  if DATABASE_PATH.exists():
65
  try:
66
  with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
67
  self.database = json.load(f)
68
  self.code_embeddings = np.array(self.database['embeddings'])
69
  logger.info("Loaded code database from file.")
 
70
  except (json.JSONDecodeError, KeyError) as e:
71
  logger.error(f"Error loading code database: {e}. Creating new database.")
72
  self.database = {'codes': [], 'embeddings': []}
73
  self.code_embeddings = np.array([])
74
-
75
  else:
76
  logger.info("Code database does not exist. Creating new database.")
77
  self.database = {'codes': [], 'embeddings': []}
78
  self.code_embeddings = np.array([])
 
79
 
80
  if self.embedding_model and len(self.database['codes']) != len(self.database['embeddings']):
81
  logger.warning("Mismatch between number of codes and embeddings, rebuilding embeddings.")
82
  self.rebuild_embeddings()
83
  elif self.embedding_model is None:
84
- logger.warning("Embeddings are not supported in this context.")
85
 
86
- # Index the embeddings for efficient searching
87
  if len(self.code_embeddings) > 0 and self.embedding_model:
88
  self.index = faiss.IndexFlatL2(self.code_embeddings.shape[1]) # L2 distance
89
  self.index.add(self.code_embeddings)
90
 
91
  def add_to_database(self, code: str):
92
- """Adds a code snippet to the database"""
93
  try:
 
 
94
  embedding = self.embedding_model.encode(code)
95
  self.database['codes'].append(code)
96
  self.database['embeddings'].append(embedding.tolist())
97
- self.code_embeddings = np.vstack((self.code_embeddings, embedding))
98
- self.index.add(np.array([embedding])) # update FAISS index
99
  self.save_database()
100
  logger.info(f"Added code snippet to database. Total size: {len(self.database['codes'])}.")
101
  except Exception as e:
102
  logger.error(f"Error adding to database: {e}")
103
 
104
  def save_database(self):
105
- """Saves the database to a file"""
106
  try:
107
  with open(DATABASE_PATH, 'w', encoding='utf-8') as f:
108
  json.dump(self.database, f, indent=2)
@@ -111,22 +147,21 @@ class RAGSystem:
111
  logger.error(f"Error saving database: {e}")
112
 
113
  def rebuild_embeddings(self):
114
- """Rebuilds embeddings from the codes"""
115
  try:
 
 
116
  embeddings = self.embedding_model.encode(self.database['codes'])
117
  self.code_embeddings = embeddings
118
  self.database['embeddings'] = embeddings.tolist()
119
- self.index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 distance
120
- self.index.add(embeddings)
121
  self.save_database()
122
  logger.info("Rebuilt and saved embeddings to the database.")
123
  except Exception as e:
124
  logger.error(f"Error rebuilding embeddings: {e}")
125
 
126
  def retrieve_similar_code(self, description: str, top_k: int = 3) -> List[str]:
127
- """Retrieves similar code snippets from the database"""
128
- if self.embedding_model is None:
129
- logger.warning("Embedding model is not available. Cannot retrieve similar code.")
130
  return []
131
  try:
132
  embedding = self.embedding_model.encode(description)
@@ -139,7 +174,7 @@ class RAGSystem:
139
 
140
  def generate_code(self, description: str, template_code: str) -> str:
141
  retrieved_codes = self.retrieve_similar_code(description)
142
- prompt = f"Description: {description}\nRetrieved Code Snippets:\n{''.join([f'```python\n{code}\n```\n' for code in retrieved_codes])}\nTemplate:\n```python\n{template_code}\n```\nGenerated Code:\n```python\n"
143
  if self.pipe:
144
  try:
145
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
@@ -151,58 +186,25 @@ class RAGSystem:
151
  return template_code
152
  else:
153
  logger.warning("Text generation pipeline is not available. Returning placeholder code.")
154
- return f"# Placeholder code generation. Description: {description}\n{template_code}"
155
-
156
- def generate_interface(self, screenshot: Optional[Image.Image], description: str) -> str:
157
- retrieved_codes = self.retrieve_similar_code(description)
158
- prompt = f"Create a Gradio interface based on this description: {description}\nRetrieved Code Snippets:\n{''.join([f'```python\n{code}\n```\n' for code in retrieved_codes])}"
159
- if screenshot:
160
- prompt += "\nThe interface should resemble the provided screenshot."
161
- prompt += "\n```python\n"
162
- if self.pipe:
163
- try:
164
- generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
165
- generated_code = generated_text.split("```")[1].strip()
166
- logger.info("Interface code generated successfully.")
167
- return generated_code
168
- except Exception as e:
169
- logger.error(f"Error generating interface with language model: {e}. Returning placeholder.")
170
- return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
171
- else:
172
- logger.warning("Text generation pipeline is not available. Returning placeholder interface code.")
173
- return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
174
-
175
- class PreviewManager:
176
- def __init__(self):
177
- self.preview_code = ""
178
-
179
- def update_preview(self, code: str):
180
- """Update the preview with the generated code."""
181
- self.preview_code = code
182
- logger.info("Preview updated with new code.")
183
 
184
  class GradioInterface:
185
  def __init__(self):
186
  self.template_manager = TemplateManager(TEMPLATE_DIR)
187
  self.template_manager.load_templates()
188
- self.current_code = ""
189
  self.rag_system = RAGSystem()
190
- self.preview_manager = PreviewManager()
191
 
192
  def _extract_components(self, code: str) -> List[str]:
193
- """Extract components from the code."""
194
  components = []
195
- function_matches = re.findall(r'def (\w+)', code components.extend(function_matches)
196
- class_matches = re.findall(r'class (\w+)', code)
 
197
  components.extend(class_matches)
198
  logger.info(f"Extracted components: {components}")
199
  return components
200
 
201
  def _get_template_choices(self) -> List[str]:
202
- """Get available template choices."""
203
- choices = list(self.template_manager.templates.keys())
204
- logger.info(f"Available template choices: {choices}")
205
- return choices
206
 
207
  def launch(self, **kwargs):
208
  with gr.Blocks() as interface:
@@ -212,59 +214,51 @@ class GradioInterface:
212
  generate_button = gr.Button("Generate Code")
213
  template_choice = gr.Dropdown(label="Select Template", choices=self._get_template_choices(), value=None)
214
  save_button = gr.Button("Save as Template")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  generate_button.click(
217
- fn=self.generate_code,
218
  inputs=[description_input, template_choice],
219
- outputs=code_output
220
  )
221
 
222
  save_button.click(
223
- fn=self.save_template,
224
  inputs=[code_output, template_choice, description_input],
225
- outputs=code_output
226
- )
227
-
228
- gr.Markdown("### Preview")
229
- preview_output = gr.Textbox(label="Preview", interactive=False)
230
- self.preview_manager.update_preview(code_output)
231
-
232
- generate_button.click(
233
- fn=lambda code: self.preview_manager.update_preview(code),
234
- inputs=code_output,
235
- outputs=preview_output
236
  )
237
 
238
  logger.info("Launching Gradio interface.")
239
  interface.launch(**kwargs)
240
 
241
- def generate_code(self, description: str, template_choice: Optional[str]) -> str:
242
- """Generate code based on the description and selected template."""
243
- template_code = self.template_manager.get_template(template_choice) if template_choice else "" # Get template code if selected
244
- logger.info(f"Generating code for description: {description} with template: {template_choice}")
245
- return self.rag_system.generate_code(description, template_code)
246
-
247
- def save_template(self, code: str, name: str, description: str) -> str:
248
- """Save the generated code as a template."""
249
- try:
250
- components = self._extract_components(code)
251
- template = Template(code=code, description=description, components=components)
252
- if self.template_manager.save_template(name, template):
253
- self.rag_system.add_to_database(code) # Add code to the database
254
- logger.info(f"Template '{name}' saved successfully.")
255
- return f"✅ Template '{name}' saved successfully."
256
- else:
257
- logger.error("Failed to save template.")
258
- return "❌ Failed to save template."
259
- except Exception as e:
260
- logger.error(f"Error saving template: {e}")
261
- return f"❌ Error saving template: {str(e)}"
262
-
263
  def main():
264
  logger.info("=== Application Startup ===")
265
-
266
  try:
267
- # Initialize and launch interface
268
  interface = GradioInterface()
269
  interface.launch(
270
  server_port=DEFAULT_PORT,
 
6
  import re
7
  import torch
8
  import tempfile
 
 
9
  import os
 
10
  from pathlib import Path
11
  from typing import Dict, List, Tuple, Optional, Any, Union
12
  from dataclasses import dataclass, field
 
16
  import faiss
17
  import numpy as np
18
  from PIL import Image
 
19
 
20
  # Configure logging
21
  logging.basicConfig(
 
33
  MODEL_CACHE_DIR = Path("model_cache")
34
  TEMPLATE_DIR = Path("templates")
35
  TEMP_DIR = Path("temp")
36
+ DATABASE_PATH = Path("code_database.json")
37
 
38
  # Ensure directories exist
39
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
40
  directory.mkdir(exist_ok=True, parents=True)
41
 
42
+ @dataclass
43
+ class Template:
44
+ code: str
45
+ description: str
46
+ components: List[str] = field(default_factory=list)
47
+
48
+ class TemplateManager:
49
+ def __init__(self, template_dir: Path):
50
+ self.template_dir = template_dir
51
+ self.templates: Dict[str, Template] = {}
52
+
53
+ def load_templates(self):
54
+ for file_path in self.template_dir.glob("*.json"):
55
+ try:
56
+ with open(file_path, 'r') as f:
57
+ template_data = json.load(f)
58
+ template = Template(**template_data)
59
+ self.templates[template_data['description']] = template
60
+ except json.JSONDecodeError as e:
61
+ logger.error(f"Error loading template from {file_path}: {e}")
62
+ except KeyError as e:
63
+ logger.error(f"Missing key in template file {file_path}: {e}")
64
+
65
+ def save_template(self, name: str, template: Template) -> bool:
66
+ file_path = self.template_dir / f"{name}.json"
67
+ try:
68
+ with open(file_path, 'w') as f:
69
+ json.dump(dataclasses.asdict(template), f, indent=2)
70
+ return True
71
+ except Exception as e:
72
+ logger.error(f"Error saving template to {file_path}: {e}")
73
+ return False
74
+
75
+ def get_template(self, name: str) -> Optional[str]:
76
+ return self.templates.get(name, {}).get('code', "")
77
+
78
  class RAGSystem:
79
  def __init__(self, model_name: str = "gpt2", device: str = "cuda" if torch.cuda.is_available() else "cpu", embedding_model="all-mpnet-base-v2"):
80
+ self.device = device
81
+ self.embedding_model = None
82
+ self.code_embeddings = None
83
+ self.index = None
84
+ self.database = {'codes': [], 'embeddings': []}
85
+ self.pipe = None
86
+
87
  try:
88
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=MODEL_CACHE_DIR)
89
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=MODEL_CACHE_DIR).to(device)
 
90
  self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device)
91
  self.embedding_model = SentenceTransformer(embedding_model)
92
  self.load_database()
93
  logger.info("RAG system initialized successfully.")
94
  except Exception as e:
95
  logger.error(f"Error loading language model or embedding model: {e}. Falling back to placeholder generation.")
 
 
 
96
 
97
  def load_database(self):
 
98
  if DATABASE_PATH.exists():
99
  try:
100
  with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
101
  self.database = json.load(f)
102
  self.code_embeddings = np.array(self.database['embeddings'])
103
  logger.info("Loaded code database from file.")
104
+ self._build_index()
105
  except (json.JSONDecodeError, KeyError) as e:
106
  logger.error(f"Error loading code database: {e}. Creating new database.")
107
  self.database = {'codes': [], 'embeddings': []}
108
  self.code_embeddings = np.array([])
109
+ self._build_index()
110
  else:
111
  logger.info("Code database does not exist. Creating new database.")
112
  self.database = {'codes': [], 'embeddings': []}
113
  self.code_embeddings = np.array([])
114
+ self._build_index()
115
 
116
  if self.embedding_model and len(self.database['codes']) != len(self.database['embeddings']):
117
  logger.warning("Mismatch between number of codes and embeddings, rebuilding embeddings.")
118
  self.rebuild_embeddings()
119
  elif self.embedding_model is None:
120
+ logger.warning ("Embeddings are not supported in this context.")
121
 
122
+ def _build_index(self):
123
  if len(self.code_embeddings) > 0 and self.embedding_model:
124
  self.index = faiss.IndexFlatL2(self.code_embeddings.shape[1]) # L2 distance
125
  self.index.add(self.code_embeddings)
126
 
127
  def add_to_database(self, code: str):
 
128
  try:
129
+ if self.embedding_model is None:
130
+ raise ValueError("Embedding model not loaded.")
131
  embedding = self.embedding_model.encode(code)
132
  self.database['codes'].append(code)
133
  self.database['embeddings'].append(embedding.tolist())
134
+ self.code_embeddings = np.vstack((self.code_embeddings, embedding)) if len(self.code_embeddings) > 0 else np.array([embedding])
135
+ self.index.add(np.array([embedding]))
136
  self.save_database()
137
  logger.info(f"Added code snippet to database. Total size: {len(self.database['codes'])}.")
138
  except Exception as e:
139
  logger.error(f"Error adding to database: {e}")
140
 
141
  def save_database(self):
 
142
  try:
143
  with open(DATABASE_PATH, 'w', encoding='utf-8') as f:
144
  json.dump(self.database, f, indent=2)
 
147
  logger.error(f"Error saving database: {e}")
148
 
149
  def rebuild_embeddings(self):
 
150
  try:
151
+ if self.embedding_model is None:
152
+ raise ValueError("Embedding model not loaded.")
153
  embeddings = self.embedding_model.encode(self.database['codes'])
154
  self.code_embeddings = embeddings
155
  self.database['embeddings'] = embeddings.tolist()
156
+ self._build_index()
 
157
  self.save_database()
158
  logger.info("Rebuilt and saved embeddings to the database.")
159
  except Exception as e:
160
  logger.error(f"Error rebuilding embeddings: {e}")
161
 
162
  def retrieve_similar_code(self, description: str, top_k: int = 3) -> List[str]:
163
+ if self.embedding_model is None or self.index is None:
164
+ logger.warning("Embedding model or index not available. Cannot retrieve similar code.")
 
165
  return []
166
  try:
167
  embedding = self.embedding_model.encode(description)
 
174
 
175
  def generate_code(self, description: str, template_code: str) -> str:
176
  retrieved_codes = self.retrieve_similar_code(description)
177
+ prompt = f"Description: {description} Retrieved Code Snippets: {''.join([f'```python {code} ```' for code in retrieved_codes])} Template: ```python {template_code} ``` Generated Code: ```python "
178
  if self.pipe:
179
  try:
180
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
 
186
  return template_code
187
  else:
188
  logger.warning("Text generation pipeline is not available. Returning placeholder code.")
189
+ return f"# Placeholder code generation. Description: {description} {template_code}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  class GradioInterface:
192
  def __init__(self):
193
  self.template_manager = TemplateManager(TEMPLATE_DIR)
194
  self.template_manager.load_templates()
 
195
  self.rag_system = RAGSystem()
 
196
 
197
  def _extract_components(self, code: str) -> List[str]:
 
198
  components = []
199
+ function_matches = re.findall(r'def (\w+)\(', code) # added parenthesis for more accuracy
200
+ components.extend(function_matches)
201
+ class_matches = re.findall(r'class (\w+)\:', code) # added colon for more accuracy
202
  components.extend(class_matches)
203
  logger.info(f"Extracted components: {components}")
204
  return components
205
 
206
  def _get_template_choices(self) -> List[str]:
207
+ return list(self.template_manager.templates.keys())
 
 
 
208
 
209
  def launch(self, **kwargs):
210
  with gr.Blocks() as interface:
 
214
  generate_button = gr.Button("Generate Code")
215
  template_choice = gr.Dropdown(label="Select Template", choices=self._get_template_choices(), value=None)
216
  save_button = gr.Button("Save as Template")
217
+ status_output = gr.Textbox(label="Status", interactive=False)
218
+
219
+ def generate_code_wrapper(description, template_choice):
220
+ try:
221
+ template_code = self.template_manager.get_template(template_choice) if template_choice else ""
222
+ generated_code = self.rag_system.generate_code(description, template_code)
223
+ return generated_code, "Code generated successfully."
224
+ except Exception as e:
225
+ return "", f"Error generating code: {e}"
226
+
227
+ def save_template_wrapper(code, name, description):
228
+ try:
229
+ if not name:
230
+ return code, "Template name cannot be empty."
231
+ if not code:
232
+ return code, "Code cannot be empty."
233
+
234
+ components = self._extract_components(code)
235
+ template = Template(code=code, description=name, components=components)
236
+ if self.template_manager.save_template(name, template):
237
+ self.rag_system.add_to_database(code)
238
+ return code, f"Template '{name}' saved successfully."
239
+ else:
240
+ return code, "Failed to save template."
241
+ except Exception as e:
242
+ return code, f"Error saving template: {e}"
243
 
244
  generate_button.click(
245
+ fn=generate_code_wrapper,
246
  inputs=[description_input, template_choice],
247
+ outputs=[code_output, status_output]
248
  )
249
 
250
  save_button.click(
251
+ fn=save_template_wrapper,
252
  inputs=[code_output, template_choice, description_input],
253
+ outputs=[code_output, status_output]
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
  logger.info("Launching Gradio interface.")
257
  interface.launch(**kwargs)
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  def main():
260
  logger.info("=== Application Startup ===")
 
261
  try:
 
262
  interface = GradioInterface()
263
  interface.launch(
264
  server_port=DEFAULT_PORT,