File size: 4,400 Bytes
87d40d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# ์—ฌ๋Ÿฌ GPU๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก 

๋ถ„์‚ฐ ์„ค์ •์—์„œ๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋™์‹œ์— ์ƒ์„ฑํ•  ๋•Œ ์œ ์šฉํ•œ ๐Ÿค— [Accelerate](https://huggingface.co/docs/accelerate/index) ๋˜๋Š” [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ GPU์—์„œ ์ถ”๋ก ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋ถ„์‚ฐ ์ถ”๋ก ์„ ์œ„ํ•ด ๐Ÿค— Accelerate์™€ PyTorch Distributed๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

## ๐Ÿค— Accelerate

๐Ÿค— [Accelerate](https://huggingface.co/docs/accelerate/index)๋Š” ๋ถ„์‚ฐ ์„ค์ •์—์„œ ์ถ”๋ก ์„ ์‰ฝ๊ฒŒ ํ›ˆ๋ จํ•˜๊ฑฐ๋‚˜ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์„ค์ • ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„์†Œํ™”ํ•˜์—ฌ PyTorch ์ฝ”๋“œ์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ด์ค๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  [`accelerate.PartialState`]๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜์—ฌ ๋ถ„์‚ฐ ํ™˜๊ฒฝ์„ ์ƒ์„ฑํ•˜๋ฉด, ์„ค์ •์ด ์ž๋™์œผ๋กœ ๊ฐ์ง€๋˜๋ฏ€๋กœ `rank` ๋˜๋Š” `world_size`๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ •์˜ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ['DiffusionPipeline`]์„ `distributed_state.device`๋กœ ์ด๋™ํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž๋กœ [`~accelerate.PartialState.split_between_processes`] ์œ ํ‹ธ๋ฆฌํ‹ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ”„๋กœ์„ธ์Šค ์ˆ˜์— ๋”ฐ๋ผ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž๋™์œผ๋กœ ๋ถ„๋ฐฐํ•ฉ๋‹ˆ๋‹ค.


```py
from accelerate import PartialState
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipeline.to(distributed_state.device)

with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
    result = pipeline(prompt).images[0]
    result.save(f"result_{distributed_state.process_index}.png")
```

Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:

```bash
accelerate launch run_distributed.py --num_processes=2
```

<Tip>์ž์„ธํ•œ ๋‚ด์šฉ์€ [๐Ÿค— Accelerate๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก ](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

</Tip>

## Pytoerch ๋ถ„์‚ฐ

PyTorch๋Š” ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  `torch.distributed` ๋ฐ `torch.multiprocessing`์„ ์ž„ํฌํŠธํ•˜์—ฌ ๋ถ„์‚ฐ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์„ ์„ค์ •ํ•˜๊ณ  ๊ฐ GPU์—์„œ ์ถ”๋ก ์šฉ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  [`DiffusionPipeline`]๋„ ์ดˆ๊ธฐํ™”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

ํ™•์‚ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ `rank`๋กœ ์ด๋™ํ•˜๊ณ  `get_rank`๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•˜๋ฉด ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค:

```py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers import DiffusionPipeline

sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```

์‚ฌ์šฉํ•  ๋ฐฑ์—”๋“œ ์œ ํ˜•, ํ˜„์žฌ ํ”„๋กœ์„ธ์Šค์˜ `rank`, `world_size` ๋˜๋Š” ์ฐธ์—ฌํ•˜๋Š” ํ”„๋กœ์„ธ์Šค ์ˆ˜๋กœ ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์ƒ์„ฑ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜[`init_process_group`]๋ฅผ ๋งŒ๋“ค์–ด ์ถ”๋ก ์„ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

2๊ฐœ์˜ GPU์—์„œ ์ถ”๋ก ์„ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ `world_size`๋Š” 2์ž…๋‹ˆ๋‹ค.

```py
def run_inference(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    sd.to(rank)

    if torch.distributed.get_rank() == 0:
        prompt = "a dog"
    elif torch.distributed.get_rank() == 1:
        prompt = "a cat"

    image = sd(prompt).images[0]
    image.save(f"./{'_'.join(prompt)}.png")
```

๋ถ„์‚ฐ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)์„ ํ˜ธ์ถœํ•˜์—ฌ `world_size`์— ์ •์˜๋œ GPU ์ˆ˜์— ๋Œ€ํ•ด `run_inference` ํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

```py
def main():
    world_size = 2
    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()
```

์ถ”๋ก  ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์™„๋ฃŒํ–ˆ์œผ๋ฉด `--nproc_per_node` ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉํ•  GPU ์ˆ˜๋ฅผ ์ง€์ •ํ•˜๊ณ  `torchrun`์„ ํ˜ธ์ถœํ•˜์—ฌ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

```bash
torchrun run_distributed.py --nproc_per_node=2
```