Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import os | |
import time | |
import json | |
from pathlib import Path | |
import torch | |
from fairscale.nn.model_parallel.initialize import initialize_model_parallel | |
from llama.generation import LLaMA | |
from llama.model import ModelArgs, Transformer | |
from llama.tokenizer import Tokenizer | |
from google.cloud import storage | |
bucket_name = os.environ.get("GCS_BUCKET") | |
llama_weight_path = "weights/llama" | |
tokenizer_weight_path = "weights/tokenizer" | |
def setup_model_parallel() -> Tuple[int, int]: | |
local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
world_size = int(os.environ.get("WORLD_SIZE", -1)) | |
torch.distributed.init_process_group("nccl") | |
initialize_model_parallel(world_size) | |
torch.cuda.set_device(local_rank) | |
# seed must be the same in all processes | |
torch.manual_seed(1) | |
return local_rank, world_size | |
def download_pretrained_models( | |
ckpt_path: str, | |
tokenizer_path: str | |
): | |
os.makedirs(llama_weight_path) | |
os.makedirs(tokenizer_weight_path) | |
storage_client = storage.Client.create_anonymous_client() | |
bucket = storage_client.bucket(bucket_name) | |
blobs = bucket.list_blobs(prefix=f"{ckpt_path}/") | |
for blob in blobs: | |
filename = blob.name.split("/")[1] | |
blob.download_to_filename(f"{llama_weight_path}/{filename}") | |
blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/") | |
for blob in blobs: | |
filename = blob.name.split("/")[1] | |
blob.download_to_filename(f"{tokenizer_weight_path}/{filename}") | |
def get_pretrained_models( | |
ckpt_path: str, | |
tokenizer_path: str, | |
local_rank: int, | |
world_size: int) -> LLaMA: | |
download_pretrained_models(ckpt_path, tokenizer_path) | |
start_time = time.time() | |
checkpoints = sorted(Path(llama_weight_path).glob("*.pth")) | |
llama_ckpt_path = checkpoints[local_rank] | |
print("Loading") | |
checkpoint = torch.load(llama_ckpt_path, map_location="cpu") | |
with open(Path(llama_weight_path) / "params.json", "r") as f: | |
params = json.loads(f.read()) | |
model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params) | |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model") | |
model_args.vocab_size = tokenizer.n_words | |
torch.set_default_tensor_type(torch.cuda.HalfTensor) | |
model = Transformer(model_args).cuda().half() | |
torch.set_default_tensor_type(torch.FloatTensor) | |
model.load_state_dict(checkpoint, strict=False) | |
generator = LLaMA(model, tokenizer) | |
print(f"Loaded in {time.time() - start_time:.2f} seconds") | |
return generator | |
def get_output( | |
generator: LLaMA, | |
prompt: str, | |
temperature: float = 0.8, | |
top_p: float = 0.95): | |
prompts = [prompt] | |
results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p) | |
return results |