File size: 1,986 Bytes
b91e31d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18db29a
b91e31d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
##
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+train_dataloader, eval_dataloader, model, optimizer, scheduler = (
+    accelerator.prepare(
+        train_dataloader, eval_dataloader, 
+        model, optimizer, scheduler
+    )
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

model.eval()
for batch in eval_dataloader:
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    with torch.no_grad():
        outputs = model(inputs)
    predictions = outputs.argmax(dim=-1)
+    predictions, references = accelerator.gather_for_metrics(
+        (predictions, references)
+    )
    metric.add_batch(
        predictions = predictions,
        references = references
    )
print(metric.compute())</pre>

##
When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics`
method to gather the predictions and references from all devices and then calculate the metric on the gathered values. 
This will also *automatically* drop the padded values from the gathered tensors that were added to ensure 
that all tensors have the same length. This ensures that the metric is calculated on the correct values.
##
To learn more checkout the related documentation:

- <a href="https://huggingface.co/docs/accelerate/en/quicktour#distributed-evaluation" target="_blank">Quicktour - Calculating metrics</a>
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics" target="_blank">API reference</a>
- <a href="https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py" target="_blank">Example script</a>