davanstrien HF staff commited on
Commit
5721477
·
1 Parent(s): b558f4f

schedule collections refresh

Browse files
Files changed (2) hide show
  1. create_collections.py +128 -0
  2. main.py +20 -0
create_collections.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Iterator, List
2
+
3
+ import requests
4
+ from huggingface_hub import add_collection_item, create_collection
5
+ from tqdm.auto import tqdm
6
+
7
+
8
+ class DatasetSearchClient:
9
+ def __init__(
10
+ self,
11
+ base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space",
12
+ ):
13
+ self.base_url = base_url
14
+
15
+ def search(
16
+ self, columns: List[str], match_all: bool = False, page_size: int = 100
17
+ ) -> Iterator[Dict[str, Any]]:
18
+ """
19
+ Search datasets using the provided API, automatically handling pagination.
20
+
21
+ Args:
22
+ columns (List[str]): List of column names to search for.
23
+ match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False.
24
+ page_size (int, optional): Number of results per page. Defaults to 100.
25
+
26
+ Yields:
27
+ Dict[str, Any]: Each dataset result from all pages.
28
+
29
+ Raises:
30
+ requests.RequestException: If there's an error with the HTTP request.
31
+ ValueError: If the API returns an unexpected response format.
32
+ """
33
+ page = 1
34
+ total_results = None
35
+
36
+ while total_results is None or (page - 1) * page_size < total_results:
37
+ params = {
38
+ "columns": columns,
39
+ "match_all": str(match_all).lower(),
40
+ "page": page,
41
+ "page_size": page_size,
42
+ }
43
+
44
+ try:
45
+ response = requests.get(f"{self.base_url}/search", params=params)
46
+ response.raise_for_status()
47
+ data = response.json()
48
+
49
+ if not {"total", "page", "page_size", "results"}.issubset(data.keys()):
50
+ raise ValueError("Unexpected response format from the API")
51
+
52
+ if total_results is None:
53
+ total_results = data["total"]
54
+
55
+ yield from data["results"]
56
+ page += 1
57
+
58
+ except requests.RequestException as e:
59
+ raise requests.RequestException(
60
+ f"Error connecting to the API: {str(e)}"
61
+ ) from e
62
+ except ValueError as e:
63
+ raise ValueError(f"Error processing API response: {str(e)}") from e
64
+
65
+
66
+ # Create an instance of the client
67
+ client = DatasetSearchClient()
68
+
69
+
70
+ def update_collection_for_dataset(
71
+ collection_name: str = None,
72
+ dataset_columns: List[str] = None,
73
+ collection_description: str = None,
74
+ collection_namespace: str = None,
75
+ ):
76
+ if not collection_name:
77
+ collection = create_collection(
78
+ collection_name, exists_ok=True, description=collection_description
79
+ )
80
+ else:
81
+ collection = create_collection(
82
+ collection_name,
83
+ exists_ok=True,
84
+ description=collection_description,
85
+ namespace=collection_namespace,
86
+ )
87
+ results = list(
88
+ tqdm(
89
+ client.search(dataset_columns, match_all=True),
90
+ desc="Searching datasets...",
91
+ leave=False,
92
+ )
93
+ )
94
+ for result in tqdm(results, desc="Adding datasets to collection...", leave=False):
95
+ try:
96
+ add_collection_item(
97
+ collection.slug, result["hub_id"], item_type="dataset", exists_ok=True
98
+ )
99
+ except Exception as e:
100
+ print(
101
+ f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}"
102
+ )
103
+ return f"https://huggingface.co/collections/{collection.slug}"
104
+
105
+
106
+ collections = [
107
+ {
108
+ "dataset_columns": ["chosen", "rejected", "prompt"],
109
+ "collection_description": "Datasets suitable for Direct Preference Optimization based on having 'chosen', 'rejected', and 'prompt' columns",
110
+ "collection_name": "Direct Preference Optimization Datasets",
111
+ },
112
+ {
113
+ "dataset_columns": ["image", "chosen", "rejected"],
114
+ "collection_description": "Datasets suitable for Image Preference Optimization based on having 'image','chosen', and 'rejected' columns",
115
+ "collection_name": "Image Preference Optimization Datasets",
116
+ },
117
+ {
118
+ "collection_name": "Alpaca Style Datasets",
119
+ "dataset_columns": ["instruction", "input", "output"],
120
+ "collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns",
121
+ },
122
+ ]
123
+
124
+ results = [
125
+ update_collection_for_dataset(**collection, collection_namespace="librarian-bots")
126
+ for collection in collections
127
+ ]
128
+ print(results)
main.py CHANGED
@@ -17,6 +17,7 @@ from pandas import Timestamp
17
  from pydantic import BaseModel
18
  from starlette.responses import RedirectResponse
19
 
 
20
  from data_loader import refresh_data
21
 
22
  login(token=os.getenv("HF_TOKEN"))
@@ -163,6 +164,23 @@ async def update_database():
163
  logger.error(f"Error uploading database file to Hugging Face Hub: {str(e)}")
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  @asynccontextmanager
167
  async def lifespan(app: FastAPI):
168
  setup_database()
@@ -173,6 +191,8 @@ async def lifespan(app: FastAPI):
173
  scheduler = AsyncIOScheduler()
174
  # Schedule the update_database function using the UPDATE_SCHEDULE configuration
175
  scheduler.add_job(update_database, CronTrigger(**UPDATE_SCHEDULE))
 
 
176
  scheduler.start()
177
 
178
  yield
 
17
  from pydantic import BaseModel
18
  from starlette.responses import RedirectResponse
19
 
20
+ from create_collections import collections, update_collection_for_dataset
21
  from data_loader import refresh_data
22
 
23
  login(token=os.getenv("HF_TOKEN"))
 
164
  logger.error(f"Error uploading database file to Hugging Face Hub: {str(e)}")
165
 
166
 
167
+ async def update_collections():
168
+ logger.info("Starting scheduled collection update")
169
+ try:
170
+ for collection in collections:
171
+ result = await asyncio.get_event_loop().run_in_executor(
172
+ None,
173
+ update_collection_for_dataset,
174
+ collection["collection_name"],
175
+ collection["dataset_columns"],
176
+ collection["collection_description"],
177
+ "librarian-bots",
178
+ )
179
+ logger.info(f"Updated collection: {result}")
180
+ except Exception as e:
181
+ logger.error(f"Error during collection update: {str(e)}")
182
+
183
+
184
  @asynccontextmanager
185
  async def lifespan(app: FastAPI):
186
  setup_database()
 
191
  scheduler = AsyncIOScheduler()
192
  # Schedule the update_database function using the UPDATE_SCHEDULE configuration
193
  scheduler.add_job(update_database, CronTrigger(**UPDATE_SCHEDULE))
194
+ # Schedule the update_collections function to run daily at midnight
195
+ scheduler.add_job(update_collections, CronTrigger(hour=0, minute=0))
196
  scheduler.start()
197
 
198
  yield