John6666 commited on
Commit
19ddc2b
Β·
verified Β·
1 Parent(s): 8ae0bed

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +13 -12
  2. app.py +105 -0
  3. requirements.txt +10 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Trtest1
3
- emoji: πŸ‘
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: test train
3
+ emoji: πŸ™„
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import torchvision
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ from huggingface_hub import HfApi, HfFolder, Repository
9
+ from transformers import ViTForImageClassification, Trainer, TrainingArguments
10
+ from datasets import load_dataset
11
+ from sklearn.metrics import accuracy_score
12
+
13
+
14
+ @spaces.GPU
15
+ def dummy_gpu():
16
+ pass
17
+
18
+
19
+ HF_MODEL = "google/vit-base-patch16-224"
20
+ HF_DATASET = "verytuffcat/recaptcha-dataset"
21
+
22
+
23
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
24
+ if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
25
+ if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
26
+ if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL")
27
+ OUT_DIR = "./new_model"
28
+
29
+
30
+ def compute_metrics(eval_pred):
31
+ predictions, labels = eval_pred
32
+ predictions = np.argmax(predictions, axis=1)
33
+ return dict(accuracy=accuracy_score(predictions, labels))
34
+
35
+
36
+ def collate_fn(batch):
37
+ pixel_values = torch.stack([torchvision.transforms.functional.to_tensor(x["image"].convert("RGB").resize((224, 224), Image.BICUBIC)) for x in batch])
38
+ labels = torch.tensor([x["label"] for x in batch])
39
+ return {"pixel_values": pixel_values, "labels": labels}
40
+
41
+
42
+ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: str, progress=gr.Progress(track_tqdm=True)):
43
+ try:
44
+ if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
45
+ if not hf_token: hf_token = HF_TOKEN
46
+ if not hf_token: raise gr.Error("Input HF token.")
47
+ HfFolder.save_token(hf_token)
48
+
49
+ model = ViTForImageClassification.from_pretrained(model_id)
50
+ dataset = load_dataset(dataset_id, split="train")
51
+
52
+ training_args = TrainingArguments(
53
+ output_dir=OUT_DIR,
54
+ use_cpu=True,
55
+ no_cuda=True,
56
+ fp16=True,
57
+ optim="adamw_torch",
58
+ lr_scheduler_type="linear",
59
+ learning_rate=0.00005,
60
+ per_device_train_batch_size=8,
61
+ num_train_epochs=3,
62
+ gradient_accumulation_steps=1,
63
+ use_ipex=True,
64
+ eval_strategy="no",
65
+ logging_strategy="no",
66
+ remove_unused_columns=False,
67
+ push_to_hub=False,
68
+ save_total_limit=2,
69
+ report_to="none"
70
+ )
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=training_args,
74
+ data_collator=collate_fn,
75
+ compute_metrics=compute_metrics,
76
+ train_dataset=dataset,
77
+ eval_dataset=None,
78
+ )
79
+ trainer.train()
80
+ trainer.save_model(OUT_DIR)
81
+
82
+ api = HfApi(token=hf_token)
83
+ api.create_repo(repo_id=repo_id, private=True, token=hf_token)
84
+ repo = Repository(local_dir=OUT_DIR, clone_from=repo_id, use_auth_token=hf_token)
85
+ repo.push_to_hub()
86
+
87
+ return log_md
88
+ except Exception as e:
89
+ raise gr.Error(f"Error occured: {e}")
90
+
91
+
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ model_id = gr.Textbox(label="Source model", value=HF_MODEL, lines=1)
95
+ dataset_id = gr.Textbox(label="Source dataset", value=HF_DATASET, lines=1)
96
+ with gr.Row():
97
+ repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
98
+ hf_token = gr.Textbox(label="HF write token", value="", lines=1)
99
+ train_btn = gr.Button("Train")
100
+ log_md = gr.Markdown(label="Log", value="<br><br>")
101
+
102
+ train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, log_md], [log_md])
103
+
104
+
105
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ transformers
3
+ torch
4
+ torchvision
5
+ numpy<2
6
+ scikit-learn
7
+ accelerate
8
+ optimum[ipex]
9
+ intel-extension-for-pytorch
10
+ #oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/