ybelkada commited on
Commit
1fe1d3c
1 Parent(s): 9df608d

Add sharded support (#2)

Browse files

- Add sharded support (604cf791f3910fcc6adbf4ae648f623d1e3574b0)

Files changed (1) hide show
  1. app.py +54 -10
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
 
3
  import torch
@@ -6,14 +9,7 @@ import safetensors
6
  from safetensors.torch import save_file
7
  from huggingface_hub import hf_hub_download
8
 
9
- def run(pr_number, model_id):
10
- try:
11
- st_weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=f"refs/pr/{pr_number}")
12
- torch_weights_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
13
- except Exception as e:
14
- return f"Error: {e} | \n Maybe you specified model ids or PRs that does not exist or does not contain any `model.safetensors` or `pytorch_model.bin` files"
15
-
16
-
17
  st_weights = safetensors.torch.load_file(st_weights_path)
18
  torch_weights = torch.load(torch_weights_path)
19
 
@@ -21,7 +17,7 @@ def run(pr_number, model_id):
21
  if st_weights.keys() != torch_weights.keys():
22
  # retrieve different keys
23
  unexpected_keys = st_weights.keys() - torch_weights.keys()
24
- return f"keys are not the same ! Conversion failed - unexpected keys are: {unexpected_keys}"
25
 
26
  total_errors = []
27
 
@@ -33,6 +29,54 @@ def run(pr_number, model_id):
33
  except Exception as e:
34
  total_errors.append(e)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  if len(total_errors) > 0:
38
  return f"weights are not the same ! Conversion failed - {len(total_errors)} errors : {total_errors}"
@@ -47,7 +91,7 @@ The steps are the following:
47
  - Click "Submit"
48
  - That's it! You'll get feedback if the user successfully converted a model in `safetensors` format or not!
49
 
50
- For now this app supports only `pytorch_model.bin` files, and we'll extend it in the future to support sharded formats.
51
  """
52
 
53
  demo = gr.Interface(
 
1
+ import json
2
+ import shutil
3
+ import gc
4
  import gradio as gr
5
 
6
  import torch
 
9
  from safetensors.torch import save_file
10
  from huggingface_hub import hf_hub_download
11
 
12
+ def check_simple_file(st_weights_path, torch_weights_path):
 
 
 
 
 
 
 
13
  st_weights = safetensors.torch.load_file(st_weights_path)
14
  torch_weights = torch.load(torch_weights_path)
15
 
 
17
  if st_weights.keys() != torch_weights.keys():
18
  # retrieve different keys
19
  unexpected_keys = st_weights.keys() - torch_weights.keys()
20
+ return f"keys are not the same ! Conversion failed - unexpected keys are: {unexpected_keys} for the file {st_weights_path}"
21
 
22
  total_errors = []
23
 
 
29
  except Exception as e:
30
  total_errors.append(e)
31
 
32
+ del st_weights
33
+ del torch_weights
34
+ gc.collect()
35
+
36
+ return total_errors
37
+
38
+ def run(pr_number, model_id):
39
+ is_sharded = False
40
+ try:
41
+ st_sharded_index_file = hf_hub_download(repo_id=model_id, filename="model.safetensors.index.json", revision=f"refs/pr/{pr_number}")
42
+ torch_sharded_index_file = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
43
+
44
+ is_sharded = True
45
+ except:
46
+ pass
47
+
48
+ if not is_sharded:
49
+ try:
50
+ st_weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=f"refs/pr/{pr_number}")
51
+ torch_weights_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
52
+ except Exception as e:
53
+ return f"Error: {e} | \n Maybe you specified model ids or PRs that does not exist or does not contain any `model.safetensors` or `pytorch_model.bin` files"
54
+
55
+ total_errors = check_simple_file(st_weights_path, torch_weights_path)
56
+ else:
57
+ total_errors = []
58
+ total_st_files = set(json.load(open(st_sharded_index_file, "r"))["weight_map"].values())
59
+ total_pt_files = set(json.load(open(torch_sharded_index_file, "r"))["weight_map"].values())
60
+
61
+ if len(total_st_files) != len(total_pt_files):
62
+ return f"weights are not the same there are {len(total_st_files)} files in safetensors and {len(total_pt_files)} files in torch ! Conversion failed - {len(total_errors)} errors : {total_errors}"
63
+
64
+ # check if the mapping are correct
65
+ if not all([pt_file.replace("pytorch_model", "model").replace(".bin", ".safetensors") in total_st_files for pt_file in total_pt_files]):
66
+ return f"Conversion failed! Safetensors files are not the same as torch files - make sure you have the correct files in the PR"
67
+
68
+ for pt_file in total_pt_files:
69
+ st_file = pt_file.replace("pytorch_model", "model").replace(".bin", ".safetensors")
70
+
71
+ st_weights_path = hf_hub_download(repo_id=model_id, filename=st_file, revision=f"refs/pr/{pr_number}")
72
+ torch_weights_path = hf_hub_download(repo_id=model_id, filename=pt_file)
73
+
74
+ total_errors += check_simple_file(st_weights_path, torch_weights_path)
75
+
76
+ # remove files for memory optimization
77
+ shutil.rmtree(st_weights_path)
78
+ shutil.rmtree(torch_weights_path)
79
+
80
 
81
  if len(total_errors) > 0:
82
  return f"weights are not the same ! Conversion failed - {len(total_errors)} errors : {total_errors}"
 
91
  - Click "Submit"
92
  - That's it! You'll get feedback if the user successfully converted a model in `safetensors` format or not!
93
 
94
+ This checker also support sharded weights.
95
  """
96
 
97
  demo = gr.Interface(