File size: 6,019 Bytes
3fa3954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
"""S22.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1pq0UO46D0emoqF8rPuD4cUznmYVSMESO
"""

# Commented out IPython magic to ensure Python compatibility.
# %pip install lightning -q



import torch
torch.cuda.is_available()

import glob
import math
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union

import lightning as L
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy

from tsai_gpt.model import GPT, Block, Config
from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset
from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops
from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor
from tsai_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, load_checkpoint
import os
import pickle
from contextlib import nullcontext
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tsai_gpt.tokenizer import Tokenizer
import gradio as gr

model_name = "pythia-160m"
name = "redpajama"
out_dir = Path("out") / name

hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
logger = CSVLogger("out", name, flush_logs_every_n_steps=log_interval)

fabric = L.Fabric(devices=1, strategy='auto', precision=None, loggers=logger)

checkpoint_path = Path("out/redpajama/iter-023999-ckpt.pth")
config = Config.from_name(model_name)
model = GPT(config)

load_checkpoint(fabric, model, checkpoint_path)

#print(model.transformer.h[0].mlp.fc.weight)

def generate( model, config, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.

        """
        idx = idx.unsqueeze(dim=0)
        for _ in range(max_new_tokens):
            
            # # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= config.block_size else idx[ :,-config.block_size:]
            # forward the model to get the logits for the index in the sequence
            idx_cd = idx
            logits = model(idx_cd)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx



checkpoint_dir = Path('./checkpoints/meta-llama/Llama-2-7b-chat-hf')
token = Tokenizer(checkpoint_dir = checkpoint_dir)

def tsaigpt(start:str , model= model, max_new_tokens = 300, num_samples =2, tokeniser= token):



  # -----------------------------------------------------------------------------
    temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
    top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
    seed = 1337
    device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
    dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
    compile = False # use PyTorch 2.0 to compile the model to be faster
    #exec(open('configurator.py').read()) # overrides from command line or config file
    # -----------------------------------------------------------------------------

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

    model.eval()
    model.to(device)
    if compile:
        model = torch.compile(model) # requires PyTorch 2.0 (optional)



    start_ids = tokeniser.encode(start).to(device)
    #x = torch.tensor(start_ids, dtype=torch.long, device=device).clone().detach()

    # run generation
    with torch.no_grad():
        with ctx:

                y = generate(model =model, config =config ,  max_new_tokens = max_new_tokens, idx = start_ids ,temperature=1.0, top_k=None)
                #print(decode(y[0].tolist()))
                output = tokeniser.decode(y[0])
    return output

INTERFACE = gr.Interface(fn=tsaigpt, inputs=[gr.Textbox(label= "Prompt", value= 'All that glisters is not gold.'),
                    gr.Slider(minimum = 300, maximum = 500, value= 300,  label= "Maximum number of tokens to be generated")] ,
                    outputs=gr.Text(label= "Generated Text"), title="TSAI_GPT",
                 description="TSAIGPT is a transformer-based language model with only 0.16 billion parameters, trained on RedPajama 1T Sample.",
                 examples = [['We know what we are, but know not what we may be',300],
                ['Sweet are the uses of adversity which, like the toad, ugly and venomous, wears yet a precious jewel in his head',300],]
                        ).launch(debug=True)