Panchovix Thireus commited on
Commit
b0c3040
1 Parent(s): 926e34f

Update bin2safetensors/convert.py (#2)

Browse files

- Update bin2safetensors/convert.py (62bc6924b6d614a9e1b4ad5f3f1dd90b06573eab)


Co-authored-by: None <Thireus@users.noreply.huggingface.co>

Files changed (1) hide show
  1. bin2safetensors/convert.py +20 -2
bin2safetensors/convert.py CHANGED
@@ -312,7 +312,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
312
  return new_pr, errors
313
 
314
 
315
- def main(input_directory, output_directory):
316
  # Get a list of all files in the input directory
317
  files = os.listdir(input_directory)
318
 
@@ -360,11 +360,29 @@ def main(input_directory, output_directory):
360
  output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
361
  convert_file(input_filename, output_filename)
362
  print(f"Converted {input_filename} to {output_filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  if __name__ == "__main__":
365
  parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
366
  parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
367
  parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
 
 
368
  args = parser.parse_args()
369
 
370
- main(args.input_directory, args.output_directory)
 
312
  return new_pr, errors
313
 
314
 
315
+ def main(input_directory, output_directory, delete_files, delete_input_directory):
316
  # Get a list of all files in the input directory
317
  files = os.listdir(input_directory)
318
 
 
360
  output_filename = os.path.join(output_directory, f"model-{i:05d}-of-{yyyyy:05d}.safetensors")
361
  convert_file(input_filename, output_filename)
362
  print(f"Converted {input_filename} to {output_filename}")
363
+
364
+ # Delete the pytorch_model file if the delete_files flag or delete_input_directory flag are set
365
+ if delete_files or delete_input_directory:
366
+ os.remove(input_filename)
367
+ print(f"Deleted {input_filename}")
368
+
369
+ # Check if there are any remaining pytorch_model files in the input directory
370
+ remaining_model_files = [file for file in os.listdir(input_directory) if re.match(r'pytorch_model-\d{5}-of-\d{5}\.bin', file)]
371
+
372
+ if len(remaining_model_files) == 0:
373
+ # Delete the input directory if all files have been converted successfully
374
+ if delete_input_directory:
375
+ shutil.rmtree(input_directory)
376
+ print(f"Deleted input directory {input_directory}")
377
+ else:
378
+ print("Warning: Input directory still contains pytorch_model files and won't be deleted.")
379
 
380
  if __name__ == "__main__":
381
  parser = argparse.ArgumentParser(description="Convert pytorch_model model to safetensor and copy JSON and .model files.")
382
  parser.add_argument("input_directory", help="Path to the input directory containing pytorch_model files")
383
  parser.add_argument("output_directory", help="Path to the output directory for converted safetensor files")
384
+ parser.add_argument("-d", "--delete", action="store_true", help="Delete pytorch_model files after conversion")
385
+ parser.add_argument("-D", "--delete-input", action="store_true", help="Delete pytorch_model files after conversion as well as the input directory after all files are converted")
386
  args = parser.parse_args()
387
 
388
+ main(args.input_directory, args.output_directory, args.delete, args.delete_input)