|
import time |
|
from copy import deepcopy |
|
|
|
import colossalai |
|
import torch |
|
from colossalai.shardformer import ShardConfig, ShardFormer |
|
from colossalai.testing import spawn |
|
|
|
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy |
|
from opensora.models.text_encoder.t5 import T5Embedder |
|
|
|
|
|
def run_t5_encoder(rank, world_size, port): |
|
colossalai.launch({}, rank=rank, world_size=world_size, port=port, host="localhost") |
|
|
|
|
|
t5_path = "./pretrained_models/t5_ckpts" |
|
hf_t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=t5_path, torch_dtype=torch.float) |
|
sf_t5 = deepcopy(hf_t5) |
|
|
|
|
|
shard_config = ShardConfig( |
|
tensor_parallel_process_group=None, |
|
pipeline_stage_manager=None, |
|
enable_tensor_parallelism=False, |
|
enable_fused_normalization=False, |
|
enable_flash_attention=False, |
|
enable_jit_fused=True, |
|
enable_sequence_parallelism=False, |
|
enable_sequence_overlap=False, |
|
) |
|
shard_former = ShardFormer(shard_config=shard_config) |
|
sharded_model, _ = shard_former.optimize(sf_t5.model, policy=T5EncoderPolicy()) |
|
sf_t5.model = sharded_model |
|
|
|
|
|
texts = ["Who is the best player in the history of NBA?", "How to study computer science?"] |
|
for i in range(5): |
|
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) |
|
sf_embs, sf_masks = sf_t5.get_text_embeddings(texts) |
|
|
|
|
|
assert torch.allclose(hf_embs, sf_embs, rtol=1e-4, atol=1e-5), f"{hf_embs} \nvs\n{sf_embs}" |
|
assert torch.allclose(hf_masks, sf_masks), f"{hf_masks} \nvs\n{sf_masks}" |
|
|
|
|
|
torch.cuda.synchronize() |
|
hf_start = time.time() |
|
for i in range(20): |
|
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) |
|
torch.cuda.synchronize() |
|
hf_end = time.time() |
|
|
|
|
|
hf_t5.model = hf_t5.model.half() |
|
torch.cuda.synchronize() |
|
sf_start = time.time() |
|
for i in range(20): |
|
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) |
|
torch.cuda.synchronize() |
|
sf_end = time.time() |
|
|
|
print(f"[Performance] native: {hf_end - hf_start}s, shardformer: {sf_end - sf_start} s") |
|
|
|
|
|
def test_t5_encoder(): |
|
spawn(run_t5_encoder) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_t5_encoder() |
|
|