|
|
|
"""Model Pull and Prompt.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1Ap0yRsMk8on-NcFPSYay6W3Oble43kyi |
|
""" |
|
|
|
|
|
import torch |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
peft_model_id = "vhs01/mistral-7b-dolly" |
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
|
|
|
|
|
|
|
|
|
|
from transformers import BitsAndBytesConfig |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
config.base_model_name_or_path, |
|
return_dict=True, |
|
load_in_4bit=True, |
|
device_map='auto' |
|
) |
|
|
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, |
|
padding_side = "right", |
|
add_eos_token = True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
fine_tuned_model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
|
from transformers import pipeline, logging |
|
|
|
logging.set_verbosity(logging.CRITICAL) |
|
|
|
pipe = pipeline( |
|
task="text-generation", |
|
model=fine_tuned_model, |
|
tokenizer=tokenizer, |
|
eos_token_id=model.config.eos_token_id, |
|
max_new_tokens=500) |
|
|
|
prompt = """ |
|
What is a Python? Here is some context: Python is a high-level, general-purpose programming language. |
|
""" |
|
pipe = pipeline(task="text-generation", |
|
model=fine_tuned_model, |
|
tokenizer=tokenizer, |
|
eos_token_id=model.config.eos_token_id, |
|
max_new_tokens=500) |
|
|
|
result = pipe(f"<s>[INST] {prompt} [/INST]") |
|
generated = result[0]['generated_text'] |
|
print(generated[generated.find('[/INST]')+8:]) |
|
|
|
prompt = """ |
|
Please summarize what Linkedin does. Here is some context: LinkedIn is a business and employment-focused social media platform |
|
""" |
|
pipe = pipeline(task="text-generation", |
|
model=fine_tuned_model, |
|
tokenizer=tokenizer, |
|
eos_token_id=model.config.eos_token_id, |
|
max_new_tokens=500) |
|
|
|
result = pipe(f"<s>[INST] {prompt} [/INST]") |
|
generated = result[0]['generated_text'] |
|
print(generated[generated.find('[/INST]')+8:]) |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
outputs = pipe(f"<s>[INST] {prompt} [/INST]") |
|
demo=gr.Interface(pipe, |
|
inputs=gr.Textbox(label="Prompt"), |
|
outputs=gr.Textbox(generated[generated.find('[/INST]')+8:])) |
|
|
|
demo.launch(share=True) |