awacke1 commited on
Commit
e16f68e
·
verified ·
1 Parent(s): cf0cb2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -5
app.py CHANGED
@@ -176,15 +176,38 @@ def create_file(filename, prompt, response, is_image=False):
176
  with open(filename, "w", encoding="utf-8") as f:
177
  f.write(prompt + "\n\n" + response)
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  # Now filename length protected for linux and windows filename lengths
181
  def save_image(image, filename):
182
  max_filename_length = 250
183
  filename_stem, extension = os.path.splitext(filename)
184
  truncated_stem = filename_stem[:max_filename_length - len(extension)] if len(filename) > max_filename_length else filename_stem
185
- filename = f"{truncated_stem}{extension}"
186
- with open(filename, "wb") as f:
187
- f.write(image.getbuffer())
 
 
 
 
188
  return filename
189
 
190
  def extract_boldface_terms(text):
@@ -212,9 +235,14 @@ def process_image(image_input, user_prompt):
212
  data=False
213
  else:
214
  #image_file_name = image_input.name
215
- image_file_name = image_input.filename
216
- image_input = image_input.read()
217
  SaveNewFile=True
 
 
 
 
 
 
218
 
219
  st.markdown('Processing image: ' + image_file_name)
220
  base64_image = base64.b64encode(image_input).decode("utf-8")
 
176
  with open(filename, "w", encoding="utf-8") as f:
177
  f.write(prompt + "\n\n" + response)
178
 
179
+ def sanitize_filename(filename):
180
+ import string
181
+ # Characters not allowed in Windows filenames
182
+ windows_disallowed_chars = '<>:"\\|?*'
183
+
184
+ # Characters not allowed in Unix/Linux filenames
185
+ linux_disallowed_chars = '/'
186
+
187
+ # Additional disallowed characters (non-printable ASCII characters)
188
+ additional_disallowed_chars = ''.join(chr(i) for i in range(32))
189
+
190
+ # Combined set of disallowed characters
191
+ disallowed_chars = windows_disallowed_chars + linux_disallowed_chars + additional_disallowed_chars
192
+
193
+ # Remove disallowed characters
194
+ sanitized_filename = ''.join(c for c in filename if c not in disallowed_chars and c in string.printable)
195
+
196
+ return sanitized_filename
197
+
198
 
199
  # Now filename length protected for linux and windows filename lengths
200
  def save_image(image, filename):
201
  max_filename_length = 250
202
  filename_stem, extension = os.path.splitext(filename)
203
  truncated_stem = filename_stem[:max_filename_length - len(extension)] if len(filename) > max_filename_length else filename_stem
204
+ filename = f"{truncated_s tem}{extension}"
205
+ filename = sanitize_filename(filename)
206
+ try:
207
+ with open(filename, "wb") as f:
208
+ f.write(image.getbuffer())
209
+ except:
210
+ errored=True
211
  return filename
212
 
213
  def extract_boldface_terms(text):
 
235
  data=False
236
  else:
237
  #image_file_name = image_input.name
238
+ image_bytes = image_input.read()
 
239
  SaveNewFile=True
240
+ try:
241
+ if (image_input.filename is not None):
242
+ image_file_name = image_input.filename
243
+ except:
244
+ image_file_name = image_input.name
245
+ image_input = image_bytes # this should allow new posts to ssave and to flow through bytes
246
 
247
  st.markdown('Processing image: ' + image_file_name)
248
  base64_image = base64.b64encode(image_input).decode("utf-8")