barunsaha commited on
Commit
89a3160
·
unverified ·
2 Parent(s): 89c5253 50f37bd

Merge pull request #62 from barun-saha/ollama

Browse files
Files changed (5) hide show
  1. README.md +29 -0
  2. app.py +59 -26
  3. global_config.py +22 -2
  4. helpers/llm_helper.py +30 -9
  5. requirements.txt +6 -1
README.md CHANGED
@@ -47,6 +47,8 @@ Different LLMs offer different styles of content generation. Use one of the foll
47
 
48
  The Mistral models do not mandatorily require an access token. However, you are encouraged to get and use your own Hugging Face access token.
49
 
 
 
50
 
51
  # Icons
52
 
@@ -62,6 +64,33 @@ To run this project by yourself, you need to provide the `HUGGINGFACEHUB_API_TOK
62
  for example, in a `.env` file. Alternatively, you can provide the access token in the app's user interface itself (UI). For other LLM providers, the API key can only be specified in the UI. For image search, the `PEXEL_API_KEY` should be made available as an environment variable.
63
  Visit the respective websites to obtain the API keys.
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Live Demo
67
 
 
47
 
48
  The Mistral models do not mandatorily require an access token. However, you are encouraged to get and use your own Hugging Face access token.
49
 
50
+ In addition, offline LLMs provided by Ollama can be used. Read below to know more.
51
+
52
 
53
  # Icons
54
 
 
64
  for example, in a `.env` file. Alternatively, you can provide the access token in the app's user interface itself (UI). For other LLM providers, the API key can only be specified in the UI. For image search, the `PEXEL_API_KEY` should be made available as an environment variable.
65
  Visit the respective websites to obtain the API keys.
66
 
67
+ ## Offline LLMs Using Ollama
68
+
69
+ SlideDeck AI allows the use of offline LLMs to generate the contents of the slide decks. This is typically suitable for individuals or organizations who would like to use self-hosted LLMs for privacy concerns, for example.
70
+
71
+ Offline LLMs are made available via Ollama. Therefore, a pre-requisite here is to have [Ollama installed](https://ollama.com/download) on the system and the desired [LLM](https://ollama.com/search) pulled locally.
72
+
73
+ In addition, the `RUN_IN_OFFLINE_MODE` environment variable needs to be set to `True` to enable the offline mode. This, for example, can be done using a `.env` file or from the terminal. The typical steps to use SlideDeck AI in offline mode (in a `bash` shell) are as follows:
74
+
75
+ ```bash
76
+ ollama list # View locally available LLMs
77
+ export RUN_IN_OFFLINE_MODE=True # Enable the offline mode to use Ollama
78
+ git clone https://github.com/barun-saha/slide-deck-ai.git
79
+ cd slide-deck-ai
80
+ python -m venv venv # Create a virtual environment
81
+ source venv/bin/activate # On a Linux system
82
+ pip install -r requirements.txt
83
+ streamlit run ./app.py # Run the application
84
+ ```
85
+
86
+ The `.env` file should be created inside the `slide-deck-ai` directory.
87
+
88
+ The UI is similar to the online mode. However, rather than selecting an LLM from a list, one has to write the name of the Ollama model to be used in a textbox. There is no API key asked here.
89
+
90
+ The online and offline modes are mutually exclusive. So, setting `RUN_IN_OFFLINE_MODE` to `False` will make SlideDeck AI use the online LLMs (i.e., the "original mode."). By default, `RUN_IN_OFFLINE_MODE` is set to `False`.
91
+
92
+ Finally, the focus is on using offline LLMs, not going completely offline. So, Internet connectivity would still be required to fetch the images from Pexels.
93
+
94
 
95
  # Live Demo
96
 
app.py CHANGED
@@ -3,23 +3,34 @@ Streamlit app containing the UI and the application logic.
3
  """
4
  import datetime
5
  import logging
 
6
  import pathlib
7
  import random
8
  import tempfile
9
  from typing import List, Union
10
 
 
11
  import huggingface_hub
12
  import json5
 
13
  import requests
14
  import streamlit as st
 
15
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
16
  from langchain_core.messages import HumanMessage
17
  from langchain_core.prompts import ChatPromptTemplate
18
 
 
19
  from global_config import GlobalConfig
20
  from helpers import llm_helper, pptx_helper, text_helper
21
 
22
 
 
 
 
 
 
 
23
  @st.cache_data
24
  def _load_strings() -> dict:
25
  """
@@ -135,25 +146,36 @@ with st.sidebar:
135
  horizontal=True
136
  )
137
 
138
- # The LLMs
139
- llm_provider_to_use = st.sidebar.selectbox(
140
- label='2: Select an LLM to use:',
141
- options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
142
- index=GlobalConfig.DEFAULT_MODEL_INDEX,
143
- help=GlobalConfig.LLM_PROVIDER_HELP,
144
- on_change=reset_api_key
145
- ).split(' ')[0]
146
-
147
- # The API key/access token
148
- api_key_token = st.text_input(
149
- label=(
150
- '3: Paste your API key/access token:\n\n'
151
- '*Mandatory* for Cohere and Gemini LLMs.'
152
- ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
153
- ),
154
- type='password',
155
- key='api_key_input'
156
- )
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
  def build_ui():
@@ -200,7 +222,11 @@ def set_up_chat_ui():
200
  placeholder=APP_TEXT['chat_placeholder'],
201
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
202
  ):
203
- provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
 
 
 
 
204
 
205
  if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
206
  return
@@ -233,7 +259,7 @@ def set_up_chat_ui():
233
  llm = llm_helper.get_langchain_llm(
234
  provider=provider,
235
  model=llm_name,
236
- max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
237
  api_key=api_key_token.strip(),
238
  )
239
 
@@ -252,18 +278,17 @@ def set_up_chat_ui():
252
  # Update the progress bar with an approx progress percentage
253
  progress_bar.progress(
254
  min(
255
- len(response) / GlobalConfig.VALID_MODELS[
256
- llm_provider_to_use
257
- ]['max_new_tokens'],
258
  0.95
259
  ),
260
  text='Streaming content...this might take a while...'
261
  )
262
- except requests.exceptions.ConnectionError:
263
  handle_error(
264
  'A connection error occurred while streaming content from the LLM endpoint.'
265
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
266
- ' Alternatively, try selecting a different LLM from the dropdown list.',
 
267
  True
268
  )
269
  return
@@ -274,6 +299,14 @@ def set_up_chat_ui():
274
  True
275
  )
276
  return
 
 
 
 
 
 
 
 
277
  except Exception as ex:
278
  handle_error(
279
  f'An unexpected error occurred while generating the content: {ex}'
 
3
  """
4
  import datetime
5
  import logging
6
+ import os
7
  import pathlib
8
  import random
9
  import tempfile
10
  from typing import List, Union
11
 
12
+ import httpx
13
  import huggingface_hub
14
  import json5
15
+ import ollama
16
  import requests
17
  import streamlit as st
18
+ from dotenv import load_dotenv
19
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
20
  from langchain_core.messages import HumanMessage
21
  from langchain_core.prompts import ChatPromptTemplate
22
 
23
+ import global_config as gcfg
24
  from global_config import GlobalConfig
25
  from helpers import llm_helper, pptx_helper, text_helper
26
 
27
 
28
+ load_dotenv()
29
+
30
+
31
+ RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
32
+
33
+
34
  @st.cache_data
35
  def _load_strings() -> dict:
36
  """
 
146
  horizontal=True
147
  )
148
 
149
+ if RUN_IN_OFFLINE_MODE:
150
+ llm_provider_to_use = st.text_input(
151
+ label='2: Enter Ollama model name to use:',
152
+ help=(
153
+ 'Specify a correct, locally available LLM, found by running `ollama list`, for'
154
+ ' example `mistral:v0.2` and `mistral-nemo:latest`. Having an Ollama-compatible'
155
+ ' and supported GPU is strongly recommended.'
156
+ )
157
+ )
158
+ api_key_token: str = ''
159
+ else:
160
+ # The LLMs
161
+ llm_provider_to_use = st.sidebar.selectbox(
162
+ label='2: Select an LLM to use:',
163
+ options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
164
+ index=GlobalConfig.DEFAULT_MODEL_INDEX,
165
+ help=GlobalConfig.LLM_PROVIDER_HELP,
166
+ on_change=reset_api_key
167
+ ).split(' ')[0]
168
+
169
+ # The API key/access token
170
+ api_key_token = st.text_input(
171
+ label=(
172
+ '3: Paste your API key/access token:\n\n'
173
+ '*Mandatory* for Cohere and Gemini LLMs.'
174
+ ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
175
+ ),
176
+ type='password',
177
+ key='api_key_input'
178
+ )
179
 
180
 
181
  def build_ui():
 
222
  placeholder=APP_TEXT['chat_placeholder'],
223
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
224
  ):
225
+ provider, llm_name = llm_helper.get_provider_model(
226
+ llm_provider_to_use,
227
+ use_ollama=RUN_IN_OFFLINE_MODE
228
+ )
229
+ print(f'{llm_provider_to_use=}, {provider=}, {llm_name=}, {api_key_token=}')
230
 
231
  if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
232
  return
 
259
  llm = llm_helper.get_langchain_llm(
260
  provider=provider,
261
  model=llm_name,
262
+ max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
263
  api_key=api_key_token.strip(),
264
  )
265
 
 
278
  # Update the progress bar with an approx progress percentage
279
  progress_bar.progress(
280
  min(
281
+ len(response) / gcfg.get_max_output_tokens(llm_provider_to_use),
 
 
282
  0.95
283
  ),
284
  text='Streaming content...this might take a while...'
285
  )
286
+ except (httpx.ConnectError, requests.exceptions.ConnectionError):
287
  handle_error(
288
  'A connection error occurred while streaming content from the LLM endpoint.'
289
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
290
+ ' Alternatively, try selecting a different LLM from the dropdown list. If you are'
291
+ ' using Ollama, make sure that Ollama is already running on your system.',
292
  True
293
  )
294
  return
 
299
  True
300
  )
301
  return
302
+ except ollama.ResponseError:
303
+ handle_error(
304
+ f'The model `{llm_name}` is unavailable with Ollama on your system.'
305
+ f' Make sure that you have provided the correct LLM name or pull it using'
306
+ f' `ollama pull {llm_name}`. View LLMs available locally by running `ollama list`.',
307
+ True
308
+ )
309
+ return
310
  except Exception as ex:
311
  handle_error(
312
  f'An unexpected error occurred while generating the content: {ex}'
global_config.py CHANGED
@@ -20,7 +20,13 @@ class GlobalConfig:
20
  PROVIDER_COHERE = 'co'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
  PROVIDER_HUGGING_FACE = 'hf'
23
- VALID_PROVIDERS = {PROVIDER_COHERE, PROVIDER_GOOGLE_GEMINI, PROVIDER_HUGGING_FACE}
 
 
 
 
 
 
24
  VALID_MODELS = {
25
  '[co]command-r-08-2024': {
26
  'description': 'simpler, slower',
@@ -47,7 +53,7 @@ class GlobalConfig:
47
  'LLM provider codes:\n\n'
48
  '- **[co]**: Cohere\n'
49
  '- **[gg]**: Google Gemini API\n'
50
- '- **[hf]**: Hugging Face Inference Endpoint\n'
51
  )
52
  DEFAULT_MODEL_INDEX = 2
53
  LLM_MODEL_TEMPERATURE = 0.2
@@ -125,3 +131,17 @@ logging.basicConfig(
125
  format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
126
  datefmt='%Y-%m-%d %H:%M:%S'
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  PROVIDER_COHERE = 'co'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
  PROVIDER_HUGGING_FACE = 'hf'
23
+ PROVIDER_OLLAMA = 'ol'
24
+ VALID_PROVIDERS = {
25
+ PROVIDER_COHERE,
26
+ PROVIDER_GOOGLE_GEMINI,
27
+ PROVIDER_HUGGING_FACE,
28
+ PROVIDER_OLLAMA
29
+ }
30
  VALID_MODELS = {
31
  '[co]command-r-08-2024': {
32
  'description': 'simpler, slower',
 
53
  'LLM provider codes:\n\n'
54
  '- **[co]**: Cohere\n'
55
  '- **[gg]**: Google Gemini API\n'
56
+ '- **[hf]**: Hugging Face Inference API\n'
57
  )
58
  DEFAULT_MODEL_INDEX = 2
59
  LLM_MODEL_TEMPERATURE = 0.2
 
131
  format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
132
  datefmt='%Y-%m-%d %H:%M:%S'
133
  )
134
+
135
+
136
+ def get_max_output_tokens(llm_name: str) -> int:
137
+ """
138
+ Get the max output tokens value configured for an LLM. Return a default value if not configured.
139
+
140
+ :param llm_name: The name of the LLM.
141
+ :return: Max output tokens or a default count.
142
+ """
143
+
144
+ try:
145
+ return GlobalConfig.VALID_MODELS[llm_name]['max_new_tokens']
146
+ except KeyError:
147
+ return 2048
helpers/llm_helper.py CHANGED
@@ -17,8 +17,9 @@ from global_config import GlobalConfig
17
 
18
 
19
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
 
20
  # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
21
- API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9\-_]{6,64}$')
22
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
23
  REQUEST_TIMEOUT = 35
24
 
@@ -39,20 +40,28 @@ http_session.mount('https://', adapter)
39
  http_session.mount('http://', adapter)
40
 
41
 
42
- def get_provider_model(provider_model: str) -> Tuple[str, str]:
43
  """
44
  Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
45
 
46
  :param provider_model: The provider, model name string from `GlobalConfig`.
47
- :return: The provider and the model name.
 
48
  """
49
 
50
- match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
51
 
52
- if match:
53
- inside_brackets = match.group(1)
54
- outside_brackets = match.group(2)
55
- return inside_brackets, outside_brackets
 
 
 
 
 
 
 
56
 
57
  return '', ''
58
 
@@ -152,6 +161,18 @@ def get_langchain_llm(
152
  streaming=True,
153
  )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return None
156
 
157
 
@@ -163,4 +184,4 @@ if __name__ == '__main__':
163
  ]
164
 
165
  for text in inputs:
166
- print(get_provider_model(text))
 
17
 
18
 
19
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
20
+ OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
21
  # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
22
+ API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
23
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
24
  REQUEST_TIMEOUT = 35
25
 
 
40
  http_session.mount('http://', adapter)
41
 
42
 
43
+ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]:
44
  """
45
  Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
46
 
47
  :param provider_model: The provider, model name string from `GlobalConfig`.
48
+ :param use_ollama: Whether Ollama is used (i.e., running in offline mode).
49
+ :return: The provider and the model name; empty strings in case no matching pattern found.
50
  """
51
 
52
+ provider_model = provider_model.strip()
53
 
54
+ if use_ollama:
55
+ match = OLLAMA_MODEL_REGEX.match(provider_model)
56
+ if match:
57
+ return GlobalConfig.PROVIDER_OLLAMA, match.group(0)
58
+ else:
59
+ match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
60
+
61
+ if match:
62
+ inside_brackets = match.group(1)
63
+ outside_brackets = match.group(2)
64
+ return inside_brackets, outside_brackets
65
 
66
  return '', ''
67
 
 
161
  streaming=True,
162
  )
163
 
164
+ if provider == GlobalConfig.PROVIDER_OLLAMA:
165
+ from langchain_ollama.llms import OllamaLLM
166
+
167
+ logger.debug('Getting LLM via Ollama: %s', model)
168
+ return OllamaLLM(
169
+ model=model,
170
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
171
+ num_predict=max_new_tokens,
172
+ format='json',
173
+ streaming=True,
174
+ )
175
+
176
  return None
177
 
178
 
 
184
  ]
185
 
186
  for text in inputs:
187
+ print(get_provider_model(text, use_ollama=False))
requirements.txt CHANGED
@@ -12,9 +12,10 @@ langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
14
  langchain-cohere==0.3.3
 
15
  streamlit~=1.38.0
16
 
17
- python-pptx
18
  # metaphor-python
19
  json5~=0.9.14
20
  requests~=2.32.3
@@ -32,3 +33,7 @@ certifi==2024.8.30
32
  urllib3==2.2.3
33
 
34
  anyio==4.4.0
 
 
 
 
 
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
14
  langchain-cohere==0.3.3
15
+ langchain-ollama==0.2.1
16
  streamlit~=1.38.0
17
 
18
+ python-pptx~=0.6.21
19
  # metaphor-python
20
  json5~=0.9.14
21
  requests~=2.32.3
 
33
  urllib3==2.2.3
34
 
35
  anyio==4.4.0
36
+
37
+ httpx~=0.27.2
38
+ huggingface-hub~=0.24.5
39
+ ollama~=0.4.3