Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,400 Bytes
87d40d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# ์ฌ๋ฌ GPU๋ฅผ ์ฌ์ฉํ ๋ถ์ฐ ์ถ๋ก
๋ถ์ฐ ์ค์ ์์๋ ์ฌ๋ฌ ๊ฐ์ ํ๋กฌํํธ๋ฅผ ๋์์ ์์ฑํ ๋ ์ ์ฉํ ๐ค [Accelerate](https://huggingface.co/docs/accelerate/index) ๋๋ [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ GPU์์ ์ถ๋ก ์ ์คํํ ์ ์์ต๋๋ค.
์ด ๊ฐ์ด๋์์๋ ๋ถ์ฐ ์ถ๋ก ์ ์ํด ๐ค Accelerate์ PyTorch Distributed๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ๋๋ฆฝ๋๋ค.
## ๐ค Accelerate
๐ค [Accelerate](https://huggingface.co/docs/accelerate/index)๋ ๋ถ์ฐ ์ค์ ์์ ์ถ๋ก ์ ์ฝ๊ฒ ํ๋ จํ๊ฑฐ๋ ์คํํ ์ ์๋๋ก ์ค๊ณ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์
๋๋ค. ๋ถ์ฐ ํ๊ฒฝ ์ค์ ํ๋ก์ธ์ค๋ฅผ ๊ฐ์ํํ์ฌ PyTorch ์ฝ๋์ ์ง์คํ ์ ์๋๋ก ํด์ค๋๋ค.
์์ํ๋ ค๋ฉด Python ํ์ผ์ ์์ฑํ๊ณ [`accelerate.PartialState`]๋ฅผ ์ด๊ธฐํํ์ฌ ๋ถ์ฐ ํ๊ฒฝ์ ์์ฑํ๋ฉด, ์ค์ ์ด ์๋์ผ๋ก ๊ฐ์ง๋๋ฏ๋ก `rank` ๋๋ `world_size`๋ฅผ ๋ช
์์ ์ผ๋ก ์ ์ํ ํ์๊ฐ ์์ต๋๋ค. ['DiffusionPipeline`]์ `distributed_state.device`๋ก ์ด๋ํ์ฌ ๊ฐ ํ๋ก์ธ์ค์ GPU๋ฅผ ํ ๋นํฉ๋๋ค.
์ด์ ์ปจํ
์คํธ ๊ด๋ฆฌ์๋ก [`~accelerate.PartialState.split_between_processes`] ์ ํธ๋ฆฌํฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ก์ธ์ค ์์ ๋ฐ๋ผ ํ๋กฌํํธ๋ฅผ ์๋์ผ๋ก ๋ถ๋ฐฐํฉ๋๋ค.
```py
from accelerate import PartialState
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipeline.to(distributed_state.device)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
result = pipeline(prompt).images[0]
result.save(f"result_{distributed_state.process_index}.png")
```
Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
```bash
accelerate launch run_distributed.py --num_processes=2
```
<Tip>์์ธํ ๋ด์ฉ์ [๐ค Accelerate๋ฅผ ์ฌ์ฉํ ๋ถ์ฐ ์ถ๋ก ](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
</Tip>
## Pytoerch ๋ถ์ฐ
PyTorch๋ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)์ ์ง์ํฉ๋๋ค.
์์ํ๋ ค๋ฉด Python ํ์ผ์ ์์ฑํ๊ณ `torch.distributed` ๋ฐ `torch.multiprocessing`์ ์ํฌํธํ์ฌ ๋ถ์ฐ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์ค์ ํ๊ณ ๊ฐ GPU์์ ์ถ๋ก ์ฉ ํ๋ก์ธ์ค๋ฅผ ์์ฑํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ [`DiffusionPipeline`]๋ ์ด๊ธฐํํด์ผ ํฉ๋๋ค:
ํ์ฐ ํ์ดํ๋ผ์ธ์ `rank`๋ก ์ด๋ํ๊ณ `get_rank`๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ํ๋ก์ธ์ค์ GPU๋ฅผ ํ ๋นํ๋ฉด ๊ฐ ํ๋ก์ธ์ค๊ฐ ๋ค๋ฅธ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค:
```py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers import DiffusionPipeline
sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
์ฌ์ฉํ ๋ฐฑ์๋ ์ ํ, ํ์ฌ ํ๋ก์ธ์ค์ `rank`, `world_size` ๋๋ ์ฐธ์ฌํ๋ ํ๋ก์ธ์ค ์๋ก ๋ถ์ฐ ํ๊ฒฝ ์์ฑ์ ์ฒ๋ฆฌํ๋ ํจ์[`init_process_group`]๋ฅผ ๋ง๋ค์ด ์ถ๋ก ์ ์คํํด์ผ ํฉ๋๋ค.
2๊ฐ์ GPU์์ ์ถ๋ก ์ ๋ณ๋ ฌ๋ก ์คํํ๋ ๊ฒฝ์ฐ `world_size`๋ 2์
๋๋ค.
```py
def run_inference(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
sd.to(rank)
if torch.distributed.get_rank() == 0:
prompt = "a dog"
elif torch.distributed.get_rank() == 1:
prompt = "a cat"
image = sd(prompt).images[0]
image.save(f"./{'_'.join(prompt)}.png")
```
๋ถ์ฐ ์ถ๋ก ์ ์คํํ๋ ค๋ฉด [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)์ ํธ์ถํ์ฌ `world_size`์ ์ ์๋ GPU ์์ ๋ํด `run_inference` ํจ์๋ฅผ ์คํํฉ๋๋ค:
```py
def main():
world_size = 2
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
```
์ถ๋ก ์คํฌ๋ฆฝํธ๋ฅผ ์๋ฃํ์ผ๋ฉด `--nproc_per_node` ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉํ GPU ์๋ฅผ ์ง์ ํ๊ณ `torchrun`์ ํธ์ถํ์ฌ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํฉ๋๋ค:
```bash
torchrun run_distributed.py --nproc_per_node=2
``` |