from typing import Tuple import os import time import json from pathlib import Path import torch from fairscale.nn.model_parallel.initialize import initialize_model_parallel from llama.generation import LLaMA from llama.model import ModelArgs, Transformer from llama.tokenizer import Tokenizer from google.cloud import storage bucket_name = os.environ.get("GCS_BUCKET") llama_weight_path = "weights/llama" tokenizer_weight_path = "weights/tokenizer" def setup_model_parallel() -> Tuple[int, int]: local_rank = int(os.environ.get("LOCAL_RANK", -1)) world_size = int(os.environ.get("WORLD_SIZE", -1)) torch.distributed.init_process_group("nccl") initialize_model_parallel(world_size) torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(1) return local_rank, world_size def download_pretrained_models( ckpt_path: str, tokenizer_path: str ): os.makedirs(llama_weight_path) os.makedirs(tokenizer_weight_path) storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blobs = bucket.list_blobs(prefix=f"{ckpt_path}/") for blob in blobs: filename = blob.name.split("/")[1] blob.download_to_filename(f"{llama_weight_path}/{filename}") blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/") for blob in blobs: filename = blob.name.split("/")[1] blob.download_to_filename(f"{tokenizer_weight_path}/{filename}") def get_pretrained_models( ckpt_path: str, tokenizer_path: str, local_rank: int, world_size: int) -> LLaMA: download_pretrained_models(ckpt_path, tokenizer_path) start_time = time.time() checkpoints = sorted(Path(llama_weight_path).glob("*.pth")) llama_ckpt_path = checkpoints[local_rank] print("Loading") checkpoint = torch.load(llama_ckpt_path, map_location=lambda storage, loc: storage.cuda(0)) with open(Path(llama_weight_path) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params) tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model") model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args).cuda().half() torch.set_default_tensor_type(torch.FloatTensor) model.load_state_dict(checkpoint, strict=False) generator = LLaMA(model, tokenizer) print(f"Loaded in {time.time() - start_time:.2f} seconds") return generator def get_output( generator: LLaMA, prompt: str, temperature: float = 0.8, top_p: float = 0.95): prompts = [prompt] results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p) return results