Spaces:
Runtime error
Runtime error
Create gen.py
Browse files
gen.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
|
10 |
+
from app.llama import ModelArgs, Transformer, Tokenizer, LLaMA
|
11 |
+
|
12 |
+
from google.cloud import storage
|
13 |
+
|
14 |
+
bucket_name = os.environ.get("GCS_BUCKET")
|
15 |
+
|
16 |
+
llama_weight_path = "weights/llama"
|
17 |
+
tokenizer_weight_path = "weights/tokenizer"
|
18 |
+
|
19 |
+
def setup_model_parallel() -> Tuple[int, int]:
|
20 |
+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
21 |
+
world_size = int(os.environ.get("WORLD_SIZE", -1))
|
22 |
+
|
23 |
+
torch.distributed.init_process_group("nccl")
|
24 |
+
initialize_model_parallel(world_size)
|
25 |
+
torch.cuda.set_device(local_rank)
|
26 |
+
|
27 |
+
# seed must be the same in all processes
|
28 |
+
torch.manual_seed(1)
|
29 |
+
return local_rank, world_size
|
30 |
+
|
31 |
+
def download_pretrained_models(
|
32 |
+
ckpt_path: str,
|
33 |
+
tokenizer_path: str
|
34 |
+
):
|
35 |
+
os.makedirs(llama_weight_path)
|
36 |
+
os.makedirs(tokenizer_weight_path)
|
37 |
+
|
38 |
+
storage_client = storage.Client()
|
39 |
+
bucket = storage_client.bucket(bucket_name)
|
40 |
+
|
41 |
+
blobs = bucket.list_blobs(prefix=f"{ckpt_path}/")
|
42 |
+
for blob in blobs:
|
43 |
+
filename = blob.name.split("/")[1]
|
44 |
+
blob.download_to_filename(f"{llama_weight_path}/{filename}")
|
45 |
+
|
46 |
+
blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/")
|
47 |
+
for blob in blobs:
|
48 |
+
filename = blob.name.split("/")[1]
|
49 |
+
blob.download_to_filename(f"{tokenizer_weight_path}/{filename}")
|
50 |
+
|
51 |
+
def get_pretrained_models(
|
52 |
+
ckpt_path: str,
|
53 |
+
tokenizer_path: str,
|
54 |
+
local_rank: int,
|
55 |
+
world_size: int) -> LLaMA:
|
56 |
+
|
57 |
+
download_pretrained_models(ckpt_path, tokenizer_path)
|
58 |
+
|
59 |
+
start_time = time.time()
|
60 |
+
checkpoints = sorted(Path(llama_weight_path).glob("*.pth"))
|
61 |
+
|
62 |
+
llama_ckpt_path = checkpoints[local_rank]
|
63 |
+
print("Loading")
|
64 |
+
checkpoint = torch.load(llama_ckpt_path, map_location=lambda storage, loc: storage.cuda(0))
|
65 |
+
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
66 |
+
params = json.loads(f.read())
|
67 |
+
|
68 |
+
model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params)
|
69 |
+
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
70 |
+
model_args.vocab_size = tokenizer.n_words
|
71 |
+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
72 |
+
model = Transformer(model_args)
|
73 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
74 |
+
model.load_state_dict(checkpoint, strict=False)
|
75 |
+
|
76 |
+
generator = LLaMA(model, tokenizer)
|
77 |
+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
78 |
+
return generator
|
79 |
+
|
80 |
+
def get_output(
|
81 |
+
generator: LLaMA,
|
82 |
+
prompt: str,
|
83 |
+
temperature: float = 0.8,
|
84 |
+
top_p: float = 0.95):
|
85 |
+
|
86 |
+
prompts = [prompt]
|
87 |
+
results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
|
88 |
+
|
89 |
+
return results
|