File size: 4,008 Bytes
1e6ae7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import transformers\n",
    "import torch\n",
    "\n",
    "model = \"meta-llama/Llama-2-7b-chat-hf\" # meta-llama/Llama-2-7b-chat-hf\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "llama_pipeline = pipeline(\n",
    "    \"text-generation\",  # LLM task\n",
    "    model=model,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"auto\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = \"\"\"<s>[INST] <<SYS>>\n",
    "You are a helpful bot. Your answers are clear and concise.\n",
    "<</SYS>>\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "# Formatting function for message and history\n",
    "def format_message(message: str, history: list, memory_limit: int = 3) -> str:\n",
    "    \"\"\"\n",
    "    Formats the message and history for the Llama model.\n",
    "\n",
    "    Parameters:\n",
    "        message (str): Current message to send.\n",
    "        history (list): Past conversation history.\n",
    "        memory_limit (int): Limit on how many past interactions to consider.\n",
    "\n",
    "    Returns:\n",
    "        str: Formatted message string\n",
    "    \"\"\"\n",
    "    # always keep len(history) <= memory_limit\n",
    "    if len(history) > memory_limit:\n",
    "        history = history[-memory_limit:]\n",
    "\n",
    "    if len(history) == 0:\n",
    "        return SYSTEM_PROMPT + f\"{message} [/INST]\"\n",
    "\n",
    "    formatted_message = SYSTEM_PROMPT + f\"{history[0][0]} [/INST] {history[0][1]} </s>\"\n",
    "\n",
    "    # Handle conversation history\n",
    "    for user_msg, model_answer in history[1:]:\n",
    "        formatted_message += f\"<s>[INST] {user_msg} [/INST] {model_answer} </s>\"\n",
    "\n",
    "    # Handle the current message\n",
    "    formatted_message += f\"<s>[INST] {message} [/INST]\"\n",
    "\n",
    "    return formatted_message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate a response from the Llama model\n",
    "def get_llama_response(message: str, history: list) -> str:\n",
    "    \"\"\"\n",
    "    Generates a conversational response from the Llama model.\n",
    "\n",
    "    Parameters:\n",
    "        message (str): User's input message.\n",
    "        history (list): Past conversation history.\n",
    "\n",
    "    Returns:\n",
    "        str: Generated response from the Llama model.\n",
    "    \"\"\"\n",
    "    query = format_message(message, history)\n",
    "    response = \"\"\n",
    "\n",
    "    sequences = llama_pipeline(\n",
    "        query,\n",
    "        do_sample=True,\n",
    "        top_k=10,\n",
    "        num_return_sequences=1,\n",
    "        eos_token_id=tokenizer.eos_token_id,\n",
    "        max_length=1024,\n",
    "    )\n",
    "\n",
    "    generated_text = sequences[0]['generated_text']\n",
    "    response = generated_text[len(query):]  # Remove the prompt from the output\n",
    "\n",
    "    print(\"Chatbot:\", response.strip())\n",
    "    return response.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "\n",
    "gr.ChatInterface(get_llama_response).launch()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "itam",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}