|
import time |
|
import torch |
|
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, |
|
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, |
|
XLNetLMHeadModel, XLNetTokenizer, |
|
TransfoXLLMHeadModel, TransfoXLTokenizer, |
|
CTRLLMHeadModel, CTRLTokenizer) |
|
|
|
model_metadata = { |
|
"gpt2/small": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 550, |
|
"checkpoint": "gpt2", |
|
"identifier": "gpt2/small" |
|
}, "gpt": { |
|
"tokenizer": OpenAIGPTTokenizer, |
|
"model": OpenAIGPTLMHeadModel, |
|
"size": 550, |
|
"checkpoint": "openai-community/openai-gpt", |
|
"identifier": "gpt" |
|
}, "xlnet": { |
|
"tokenizer": XLNetTokenizer, |
|
"model": XLNetLMHeadModel, |
|
"size": 550, |
|
"checkpoint": "xlnet-base-cased", |
|
"identifier": "xlnet" |
|
}, "gpt2/arxiv-nlp": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 550, |
|
"checkpoint": "arxiv-nlp-v1", |
|
"identifier": "gpt2/arxiv-nlp" |
|
}, "gpt2/medium": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 1500, |
|
"checkpoint": "openai-community/gpt2-medium", |
|
"identifier": "gpt2/medium" |
|
}, "gpt2/large": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 3300, |
|
"checkpoint": "openai-community/gpt2-large", |
|
"identifier": "gpt2/large" |
|
}, "distilgpt2/small": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 350, |
|
"checkpoint": "distilgpt2", |
|
"identifier": "distilgpt2/small" |
|
}, "ctrl": { |
|
"tokenizer": CTRLTokenizer, |
|
"model": CTRLLMHeadModel, |
|
"size": 6300, |
|
"checkpoint": "Salesforce/ctrl", |
|
"identifier": "ctrl" |
|
}, "pplm": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 3000, |
|
"checkpoint": "openai-community/gpt2-large", |
|
"identifier": "pplm" |
|
}, "gpt2/xl": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 7000, |
|
"checkpoint": "openai-community/gpt2-xl", |
|
"identifier": "gpt2/xl" |
|
}, "pplm": { |
|
"tokenizer": GPT2Tokenizer, |
|
"model": GPT2LMHeadModel, |
|
"size": 4000, |
|
"checkpoint": "openai-community/gpt2-medium", |
|
"identifier": "pplm", |
|
"configuration_options": { |
|
"config": GPT2Config, |
|
"options": { |
|
"output_hidden_states": True |
|
} |
|
} |
|
} |
|
} |
|
|
|
memory_overhead = 500 |
|
|
|
class GPU: |
|
def __init__(self, id): |
|
self.id = id |
|
self.models = [] |
|
self.total_memory = torch.cuda.get_device_properties( |
|
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000 |
|
|
|
print("INIT GPU WITH DEVICE", "cuda:{}".format(id)) |
|
|
|
def register_model(self, model, cached_path=None): |
|
if self.total_memory_used() + model["size"] < self.total_memory: |
|
model["device"] = "cuda:{}".format(self.id) |
|
|
|
if cached_path: |
|
model["cached_path"] = cached_path |
|
|
|
self.models.append(model) |
|
return True |
|
else: |
|
return False |
|
|
|
def total_memory_used(self): |
|
return sum([model["size"] for model in self.models]) + memory_overhead |
|
|
|
def __repr__(self): |
|
return str( |
|
[(model["checkpoint"], model["size"]) for model in self.models] + |
|
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] + |
|
["cuda:{}".format(self.id)] |
|
) |
|
|
|
|
|
class GPUHandler: |
|
def __init__(self, ids, model_list, gpu_ids, cached_models=None): |
|
if cached_models is None: |
|
cached_models = {} |
|
|
|
self.gpus = [GPU(id) for id in gpu_ids] |
|
print("GPU handler initiated with {} gpus.".format(len(self.gpus))) |
|
|
|
self.sanity_check([model_metadata[model] for model in model_list]) |
|
|
|
for model in model_list: |
|
self.register_model(model_metadata[model], cached_models.get(model)) |
|
|
|
def register_model(self, model, cached_path=None): |
|
for index, gpu in enumerate(self.gpus): |
|
if gpu.register_model(model, cached_path): |
|
print("Registered model", model, "in GPU", gpu) |
|
break |
|
|
|
if index >= len(self.gpus): |
|
raise ValueError("Could not load model", model["checkpoint"]) |
|
|
|
def sanity_check(self, model_list): |
|
temp_gpus = [GPU(id) for id in range(len(self.gpus))] |
|
|
|
for model in model_list: |
|
|
|
current_gpu_index = 0 |
|
while current_gpu_index < len(temp_gpus): |
|
if not temp_gpus[current_gpu_index].register_model(model): |
|
current_gpu_index += 1 |
|
else: |
|
break |
|
|
|
if current_gpu_index >= len(temp_gpus): |
|
raise RuntimeError("SANITY CHECK FAILED") |
|
|
|
print("Current layout", temp_gpus) |
|
|
|
def __repr__(self): |
|
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}" |
|
|