mrcuddle commited on
Commit
2491ca5
1 Parent(s): b7b912a

Update hf_merge.py

Browse files
Files changed (1) hide show
  1. hf_merge.py +4 -1
hf_merge.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  from tqdm import tqdm
6
  from huggingface_hub import HfApi, hf_hub_download
7
  from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
 
8
  import logging
9
 
10
  # Set up logging
@@ -92,8 +93,10 @@ class ModelMerger:
92
  repo_manager.delete_repo(repo_path)
93
  repo_manager.download_repo(repo_name, repo_path)
94
 
 
 
95
  try:
96
- self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
97
  logging.info(f"Merged {repo_name}")
98
  except Exception as e:
99
  logging.error(f"Error merging {repo_name}: {e}")
 
5
  from tqdm import tqdm
6
  from huggingface_hub import HfApi, hf_hub_download
7
  from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
8
+ import torch
9
  import logging
10
 
11
  # Set up logging
 
93
  repo_manager.delete_repo(repo_path)
94
  repo_manager.download_repo(repo_name, repo_path)
95
 
96
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
97
+
98
  try:
99
+ self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val, device)
100
  logging.info(f"Merged {repo_name}")
101
  except Exception as e:
102
  logging.error(f"Error merging {repo_name}: {e}")