Imane Momayiz commited on
Commit
c437348
1 Parent(s): ae41ba2

test commitscheduler

Browse files
Files changed (1) hide show
  1. src/components.py +140 -0
src/components.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi, CommitScheduler
2
+ from typing import Any, Dict, List, Optional, Union
3
+ import uuid
4
+ from pathlib import Path
5
+ import json
6
+ import tempfile
7
+ import pyarrow as pa
8
+ import pyarrow.parquet as pq
9
+
10
+
11
+ # Initialize the ParquetScheduler
12
+ class ParquetScheduler(CommitScheduler):
13
+ """
14
+ Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
15
+ call will result in 1 row in your final dataset.
16
+ ```py
17
+ # Start scheduler
18
+ >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
19
+ # Append some data to be uploaded
20
+ >>> scheduler.append({...})
21
+ >>> scheduler.append({...})
22
+ >>> scheduler.append({...})
23
+ ```
24
+ The scheduler will automatically infer the schema from the data it pushes.
25
+ Optionally, you can manually set the schema yourself:
26
+ ```py
27
+ >>> scheduler = ParquetScheduler(
28
+ ... repo_id="my-parquet-dataset",
29
+ ... schema={
30
+ ... "prompt": {"_type": "Value", "dtype": "string"},
31
+ ... "negative_prompt": {"_type": "Value", "dtype": "string"},
32
+ ... "guidance_scale": {"_type": "Value", "dtype": "int64"},
33
+ ... "image": {"_type": "Image"},
34
+ ... },
35
+ ... )
36
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
37
+ possible values.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ repo_id: str,
44
+ schema: Optional[Dict[str, Dict[str, str]]] = None,
45
+ every: Union[int, float] = 5,
46
+ path_in_repo: Optional[str] = "data",
47
+ repo_type: Optional[str] = "dataset",
48
+ revision: Optional[str] = None,
49
+ private: bool = False,
50
+ token: Optional[str] = None,
51
+ allow_patterns: Union[List[str], str, None] = None,
52
+ ignore_patterns: Union[List[str], str, None] = None,
53
+ hf_api: Optional[HfApi] = None,
54
+ ) -> None:
55
+ super().__init__(
56
+ repo_id=repo_id,
57
+ folder_path="dummy", # not used by the scheduler
58
+ every=every,
59
+ path_in_repo=path_in_repo,
60
+ repo_type=repo_type,
61
+ revision=revision,
62
+ private=private,
63
+ token=token,
64
+ allow_patterns=allow_patterns,
65
+ ignore_patterns=ignore_patterns,
66
+ hf_api=hf_api,
67
+ )
68
+
69
+ self._rows: List[Dict[str, Any]] = []
70
+ self._schema = schema
71
+
72
+ def append(self, row: Dict[str, Any]) -> None:
73
+ """Add a new item to be uploaded."""
74
+ with self.lock:
75
+ self._rows.append(row)
76
+
77
+ def push_to_hub(self):
78
+ # Check for new rows to push
79
+ with self.lock:
80
+ rows = self._rows
81
+ self._rows = []
82
+ if not rows:
83
+ return
84
+ print(f"Got {len(rows)} item(s) to commit.")
85
+
86
+ # Load images + create 'features' config for datasets library
87
+ schema: Dict[str, Dict] = self._schema or {}
88
+ path_to_cleanup: List[Path] = []
89
+ for row in rows:
90
+ for key, value in row.items():
91
+ # Infer schema (for `datasets` library)
92
+ if key not in schema:
93
+ schema[key] = _infer_schema(key, value)
94
+
95
+ # Load binary files if necessary
96
+ if schema[key]["_type"] in ("Image", "Audio"):
97
+ # It's an image or audio: we load the bytes and remember to cleanup the file
98
+ file_path = Path(value)
99
+ if file_path.is_file():
100
+ row[key] = {
101
+ "path": file_path.name,
102
+ "bytes": file_path.read_bytes(),
103
+ }
104
+ path_to_cleanup.append(file_path)
105
+
106
+ # Complete rows if needed
107
+ for row in rows:
108
+ for feature in schema:
109
+ if feature not in row:
110
+ row[feature] = None
111
+
112
+ # Export items to Arrow format
113
+ table = pa.Table.from_pylist(rows)
114
+
115
+ # Add metadata (used by datasets library)
116
+ table = table.replace_schema_metadata(
117
+ {"huggingface": json.dumps({"info": {"features": schema}})}
118
+ )
119
+
120
+ # Write to parquet file
121
+ archive_file = tempfile.NamedTemporaryFile()
122
+ pq.write_table(table, archive_file.name)
123
+
124
+ # Upload
125
+ self.api.upload_file(
126
+ repo_id=self.repo_id,
127
+ repo_type=self.repo_type,
128
+ revision=self.revision,
129
+ path_in_repo=f"{uuid.uuid4()}.parquet",
130
+ path_or_fileobj=archive_file.name,
131
+ )
132
+ print("Commit completed.")
133
+
134
+ # Cleanup
135
+ archive_file.close()
136
+ for path in path_to_cleanup:
137
+ path.unlink(missing_ok=True)
138
+
139
+
140
+