File size: 3,661 Bytes
a7d1c2d
 
 
 
 
 
 
 
1f8e434
a7d1c2d
 
1f8e434
 
 
a7d1c2d
1f8e434
 
a7d1c2d
1f8e434
 
a7d1c2d
1f8e434
 
a7d1c2d
 
9bfc13e
 
1f8e434
a7d1c2d
 
1f8e434
a7d1c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f8e434
a7d1c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
import time
import os

import pandas as pd
import requests
from datasets import load_dataset, concatenate_datasets

import argilla as rg
from argilla.listeners import listener

### Configuration section ###

# needed for pushing the validated data to HUB_DATASET_NAME
HF_TOKEN = os.environ.get("HF_TOKEN")

# The source dataset to read Alpaca translated examples
SOURCE_DATASET = "LEL-A/translated_german_alpaca"

# The name of the dataset in Argilla
RG_DATASET_NAME = "translated-german-alpaca"

# The name of the Hub dataset to push the validations every 20 min and keep the dataset synced
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation")

# The labels for the task (they can be extended if needed)
LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"]

@listener(
    dataset=RG_DATASET_NAME, 
    query="status:Validated", 
    execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub`
)
def save_validated_to_hub(records, ctx):
    if len(records) > 0:
        ds = rg.DatasetForTextClassification(records=records).to_datasets()   
        if HF_TOKEN:
            print("Pushing the dataset")
            print(ds)
            ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN)
        else:
            print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!")
    else:
        print("NO RECORDS found")

class LoadDatasets:
    def __init__(self, api_key, workspace="team"):
        rg.init(api_key=api_key, workspace=workspace)

    @staticmethod
    def load_somos():
        # Leer el dataset del Hub
        try:
            print(f"Trying to sync with {HUB_DATASET_NAME}")
            old_ds = load_dataset(HUB_DATASET_NAME, split="train")
        except Exception as e:
            print(f"Not possible to sync with {HUB_DATASET_NAME}")
            print(e)
            old_ds = None

        print(f"Loading dataset: {SOURCE_DATASET}")
        dataset = load_dataset(SOURCE_DATASET, split="train")
    
        
        if old_ds:
            print("Concatenating datasets")
            dataset = concatenate_datasets([dataset, old_ds])
            print("Concatenated dataset is:")
            print(dataset)
            
        dataset = dataset.remove_columns("metrics")
        records = rg.DatasetForTextClassification.from_datasets(dataset)

        settings = rg.TextClassificationSettings(
            label_schema=LABELS
        )
        
        print(f"Configuring dataset: {RG_DATASET_NAME}")
        rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team")
        
        # Log the dataset
        print(f"Logging dataset: {RG_DATASET_NAME}")
        rg.log(
            records,
            name=RG_DATASET_NAME,
            tags={"description": "Alpaca dataset to clean up"},
            batch_size=200
        )
        
        # run listener
        save_validated_to_hub.start()

if __name__ == "__main__":
    API_KEY = sys.argv[1]
    LOAD_DATASETS = sys.argv[2]

    if LOAD_DATASETS.lower() == "none":
        print("No datasets being loaded")
    else:
        while True:
            try:
                response = requests.get("http://0.0.0.0:6900/")
                if response.status_code == 200:
                    ld = LoadDatasets(API_KEY)
                    ld.load_somos()
                    break

            except requests.exceptions.ConnectionError:
                pass
            except Exception as e:
                print(e)
                time.sleep(10)
                pass

            time.sleep(5)
    while True:
        time.sleep(60)