ynhe commited on
Commit
d9d19ec
·
verified ·
1 Parent(s): 06c102b

Create repository.py

Browse files
Files changed (1) hide show
  1. repository.py +1477 -0
repository.py ADDED
@@ -0,0 +1,1477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import threading
6
+ import time
7
+ from contextlib import contextmanager
8
+ from pathlib import Path
9
+ from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
10
+ from urllib.parse import urlparse
11
+
12
+ from huggingface_hub import constants
13
+ from huggingface_hub.repocard import metadata_load, metadata_save
14
+
15
+ from .hf_api import HfApi, repo_type_and_id_from_hf_id
16
+ from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
17
+ from .utils import (
18
+ SoftTemporaryDirectory,
19
+ get_token,
20
+ logging,
21
+ run_subprocess,
22
+ tqdm,
23
+ validate_hf_hub_args,
24
+ )
25
+ from .utils._deprecation import _deprecate_method
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class CommandInProgress:
32
+ """
33
+ Utility to follow commands launched asynchronously.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ title: str,
39
+ is_done_method: Callable,
40
+ status_method: Callable,
41
+ process: subprocess.Popen,
42
+ post_method: Optional[Callable] = None,
43
+ ):
44
+ self.title = title
45
+ self._is_done = is_done_method
46
+ self._status = status_method
47
+ self._process = process
48
+ self._stderr = ""
49
+ self._stdout = ""
50
+ self._post_method = post_method
51
+
52
+ @property
53
+ def is_done(self) -> bool:
54
+ """
55
+ Whether the process is done.
56
+ """
57
+ result = self._is_done()
58
+
59
+ if result and self._post_method is not None:
60
+ self._post_method()
61
+ self._post_method = None
62
+
63
+ return result
64
+
65
+ @property
66
+ def status(self) -> int:
67
+ """
68
+ The exit code/status of the current action. Will return `0` if the
69
+ command has completed successfully, and a number between 1 and 255 if
70
+ the process errored-out.
71
+
72
+ Will return -1 if the command is still ongoing.
73
+ """
74
+ return self._status()
75
+
76
+ @property
77
+ def failed(self) -> bool:
78
+ """
79
+ Whether the process errored-out.
80
+ """
81
+ return self.status > 0
82
+
83
+ @property
84
+ def stderr(self) -> str:
85
+ """
86
+ The current output message on the standard error.
87
+ """
88
+ if self._process.stderr is not None:
89
+ self._stderr += self._process.stderr.read()
90
+ return self._stderr
91
+
92
+ @property
93
+ def stdout(self) -> str:
94
+ """
95
+ The current output message on the standard output.
96
+ """
97
+ if self._process.stdout is not None:
98
+ self._stdout += self._process.stdout.read()
99
+ return self._stdout
100
+
101
+ def __repr__(self):
102
+ status = self.status
103
+
104
+ if status == -1:
105
+ status = "running"
106
+
107
+ return (
108
+ f"[{self.title} command, status code: {status},"
109
+ f" {'in progress.' if not self.is_done else 'finished.'} PID:"
110
+ f" {self._process.pid}]"
111
+ )
112
+
113
+
114
+ def is_git_repo(folder: Union[str, Path]) -> bool:
115
+ """
116
+ Check if the folder is the root or part of a git repository
117
+
118
+ Args:
119
+ folder (`str`):
120
+ The folder in which to run the command.
121
+
122
+ Returns:
123
+ `bool`: `True` if the repository is part of a repository, `False`
124
+ otherwise.
125
+ """
126
+ folder_exists = os.path.exists(os.path.join(folder, ".git"))
127
+ git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
128
+ return folder_exists and git_branch.returncode == 0
129
+
130
+
131
+ def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool:
132
+ """
133
+ Check if the folder is a local clone of the remote_url
134
+
135
+ Args:
136
+ folder (`str` or `Path`):
137
+ The folder in which to run the command.
138
+ remote_url (`str`):
139
+ The url of a git repository.
140
+
141
+ Returns:
142
+ `bool`: `True` if the repository is a local clone of the remote
143
+ repository specified, `False` otherwise.
144
+ """
145
+ if not is_git_repo(folder):
146
+ return False
147
+
148
+ remotes = run_subprocess("git remote -v", folder).stdout
149
+
150
+ # Remove token for the test with remotes.
151
+ remote_url = re.sub(r"https://.*@", "https://", remote_url)
152
+ remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()]
153
+ return remote_url in remotes
154
+
155
+
156
+ def is_tracked_with_lfs(filename: Union[str, Path]) -> bool:
157
+ """
158
+ Check if the file passed is tracked with git-lfs.
159
+
160
+ Args:
161
+ filename (`str` or `Path`):
162
+ The filename to check.
163
+
164
+ Returns:
165
+ `bool`: `True` if the file passed is tracked with git-lfs, `False`
166
+ otherwise.
167
+ """
168
+ folder = Path(filename).parent
169
+ filename = Path(filename).name
170
+
171
+ try:
172
+ p = run_subprocess("git check-attr -a".split() + [filename], folder)
173
+ attributes = p.stdout.strip()
174
+ except subprocess.CalledProcessError as exc:
175
+ if not is_git_repo(folder):
176
+ return False
177
+ else:
178
+ raise OSError(exc.stderr)
179
+
180
+ if len(attributes) == 0:
181
+ return False
182
+
183
+ found_lfs_tag = {"diff": False, "merge": False, "filter": False}
184
+
185
+ for attribute in attributes.split("\n"):
186
+ for tag in found_lfs_tag.keys():
187
+ if tag in attribute and "lfs" in attribute:
188
+ found_lfs_tag[tag] = True
189
+
190
+ return all(found_lfs_tag.values())
191
+
192
+
193
+ def is_git_ignored(filename: Union[str, Path]) -> bool:
194
+ """
195
+ Check if file is git-ignored. Supports nested .gitignore files.
196
+
197
+ Args:
198
+ filename (`str` or `Path`):
199
+ The filename to check.
200
+
201
+ Returns:
202
+ `bool`: `True` if the file passed is ignored by `git`, `False`
203
+ otherwise.
204
+ """
205
+ folder = Path(filename).parent
206
+ filename = Path(filename).name
207
+
208
+ try:
209
+ p = run_subprocess("git check-ignore".split() + [filename], folder, check=False)
210
+ # Will return exit code 1 if not gitignored
211
+ is_ignored = not bool(p.returncode)
212
+ except subprocess.CalledProcessError as exc:
213
+ raise OSError(exc.stderr)
214
+
215
+ return is_ignored
216
+
217
+
218
+ def is_binary_file(filename: Union[str, Path]) -> bool:
219
+ """
220
+ Check if file is a binary file.
221
+
222
+ Args:
223
+ filename (`str` or `Path`):
224
+ The filename to check.
225
+
226
+ Returns:
227
+ `bool`: `True` if the file passed is a binary file, `False` otherwise.
228
+ """
229
+ try:
230
+ with open(filename, "rb") as f:
231
+ content = f.read(10 * (1024**2)) # Read a maximum of 10MB
232
+
233
+ # Code sample taken from the following stack overflow thread
234
+ # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391
235
+ text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
236
+ return bool(content.translate(None, text_chars))
237
+ except UnicodeDecodeError:
238
+ return True
239
+
240
+
241
+ def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]:
242
+ """
243
+ Returns a list of filenames that are to be staged.
244
+
245
+ Args:
246
+ pattern (`str` or `Path`):
247
+ The pattern of filenames to check. Put `.` to get all files.
248
+ folder (`str` or `Path`):
249
+ The folder in which to run the command.
250
+
251
+ Returns:
252
+ `List[str]`: List of files that are to be staged.
253
+ """
254
+ try:
255
+ p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder)
256
+ if len(p.stdout.strip()):
257
+ files = p.stdout.strip().split("\n")
258
+ else:
259
+ files = []
260
+ except subprocess.CalledProcessError as exc:
261
+ raise EnvironmentError(exc.stderr)
262
+
263
+ return files
264
+
265
+
266
+ def is_tracked_upstream(folder: Union[str, Path]) -> bool:
267
+ """
268
+ Check if the current checked-out branch is tracked upstream.
269
+
270
+ Args:
271
+ folder (`str` or `Path`):
272
+ The folder in which to run the command.
273
+
274
+ Returns:
275
+ `bool`: `True` if the current checked-out branch is tracked upstream,
276
+ `False` otherwise.
277
+ """
278
+ try:
279
+ run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder)
280
+ return True
281
+ except subprocess.CalledProcessError as exc:
282
+ if "HEAD" in exc.stderr:
283
+ raise OSError("No branch checked out")
284
+
285
+ return False
286
+
287
+
288
+ def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int:
289
+ """
290
+ Check the number of commits that would be pushed upstream
291
+
292
+ Args:
293
+ folder (`str` or `Path`):
294
+ The folder in which to run the command.
295
+ upstream (`str`, *optional*):
296
+ The name of the upstream repository with which the comparison should be
297
+ made.
298
+
299
+ Returns:
300
+ `int`: Number of commits that would be pushed upstream were a `git
301
+ push` to proceed.
302
+ """
303
+ try:
304
+ result = run_subprocess(f"git cherry -v {upstream or ''}", folder)
305
+ return len(result.stdout.split("\n")) - 1
306
+ except subprocess.CalledProcessError as exc:
307
+ raise EnvironmentError(exc.stderr)
308
+
309
+
310
+ class PbarT(TypedDict):
311
+ # Used to store an opened progress bar in `_lfs_log_progress`
312
+ bar: tqdm
313
+ past_bytes: int
314
+
315
+
316
+ @contextmanager
317
+ def _lfs_log_progress():
318
+ """
319
+ This is a context manager that will log the Git LFS progress of cleaning,
320
+ smudging, pulling and pushing.
321
+ """
322
+
323
+ if logger.getEffectiveLevel() >= logging.ERROR:
324
+ try:
325
+ yield
326
+ except Exception:
327
+ pass
328
+ return
329
+
330
+ def output_progress(stopping_event: threading.Event):
331
+ """
332
+ To be launched as a separate thread with an event meaning it should stop
333
+ the tail.
334
+ """
335
+ # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value)
336
+ pbars: Dict[Tuple[str, str], PbarT] = {}
337
+
338
+ def close_pbars():
339
+ for pbar in pbars.values():
340
+ pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"])
341
+ pbar["bar"].refresh()
342
+ pbar["bar"].close()
343
+
344
+ def tail_file(filename) -> Iterator[str]:
345
+ """
346
+ Creates a generator to be iterated through, which will return each
347
+ line one by one. Will stop tailing the file if the stopping_event is
348
+ set.
349
+ """
350
+ with open(filename, "r") as file:
351
+ current_line = ""
352
+ while True:
353
+ if stopping_event.is_set():
354
+ close_pbars()
355
+ break
356
+
357
+ line_bit = file.readline()
358
+ if line_bit is not None and not len(line_bit.strip()) == 0:
359
+ current_line += line_bit
360
+ if current_line.endswith("\n"):
361
+ yield current_line
362
+ current_line = ""
363
+ else:
364
+ time.sleep(1)
365
+
366
+ # If the file isn't created yet, wait for a few seconds before trying again.
367
+ # Can be interrupted with the stopping_event.
368
+ while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]):
369
+ if stopping_event.is_set():
370
+ close_pbars()
371
+ return
372
+
373
+ time.sleep(2)
374
+
375
+ for line in tail_file(os.environ["GIT_LFS_PROGRESS"]):
376
+ try:
377
+ state, file_progress, byte_progress, filename = line.split(" ",4)
378
+ except ValueError as error:
379
+ # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373.
380
+ raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error
381
+ description = f"{state.capitalize()} file {filename}"
382
+
383
+ current_bytes, total_bytes = byte_progress.split("/")
384
+ current_bytes_int = int(current_bytes)
385
+ total_bytes_int = int(total_bytes)
386
+
387
+ pbar = pbars.get((state, filename))
388
+ if pbar is None:
389
+ # Initialize progress bar
390
+ pbars[(state, filename)] = {
391
+ "bar": tqdm(
392
+ desc=description,
393
+ initial=current_bytes_int,
394
+ total=total_bytes_int,
395
+ unit="B",
396
+ unit_scale=True,
397
+ unit_divisor=1024,
398
+ name="huggingface_hub.lfs_upload",
399
+ ),
400
+ "past_bytes": int(current_bytes),
401
+ }
402
+ else:
403
+ # Update progress bar
404
+ pbar["bar"].update(current_bytes_int - pbar["past_bytes"])
405
+ pbar["past_bytes"] = current_bytes_int
406
+
407
+ current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "")
408
+
409
+ with SoftTemporaryDirectory() as tmpdir:
410
+ os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress")
411
+ logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}")
412
+
413
+ exit_event = threading.Event()
414
+ x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True)
415
+ x.start()
416
+
417
+ try:
418
+ yield
419
+ finally:
420
+ exit_event.set()
421
+ x.join()
422
+
423
+ os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value
424
+
425
+
426
+ class Repository:
427
+ """
428
+ Helper class to wrap the git and git-lfs commands.
429
+
430
+ The aim is to facilitate interacting with huggingface.co hosted model or
431
+ dataset repos, though not a lot here (if any) is actually specific to
432
+ huggingface.co.
433
+
434
+ <Tip warning={true}>
435
+
436
+ [`Repository`] is deprecated in favor of the http-based alternatives implemented in
437
+ [`HfApi`]. Given its large adoption in legacy code, the complete removal of
438
+ [`Repository`] will only happen in release `v1.0`. For more details, please read
439
+ https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.
440
+
441
+ </Tip>
442
+ """
443
+
444
+ command_queue: List[CommandInProgress]
445
+
446
+ @validate_hf_hub_args
447
+ @_deprecate_method(
448
+ version="1.0",
449
+ message=(
450
+ "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete"
451
+ " removal is only planned on next major release.\nFor more details, please read"
452
+ " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http."
453
+ ),
454
+ )
455
+ def __init__(
456
+ self,
457
+ local_dir: Union[str, Path],
458
+ clone_from: Optional[str] = None,
459
+ repo_type: Optional[str] = None,
460
+ token: Union[bool, str] = True,
461
+ git_user: Optional[str] = None,
462
+ git_email: Optional[str] = None,
463
+ revision: Optional[str] = None,
464
+ skip_lfs_files: bool = False,
465
+ client: Optional[HfApi] = None,
466
+ ):
467
+ """
468
+ Instantiate a local clone of a git repo.
469
+
470
+ If `clone_from` is set, the repo will be cloned from an existing remote repository.
471
+ If the remote repo does not exist, a `EnvironmentError` exception will be thrown.
472
+ Please create the remote repo first using [`create_repo`].
473
+
474
+ `Repository` uses the local git credentials by default. If explicitly set, the `token`
475
+ or the `git_user`/`git_email` pair will be used instead.
476
+
477
+ Args:
478
+ local_dir (`str` or `Path`):
479
+ path (e.g. `'my_trained_model/'`) to the local directory, where
480
+ the `Repository` will be initialized.
481
+ clone_from (`str`, *optional*):
482
+ Either a repository url or `repo_id`.
483
+ Example:
484
+ - `"https://huggingface.co/philschmid/playground-tests"`
485
+ - `"philschmid/playground-tests"`
486
+ repo_type (`str`, *optional*):
487
+ To set when cloning a repo from a repo_id. Default is model.
488
+ token (`bool` or `str`, *optional*):
489
+ A valid authentication token (see https://huggingface.co/settings/token).
490
+ If `None` or `True` and machine is logged in (through `huggingface-cli login`
491
+ or [`~huggingface_hub.login`]), token will be retrieved from the cache.
492
+ If `False`, token is not sent in the request header.
493
+ git_user (`str`, *optional*):
494
+ will override the `git config user.name` for committing and
495
+ pushing files to the hub.
496
+ git_email (`str`, *optional*):
497
+ will override the `git config user.email` for committing and
498
+ pushing files to the hub.
499
+ revision (`str`, *optional*):
500
+ Revision to checkout after initializing the repository. If the
501
+ revision doesn't exist, a branch will be created with that
502
+ revision name from the default branch's current HEAD.
503
+ skip_lfs_files (`bool`, *optional*, defaults to `False`):
504
+ whether to skip git-LFS files or not.
505
+ client (`HfApi`, *optional*):
506
+ Instance of [`HfApi`] to use when calling the HF Hub API. A new
507
+ instance will be created if this is left to `None`.
508
+
509
+ Raises:
510
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
511
+ If the remote repository set in `clone_from` does not exist.
512
+ """
513
+ if isinstance(local_dir, Path):
514
+ local_dir = str(local_dir)
515
+ os.makedirs(local_dir, exist_ok=True)
516
+ self.local_dir = os.path.join(os.getcwd(), local_dir)
517
+ self._repo_type = repo_type
518
+ self.command_queue = []
519
+ self.skip_lfs_files = skip_lfs_files
520
+ self.client = client if client is not None else HfApi()
521
+
522
+ self.check_git_versions()
523
+
524
+ if isinstance(token, str):
525
+ self.huggingface_token: Optional[str] = token
526
+ elif token is False:
527
+ self.huggingface_token = None
528
+ else:
529
+ # if `True` -> explicit use of the cached token
530
+ # if `None` -> implicit use of the cached token
531
+ self.huggingface_token = get_token()
532
+
533
+ if clone_from is not None:
534
+ self.clone_from(repo_url=clone_from)
535
+ else:
536
+ if is_git_repo(self.local_dir):
537
+ logger.debug("[Repository] is a valid git repo")
538
+ else:
539
+ raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.")
540
+
541
+ if self.huggingface_token is not None and (git_email is None or git_user is None):
542
+ user = self.client.whoami(self.huggingface_token)
543
+
544
+ if git_email is None:
545
+ git_email = user.get("email")
546
+
547
+ if git_user is None:
548
+ git_user = user.get("fullname")
549
+
550
+ if git_user is not None or git_email is not None:
551
+ self.git_config_username_and_email(git_user, git_email)
552
+
553
+ self.lfs_enable_largefiles()
554
+ self.git_credential_helper_store()
555
+
556
+ if revision is not None:
557
+ self.git_checkout(revision, create_branch_ok=True)
558
+
559
+ # This ensures that all commands exit before exiting the Python runtime.
560
+ # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations.
561
+ atexit.register(self.wait_for_commands)
562
+
563
+ @property
564
+ def current_branch(self) -> str:
565
+ """
566
+ Returns the current checked out branch.
567
+
568
+ Returns:
569
+ `str`: Current checked out branch.
570
+ """
571
+ try:
572
+ result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip()
573
+ except subprocess.CalledProcessError as exc:
574
+ raise EnvironmentError(exc.stderr)
575
+
576
+ return result
577
+
578
+ def check_git_versions(self):
579
+ """
580
+ Checks that `git` and `git-lfs` can be run.
581
+
582
+ Raises:
583
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
584
+ If `git` or `git-lfs` are not installed.
585
+ """
586
+ try:
587
+ git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
588
+ except FileNotFoundError:
589
+ raise EnvironmentError("Looks like you do not have git installed, please install.")
590
+
591
+ try:
592
+ lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip()
593
+ except FileNotFoundError:
594
+ raise EnvironmentError(
595
+ "Looks like you do not have git-lfs installed, please install."
596
+ " You can install from https://git-lfs.github.com/."
597
+ " Then run `git lfs install` (you only have to do this once)."
598
+ )
599
+ logger.info(git_version + "\n" + lfs_version)
600
+
601
+ @validate_hf_hub_args
602
+ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None):
603
+ """
604
+ Clone from a remote. If the folder already exists, will try to clone the
605
+ repository within it.
606
+
607
+ If this folder is a git repository with linked history, will try to
608
+ update the repository.
609
+
610
+ Args:
611
+ repo_url (`str`):
612
+ The URL from which to clone the repository
613
+ token (`Union[str, bool]`, *optional*):
614
+ Whether to use the authentication token. It can be:
615
+ - a string which is the token itself
616
+ - `False`, which would not use the authentication token
617
+ - `True`, which would fetch the authentication token from the
618
+ local folder and use it (you should be logged in for this to
619
+ work).
620
+ - `None`, which would retrieve the value of
621
+ `self.huggingface_token`.
622
+
623
+ <Tip>
624
+
625
+ Raises the following error:
626
+
627
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
628
+ if an organization token (starts with "api_org") is passed. Use must use
629
+ your own personal access token (see https://hf.co/settings/tokens).
630
+
631
+ - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
632
+ if you are trying to clone the repository in a non-empty folder, or if the
633
+ `git` operations raise errors.
634
+
635
+ </Tip>
636
+ """
637
+ token = (
638
+ token # str -> use it
639
+ if isinstance(token, str)
640
+ else (
641
+ None # `False` -> explicit no token
642
+ if token is False
643
+ else self.huggingface_token # `None` or `True` -> use default
644
+ )
645
+ )
646
+ if token is not None and token.startswith("api_org"):
647
+ raise ValueError(
648
+ "You must use your personal access token, not an Organization token"
649
+ " (see https://hf.co/settings/tokens)."
650
+ )
651
+
652
+ hub_url = self.client.endpoint
653
+ if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2):
654
+ repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url)
655
+ repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name
656
+
657
+ if repo_type is not None:
658
+ self._repo_type = repo_type
659
+
660
+ repo_url = hub_url + "/"
661
+
662
+ if self._repo_type in constants.REPO_TYPES_URL_PREFIXES:
663
+ repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type]
664
+
665
+ if token is not None:
666
+ # Add token in git url when provided
667
+ scheme = urlparse(repo_url).scheme
668
+ repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@")
669
+
670
+ repo_url += repo_id
671
+
672
+ # For error messages, it's cleaner to show the repo url without the token.
673
+ clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url)
674
+ try:
675
+ run_subprocess("git lfs install", self.local_dir)
676
+
677
+ # checks if repository is initialized in a empty repository or in one with files
678
+ if len(os.listdir(self.local_dir)) == 0:
679
+ logger.warning(f"Cloning {clean_repo_url} into local empty directory.")
680
+
681
+ with _lfs_log_progress():
682
+ env = os.environ.copy()
683
+
684
+ if self.skip_lfs_files:
685
+ env.update({"GIT_LFS_SKIP_SMUDGE": "1"})
686
+
687
+ run_subprocess(
688
+ # 'git lfs clone' is deprecated (will display a warning in the terminal)
689
+ # but we still use it as it provides a nicer UX when downloading large
690
+ # files (shows progress).
691
+ f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .",
692
+ self.local_dir,
693
+ env=env,
694
+ )
695
+ else:
696
+ # Check if the folder is the root of a git repository
697
+ if not is_git_repo(self.local_dir):
698
+ raise EnvironmentError(
699
+ "Tried to clone a repository in a non-empty folder that isn't"
700
+ f" a git repository ('{self.local_dir}'). If you really want to"
701
+ f" do this, do it manually:\n cd {self.local_dir} && git init"
702
+ " && git remote add origin && git pull origin main\n or clone"
703
+ " repo to a new folder and move your existing files there"
704
+ " afterwards."
705
+ )
706
+
707
+ if is_local_clone(self.local_dir, repo_url):
708
+ logger.warning(
709
+ f"{self.local_dir} is already a clone of {clean_repo_url}."
710
+ " Make sure you pull the latest changes with"
711
+ " `repo.git_pull()`."
712
+ )
713
+ else:
714
+ output = run_subprocess("git remote get-url origin", self.local_dir, check=False)
715
+
716
+ error_msg = (
717
+ f"Tried to clone {clean_repo_url} in an unrelated git"
718
+ " repository.\nIf you believe this is an error, please add"
719
+ f" a remote with the following URL: {clean_repo_url}."
720
+ )
721
+ if output.returncode == 0:
722
+ clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout)
723
+ error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}"
724
+ raise EnvironmentError(error_msg)
725
+
726
+ except subprocess.CalledProcessError as exc:
727
+ raise EnvironmentError(exc.stderr)
728
+
729
+ def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None):
730
+ """
731
+ Sets git username and email (only in the current repo).
732
+
733
+ Args:
734
+ git_user (`str`, *optional*):
735
+ The username to register through `git`.
736
+ git_email (`str`, *optional*):
737
+ The email to register through `git`.
738
+ """
739
+ try:
740
+ if git_user is not None:
741
+ run_subprocess("git config user.name".split() + [git_user], self.local_dir)
742
+
743
+ if git_email is not None:
744
+ run_subprocess(f"git config user.email {git_email}".split(), self.local_dir)
745
+ except subprocess.CalledProcessError as exc:
746
+ raise EnvironmentError(exc.stderr)
747
+
748
+ def git_credential_helper_store(self):
749
+ """
750
+ Sets the git credential helper to `store`
751
+ """
752
+ try:
753
+ run_subprocess("git config credential.helper store", self.local_dir)
754
+ except subprocess.CalledProcessError as exc:
755
+ raise EnvironmentError(exc.stderr)
756
+
757
+ def git_head_hash(self) -> str:
758
+ """
759
+ Get commit sha on top of HEAD.
760
+
761
+ Returns:
762
+ `str`: The current checked out commit SHA.
763
+ """
764
+ try:
765
+ p = run_subprocess("git rev-parse HEAD", self.local_dir)
766
+ return p.stdout.strip()
767
+ except subprocess.CalledProcessError as exc:
768
+ raise EnvironmentError(exc.stderr)
769
+
770
+ def git_remote_url(self) -> str:
771
+ """
772
+ Get URL to origin remote.
773
+
774
+ Returns:
775
+ `str`: The URL of the `origin` remote.
776
+ """
777
+ try:
778
+ p = run_subprocess("git config --get remote.origin.url", self.local_dir)
779
+ url = p.stdout.strip()
780
+ # Strip basic auth info.
781
+ return re.sub(r"https://.*@", "https://", url)
782
+ except subprocess.CalledProcessError as exc:
783
+ raise EnvironmentError(exc.stderr)
784
+
785
+ def git_head_commit_url(self) -> str:
786
+ """
787
+ Get URL to last commit on HEAD. We assume it's been pushed, and the url
788
+ scheme is the same one as for GitHub or HuggingFace.
789
+
790
+ Returns:
791
+ `str`: The URL to the current checked-out commit.
792
+ """
793
+ sha = self.git_head_hash()
794
+ url = self.git_remote_url()
795
+ if url.endswith("/"):
796
+ url = url[:-1]
797
+ return f"{url}/commit/{sha}"
798
+
799
+ def list_deleted_files(self) -> List[str]:
800
+ """
801
+ Returns a list of the files that are deleted in the working directory or
802
+ index.
803
+
804
+ Returns:
805
+ `List[str]`: A list of files that have been deleted in the working
806
+ directory or index.
807
+ """
808
+ try:
809
+ git_status = run_subprocess("git status -s", self.local_dir).stdout.strip()
810
+ except subprocess.CalledProcessError as exc:
811
+ raise EnvironmentError(exc.stderr)
812
+
813
+ if len(git_status) == 0:
814
+ return []
815
+
816
+ # Receives a status like the following
817
+ # D .gitignore
818
+ # D new_file.json
819
+ # AD new_file1.json
820
+ # ?? new_file2.json
821
+ # ?? new_file4.json
822
+
823
+ # Strip each line of whitespaces
824
+ modified_files_statuses = [status.strip() for status in git_status.split("\n")]
825
+
826
+ # Only keep files that are deleted using the D prefix
827
+ deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]]
828
+
829
+ # Remove the D prefix and strip to keep only the relevant filename
830
+ deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses]
831
+
832
+ return deleted_files
833
+
834
+ def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False):
835
+ """
836
+ Tell git-lfs to track files according to a pattern.
837
+
838
+ Setting the `filename` argument to `True` will treat the arguments as
839
+ literal filenames, not as patterns. Any special glob characters in the
840
+ filename will be escaped when writing to the `.gitattributes` file.
841
+
842
+ Args:
843
+ patterns (`Union[str, List[str]]`):
844
+ The pattern, or list of patterns, to track with git-lfs.
845
+ filename (`bool`, *optional*, defaults to `False`):
846
+ Whether to use the patterns as literal filenames.
847
+ """
848
+ if isinstance(patterns, str):
849
+ patterns = [patterns]
850
+ try:
851
+ for pattern in patterns:
852
+ run_subprocess(
853
+ f"git lfs track {'--filename' if filename else ''} {pattern}",
854
+ self.local_dir,
855
+ )
856
+ except subprocess.CalledProcessError as exc:
857
+ raise EnvironmentError(exc.stderr)
858
+
859
+ def lfs_untrack(self, patterns: Union[str, List[str]]):
860
+ """
861
+ Tell git-lfs to untrack those files.
862
+
863
+ Args:
864
+ patterns (`Union[str, List[str]]`):
865
+ The pattern, or list of patterns, to untrack with git-lfs.
866
+ """
867
+ if isinstance(patterns, str):
868
+ patterns = [patterns]
869
+ try:
870
+ for pattern in patterns:
871
+ run_subprocess("git lfs untrack".split() + [pattern], self.local_dir)
872
+ except subprocess.CalledProcessError as exc:
873
+ raise EnvironmentError(exc.stderr)
874
+
875
+ def lfs_enable_largefiles(self):
876
+ """
877
+ HF-specific. This enables upload support of files >5GB.
878
+ """
879
+ try:
880
+ lfs_config = "git config lfs.customtransfer.multipart"
881
+ run_subprocess(f"{lfs_config}.path huggingface-cli", self.local_dir)
882
+ run_subprocess(
883
+ f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}",
884
+ self.local_dir,
885
+ )
886
+ except subprocess.CalledProcessError as exc:
887
+ raise EnvironmentError(exc.stderr)
888
+
889
+ def auto_track_binary_files(self, pattern: str = ".") -> List[str]:
890
+ """
891
+ Automatically track binary files with git-lfs.
892
+
893
+ Args:
894
+ pattern (`str`, *optional*, defaults to "."):
895
+ The pattern with which to track files that are binary.
896
+
897
+ Returns:
898
+ `List[str]`: List of filenames that are now tracked due to being
899
+ binary files
900
+ """
901
+ files_to_be_tracked_with_lfs = []
902
+
903
+ deleted_files = self.list_deleted_files()
904
+
905
+ for filename in files_to_be_staged(pattern, folder=self.local_dir):
906
+ if filename in deleted_files:
907
+ continue
908
+
909
+ path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
910
+
911
+ if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)):
912
+ size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
913
+
914
+ if size_in_mb >= 10:
915
+ logger.warning(
916
+ "Parsing a large file to check if binary or not. Tracking large"
917
+ " files using `repository.auto_track_large_files` is"
918
+ " recommended so as to not load the full file in memory."
919
+ )
920
+
921
+ is_binary = is_binary_file(path_to_file)
922
+
923
+ if is_binary:
924
+ self.lfs_track(filename)
925
+ files_to_be_tracked_with_lfs.append(filename)
926
+
927
+ # Cleanup the .gitattributes if files were deleted
928
+ self.lfs_untrack(deleted_files)
929
+
930
+ return files_to_be_tracked_with_lfs
931
+
932
+ def auto_track_large_files(self, pattern: str = ".") -> List[str]:
933
+ """
934
+ Automatically track large files (files that weigh more than 10MBs) with
935
+ git-lfs.
936
+
937
+ Args:
938
+ pattern (`str`, *optional*, defaults to "."):
939
+ The pattern with which to track files that are above 10MBs.
940
+
941
+ Returns:
942
+ `List[str]`: List of filenames that are now tracked due to their
943
+ size.
944
+ """
945
+ files_to_be_tracked_with_lfs = []
946
+
947
+ deleted_files = self.list_deleted_files()
948
+
949
+ for filename in files_to_be_staged(pattern, folder=self.local_dir):
950
+ if filename in deleted_files:
951
+ continue
952
+
953
+ path_to_file = os.path.join(os.getcwd(), self.local_dir, filename)
954
+ size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
955
+
956
+ if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file):
957
+ self.lfs_track(filename)
958
+ files_to_be_tracked_with_lfs.append(filename)
959
+
960
+ # Cleanup the .gitattributes if files were deleted
961
+ self.lfs_untrack(deleted_files)
962
+
963
+ return files_to_be_tracked_with_lfs
964
+
965
+ def lfs_prune(self, recent=False):
966
+ """
967
+ git lfs prune
968
+
969
+ Args:
970
+ recent (`bool`, *optional*, defaults to `False`):
971
+ Whether to prune files even if they were referenced by recent
972
+ commits. See the following
973
+ [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files)
974
+ for more information.
975
+ """
976
+ try:
977
+ with _lfs_log_progress():
978
+ result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir)
979
+ logger.info(result.stdout)
980
+ except subprocess.CalledProcessError as exc:
981
+ raise EnvironmentError(exc.stderr)
982
+
983
+ def git_pull(self, rebase: bool = False, lfs: bool = False):
984
+ """
985
+ git pull
986
+
987
+ Args:
988
+ rebase (`bool`, *optional*, defaults to `False`):
989
+ Whether to rebase the current branch on top of the upstream
990
+ branch after fetching.
991
+ lfs (`bool`, *optional*, defaults to `False`):
992
+ Whether to fetch the LFS files too. This option only changes the
993
+ behavior when a repository was cloned without fetching the LFS
994
+ files; calling `repo.git_pull(lfs=True)` will then fetch the LFS
995
+ file from the remote repository.
996
+ """
997
+ command = "git pull" if not lfs else "git lfs pull"
998
+ if rebase:
999
+ command += " --rebase"
1000
+ try:
1001
+ with _lfs_log_progress():
1002
+ result = run_subprocess(command, self.local_dir)
1003
+ logger.info(result.stdout)
1004
+ except subprocess.CalledProcessError as exc:
1005
+ raise EnvironmentError(exc.stderr)
1006
+
1007
+ def git_add(self, pattern: str = ".", auto_lfs_track: bool = False):
1008
+ """
1009
+ git add
1010
+
1011
+ Setting the `auto_lfs_track` parameter to `True` will automatically
1012
+ track files that are larger than 10MB with `git-lfs`.
1013
+
1014
+ Args:
1015
+ pattern (`str`, *optional*, defaults to "."):
1016
+ The pattern with which to add files to staging.
1017
+ auto_lfs_track (`bool`, *optional*, defaults to `False`):
1018
+ Whether to automatically track large and binary files with
1019
+ git-lfs. Any file over 10MB in size, or in binary format, will
1020
+ be automatically tracked.
1021
+ """
1022
+ if auto_lfs_track:
1023
+ # Track files according to their size (>=10MB)
1024
+ tracked_files = self.auto_track_large_files(pattern)
1025
+
1026
+ # Read the remaining files and track them if they're binary
1027
+ tracked_files.extend(self.auto_track_binary_files(pattern))
1028
+
1029
+ if tracked_files:
1030
+ logger.warning(
1031
+ f"Adding files tracked by Git LFS: {tracked_files}. This may take a"
1032
+ " bit of time if the files are large."
1033
+ )
1034
+
1035
+ try:
1036
+ result = run_subprocess("git add -v".split() + [pattern], self.local_dir)
1037
+ logger.info(f"Adding to index:\n{result.stdout}\n")
1038
+ except subprocess.CalledProcessError as exc:
1039
+ raise EnvironmentError(exc.stderr)
1040
+
1041
+ def git_commit(self, commit_message: str = "commit files to HF hub"):
1042
+ """
1043
+ git commit
1044
+
1045
+ Args:
1046
+ commit_message (`str`, *optional*, defaults to "commit files to HF hub"):
1047
+ The message attributed to the commit.
1048
+ """
1049
+ try:
1050
+ result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir)
1051
+ logger.info(f"Committed:\n{result.stdout}\n")
1052
+ except subprocess.CalledProcessError as exc:
1053
+ if len(exc.stderr) > 0:
1054
+ raise EnvironmentError(exc.stderr)
1055
+ else:
1056
+ raise EnvironmentError(exc.stdout)
1057
+
1058
+ def git_push(
1059
+ self,
1060
+ upstream: Optional[str] = None,
1061
+ blocking: bool = True,
1062
+ auto_lfs_prune: bool = False,
1063
+ ) -> Union[str, Tuple[str, CommandInProgress]]:
1064
+ """
1065
+ git push
1066
+
1067
+ If used without setting `blocking`, will return url to commit on remote
1068
+ repo. If used with `blocking=True`, will return a tuple containing the
1069
+ url to commit and the command object to follow for information about the
1070
+ process.
1071
+
1072
+ Args:
1073
+ upstream (`str`, *optional*):
1074
+ Upstream to which this should push. If not specified, will push
1075
+ to the lastly defined upstream or to the default one (`origin
1076
+ main`).
1077
+ blocking (`bool`, *optional*, defaults to `True`):
1078
+ Whether the function should return only when the push has
1079
+ finished. Setting this to `False` will return an
1080
+ `CommandInProgress` object which has an `is_done` property. This
1081
+ property will be set to `True` when the push is finished.
1082
+ auto_lfs_prune (`bool`, *optional*, defaults to `False`):
1083
+ Whether to automatically prune files once they have been pushed
1084
+ to the remote.
1085
+ """
1086
+ command = "git push"
1087
+
1088
+ if upstream:
1089
+ command += f" --set-upstream {upstream}"
1090
+
1091
+ number_of_commits = commits_to_push(self.local_dir, upstream)
1092
+
1093
+ if number_of_commits > 1:
1094
+ logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.")
1095
+ if blocking:
1096
+ logger.warning("The progress bars may be unreliable.")
1097
+
1098
+ try:
1099
+ with _lfs_log_progress():
1100
+ process = subprocess.Popen(
1101
+ command.split(),
1102
+ stderr=subprocess.PIPE,
1103
+ stdout=subprocess.PIPE,
1104
+ encoding="utf-8",
1105
+ cwd=self.local_dir,
1106
+ )
1107
+
1108
+ if blocking:
1109
+ stdout, stderr = process.communicate()
1110
+ return_code = process.poll()
1111
+ process.kill()
1112
+
1113
+ if len(stderr):
1114
+ logger.warning(stderr)
1115
+
1116
+ if return_code:
1117
+ raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr)
1118
+
1119
+ except subprocess.CalledProcessError as exc:
1120
+ raise EnvironmentError(exc.stderr)
1121
+
1122
+ if not blocking:
1123
+
1124
+ def status_method():
1125
+ status = process.poll()
1126
+ if status is None:
1127
+ return -1
1128
+ else:
1129
+ return status
1130
+
1131
+ command_in_progress = CommandInProgress(
1132
+ "push",
1133
+ is_done_method=lambda: process.poll() is not None,
1134
+ status_method=status_method,
1135
+ process=process,
1136
+ post_method=self.lfs_prune if auto_lfs_prune else None,
1137
+ )
1138
+
1139
+ self.command_queue.append(command_in_progress)
1140
+
1141
+ return self.git_head_commit_url(), command_in_progress
1142
+
1143
+ if auto_lfs_prune:
1144
+ self.lfs_prune()
1145
+
1146
+ return self.git_head_commit_url()
1147
+
1148
+ def git_checkout(self, revision: str, create_branch_ok: bool = False):
1149
+ """
1150
+ git checkout a given revision
1151
+
1152
+ Specifying `create_branch_ok` to `True` will create the branch to the
1153
+ given revision if that revision doesn't exist.
1154
+
1155
+ Args:
1156
+ revision (`str`):
1157
+ The revision to checkout.
1158
+ create_branch_ok (`str`, *optional*, defaults to `False`):
1159
+ Whether creating a branch named with the `revision` passed at
1160
+ the current checked-out reference if `revision` isn't an
1161
+ existing revision is allowed.
1162
+ """
1163
+ try:
1164
+ result = run_subprocess(f"git checkout {revision}", self.local_dir)
1165
+ logger.warning(f"Checked out {revision} from {self.current_branch}.")
1166
+ logger.warning(result.stdout)
1167
+ except subprocess.CalledProcessError as exc:
1168
+ if not create_branch_ok:
1169
+ raise EnvironmentError(exc.stderr)
1170
+ else:
1171
+ try:
1172
+ result = run_subprocess(f"git checkout -b {revision}", self.local_dir)
1173
+ logger.warning(
1174
+ f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`."
1175
+ )
1176
+ logger.warning(result.stdout)
1177
+ except subprocess.CalledProcessError as exc:
1178
+ raise EnvironmentError(exc.stderr)
1179
+
1180
+ def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool:
1181
+ """
1182
+ Check if a tag exists or not.
1183
+
1184
+ Args:
1185
+ tag_name (`str`):
1186
+ The name of the tag to check.
1187
+ remote (`str`, *optional*):
1188
+ Whether to check if the tag exists on a remote. This parameter
1189
+ should be the identifier of the remote.
1190
+
1191
+ Returns:
1192
+ `bool`: Whether the tag exists.
1193
+ """
1194
+ if remote:
1195
+ try:
1196
+ result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip()
1197
+ except subprocess.CalledProcessError as exc:
1198
+ raise EnvironmentError(exc.stderr)
1199
+
1200
+ return len(result) != 0
1201
+ else:
1202
+ try:
1203
+ git_tags = run_subprocess("git tag", self.local_dir).stdout.strip()
1204
+ except subprocess.CalledProcessError as exc:
1205
+ raise EnvironmentError(exc.stderr)
1206
+
1207
+ git_tags = git_tags.split("\n")
1208
+ return tag_name in git_tags
1209
+
1210
+ def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool:
1211
+ """
1212
+ Delete a tag, both local and remote, if it exists
1213
+
1214
+ Args:
1215
+ tag_name (`str`):
1216
+ The tag name to delete.
1217
+ remote (`str`, *optional*):
1218
+ The remote on which to delete the tag.
1219
+
1220
+ Returns:
1221
+ `bool`: `True` if deleted, `False` if the tag didn't exist.
1222
+ If remote is not passed, will just be updated locally
1223
+ """
1224
+ delete_locally = True
1225
+ delete_remotely = True
1226
+
1227
+ if not self.tag_exists(tag_name):
1228
+ delete_locally = False
1229
+
1230
+ if not self.tag_exists(tag_name, remote=remote):
1231
+ delete_remotely = False
1232
+
1233
+ if delete_locally:
1234
+ try:
1235
+ run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip()
1236
+ except subprocess.CalledProcessError as exc:
1237
+ raise EnvironmentError(exc.stderr)
1238
+
1239
+ if remote and delete_remotely:
1240
+ try:
1241
+ run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip()
1242
+ except subprocess.CalledProcessError as exc:
1243
+ raise EnvironmentError(exc.stderr)
1244
+
1245
+ return True
1246
+
1247
+ def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None):
1248
+ """
1249
+ Add a tag at the current head and push it
1250
+
1251
+ If remote is None, will just be updated locally
1252
+
1253
+ If no message is provided, the tag will be lightweight. if a message is
1254
+ provided, the tag will be annotated.
1255
+
1256
+ Args:
1257
+ tag_name (`str`):
1258
+ The name of the tag to be added.
1259
+ message (`str`, *optional*):
1260
+ The message that accompanies the tag. The tag will turn into an
1261
+ annotated tag if a message is passed.
1262
+ remote (`str`, *optional*):
1263
+ The remote on which to add the tag.
1264
+ """
1265
+ if message:
1266
+ tag_args = ["git", "tag", "-a", tag_name, "-m", message]
1267
+ else:
1268
+ tag_args = ["git", "tag", tag_name]
1269
+
1270
+ try:
1271
+ run_subprocess(tag_args, self.local_dir).stdout.strip()
1272
+ except subprocess.CalledProcessError as exc:
1273
+ raise EnvironmentError(exc.stderr)
1274
+
1275
+ if remote:
1276
+ try:
1277
+ run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip()
1278
+ except subprocess.CalledProcessError as exc:
1279
+ raise EnvironmentError(exc.stderr)
1280
+
1281
+ def is_repo_clean(self) -> bool:
1282
+ """
1283
+ Return whether or not the git status is clean or not
1284
+
1285
+ Returns:
1286
+ `bool`: `True` if the git status is clean, `False` otherwise.
1287
+ """
1288
+ try:
1289
+ git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip()
1290
+ except subprocess.CalledProcessError as exc:
1291
+ raise EnvironmentError(exc.stderr)
1292
+
1293
+ return len(git_status) == 0
1294
+
1295
+ def push_to_hub(
1296
+ self,
1297
+ commit_message: str = "commit files to HF hub",
1298
+ blocking: bool = True,
1299
+ clean_ok: bool = True,
1300
+ auto_lfs_prune: bool = False,
1301
+ ) -> Union[None, str, Tuple[str, CommandInProgress]]:
1302
+ """
1303
+ Helper to add, commit, and push files to remote repository on the
1304
+ HuggingFace Hub. Will automatically track large files (>10MB).
1305
+
1306
+ Args:
1307
+ commit_message (`str`):
1308
+ Message to use for the commit.
1309
+ blocking (`bool`, *optional*, defaults to `True`):
1310
+ Whether the function should return only when the `git push` has
1311
+ finished.
1312
+ clean_ok (`bool`, *optional*, defaults to `True`):
1313
+ If True, this function will return None if the repo is
1314
+ untouched. Default behavior is to fail because the git command
1315
+ fails.
1316
+ auto_lfs_prune (`bool`, *optional*, defaults to `False`):
1317
+ Whether to automatically prune files once they have been pushed
1318
+ to the remote.
1319
+ """
1320
+ if clean_ok and self.is_repo_clean():
1321
+ logger.info("Repo currently clean. Ignoring push_to_hub")
1322
+ return None
1323
+ self.git_add(auto_lfs_track=True)
1324
+ self.git_commit(commit_message)
1325
+ return self.git_push(
1326
+ upstream=f"origin {self.current_branch}",
1327
+ blocking=blocking,
1328
+ auto_lfs_prune=auto_lfs_prune,
1329
+ )
1330
+
1331
+ @contextmanager
1332
+ def commit(
1333
+ self,
1334
+ commit_message: str,
1335
+ branch: Optional[str] = None,
1336
+ track_large_files: bool = True,
1337
+ blocking: bool = True,
1338
+ auto_lfs_prune: bool = False,
1339
+ ):
1340
+ """
1341
+ Context manager utility to handle committing to a repository. This
1342
+ automatically tracks large files (>10Mb) with git-lfs. Set the
1343
+ `track_large_files` argument to `False` if you wish to ignore that
1344
+ behavior.
1345
+
1346
+ Args:
1347
+ commit_message (`str`):
1348
+ Message to use for the commit.
1349
+ branch (`str`, *optional*):
1350
+ The branch on which the commit will appear. This branch will be
1351
+ checked-out before any operation.
1352
+ track_large_files (`bool`, *optional*, defaults to `True`):
1353
+ Whether to automatically track large files or not. Will do so by
1354
+ default.
1355
+ blocking (`bool`, *optional*, defaults to `True`):
1356
+ Whether the function should return only when the `git push` has
1357
+ finished.
1358
+ auto_lfs_prune (`bool`, defaults to `True`):
1359
+ Whether to automatically prune files once they have been pushed
1360
+ to the remote.
1361
+
1362
+ Examples:
1363
+
1364
+ ```python
1365
+ >>> with Repository(
1366
+ ... "text-files",
1367
+ ... clone_from="<user>/text-files",
1368
+ ... token=True,
1369
+ >>> ).commit("My first file :)"):
1370
+ ... with open("file.txt", "w+") as f:
1371
+ ... f.write(json.dumps({"hey": 8}))
1372
+
1373
+ >>> import torch
1374
+
1375
+ >>> model = torch.nn.Transformer()
1376
+ >>> with Repository(
1377
+ ... "torch-model",
1378
+ ... clone_from="<user>/torch-model",
1379
+ ... token=True,
1380
+ >>> ).commit("My cool model :)"):
1381
+ ... torch.save(model.state_dict(), "model.pt")
1382
+ ```
1383
+
1384
+ """
1385
+
1386
+ files_to_stage = files_to_be_staged(".", folder=self.local_dir)
1387
+
1388
+ if len(files_to_stage):
1389
+ files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage)
1390
+ logger.error(
1391
+ "There exists some updated files in the local repository that are not"
1392
+ f" committed: {files_in_msg}. This may lead to errors if checking out"
1393
+ " a branch. These files and their modifications will be added to the"
1394
+ " current commit."
1395
+ )
1396
+
1397
+ if branch is not None:
1398
+ self.git_checkout(branch, create_branch_ok=True)
1399
+
1400
+ if is_tracked_upstream(self.local_dir):
1401
+ logger.warning("Pulling changes ...")
1402
+ self.git_pull(rebase=True)
1403
+ else:
1404
+ logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'")
1405
+
1406
+ current_working_directory = os.getcwd()
1407
+ os.chdir(os.path.join(current_working_directory, self.local_dir))
1408
+
1409
+ try:
1410
+ yield self
1411
+ finally:
1412
+ self.git_add(auto_lfs_track=track_large_files)
1413
+
1414
+ try:
1415
+ self.git_commit(commit_message)
1416
+ except OSError as e:
1417
+ # If no changes are detected, there is nothing to commit.
1418
+ if "nothing to commit" not in str(e):
1419
+ raise e
1420
+
1421
+ try:
1422
+ self.git_push(
1423
+ upstream=f"origin {self.current_branch}",
1424
+ blocking=blocking,
1425
+ auto_lfs_prune=auto_lfs_prune,
1426
+ )
1427
+ except OSError as e:
1428
+ # If no changes are detected, there is nothing to commit.
1429
+ if "could not read Username" in str(e):
1430
+ raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e
1431
+ else:
1432
+ raise e
1433
+
1434
+ os.chdir(current_working_directory)
1435
+
1436
+ def repocard_metadata_load(self) -> Optional[Dict]:
1437
+ filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME)
1438
+ if os.path.isfile(filepath):
1439
+ return metadata_load(filepath)
1440
+ return None
1441
+
1442
+ def repocard_metadata_save(self, data: Dict) -> None:
1443
+ return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data)
1444
+
1445
+ @property
1446
+ def commands_failed(self):
1447
+ """
1448
+ Returns the asynchronous commands that failed.
1449
+ """
1450
+ return [c for c in self.command_queue if c.status > 0]
1451
+
1452
+ @property
1453
+ def commands_in_progress(self):
1454
+ """
1455
+ Returns the asynchronous commands that are currently in progress.
1456
+ """
1457
+ return [c for c in self.command_queue if not c.is_done]
1458
+
1459
+ def wait_for_commands(self):
1460
+ """
1461
+ Blocking method: blocks all subsequent execution until all commands have
1462
+ been processed.
1463
+ """
1464
+ index = 0
1465
+ for command_failed in self.commands_failed:
1466
+ logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.")
1467
+ logger.error(command_failed.stderr)
1468
+
1469
+ while self.commands_in_progress:
1470
+ if index % 10 == 0:
1471
+ logger.warning(
1472
+ f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}."
1473
+ )
1474
+
1475
+ index += 1
1476
+
1477
+ time.sleep(1)