Spaces:
Sleeping
Sleeping
File size: 4,008 Bytes
1e6ae7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"import transformers\n",
"import torch\n",
"\n",
"model = \"meta-llama/Llama-2-7b-chat-hf\" # meta-llama/Llama-2-7b-chat-hf\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"llama_pipeline = pipeline(\n",
" \"text-generation\", # LLM task\n",
" model=model,\n",
" torch_dtype=torch.float16,\n",
" device_map=\"auto\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SYSTEM_PROMPT = \"\"\"<s>[INST] <<SYS>>\n",
"You are a helpful bot. Your answers are clear and concise.\n",
"<</SYS>>\n",
"\n",
"\"\"\"\n",
"\n",
"# Formatting function for message and history\n",
"def format_message(message: str, history: list, memory_limit: int = 3) -> str:\n",
" \"\"\"\n",
" Formats the message and history for the Llama model.\n",
"\n",
" Parameters:\n",
" message (str): Current message to send.\n",
" history (list): Past conversation history.\n",
" memory_limit (int): Limit on how many past interactions to consider.\n",
"\n",
" Returns:\n",
" str: Formatted message string\n",
" \"\"\"\n",
" # always keep len(history) <= memory_limit\n",
" if len(history) > memory_limit:\n",
" history = history[-memory_limit:]\n",
"\n",
" if len(history) == 0:\n",
" return SYSTEM_PROMPT + f\"{message} [/INST]\"\n",
"\n",
" formatted_message = SYSTEM_PROMPT + f\"{history[0][0]} [/INST] {history[0][1]} </s>\"\n",
"\n",
" # Handle conversation history\n",
" for user_msg, model_answer in history[1:]:\n",
" formatted_message += f\"<s>[INST] {user_msg} [/INST] {model_answer} </s>\"\n",
"\n",
" # Handle the current message\n",
" formatted_message += f\"<s>[INST] {message} [/INST]\"\n",
"\n",
" return formatted_message"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate a response from the Llama model\n",
"def get_llama_response(message: str, history: list) -> str:\n",
" \"\"\"\n",
" Generates a conversational response from the Llama model.\n",
"\n",
" Parameters:\n",
" message (str): User's input message.\n",
" history (list): Past conversation history.\n",
"\n",
" Returns:\n",
" str: Generated response from the Llama model.\n",
" \"\"\"\n",
" query = format_message(message, history)\n",
" response = \"\"\n",
"\n",
" sequences = llama_pipeline(\n",
" query,\n",
" do_sample=True,\n",
" top_k=10,\n",
" num_return_sequences=1,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" max_length=1024,\n",
" )\n",
"\n",
" generated_text = sequences[0]['generated_text']\n",
" response = generated_text[len(query):] # Remove the prompt from the output\n",
"\n",
" print(\"Chatbot:\", response.strip())\n",
" return response.strip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gradio as gr\n",
"\n",
"gr.ChatInterface(get_llama_response).launch()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "itam",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|