barunsaha commited on
Commit
cf45a37
·
1 Parent(s): 80a7ca8

Add support for Gemini 1.5 Flash via Gemini API

Browse files
Files changed (5) hide show
  1. app.py +122 -110
  2. global_config.py +18 -3
  3. helpers/llm_helper.py +50 -31
  4. requirements.txt +1 -1
  5. strings.json +2 -1
app.py CHANGED
@@ -5,7 +5,6 @@ import datetime
5
  import logging
6
  import pathlib
7
  import random
8
- import sys
9
  import tempfile
10
  from typing import List, Union
11
 
@@ -17,9 +16,6 @@ from langchain_community.chat_message_histories import StreamlitChatMessageHisto
17
  from langchain_core.messages import HumanMessage
18
  from langchain_core.prompts import ChatPromptTemplate
19
 
20
- sys.path.append('..')
21
- sys.path.append('../..')
22
-
23
  from global_config import GlobalConfig
24
  from helpers import llm_helper, pptx_helper, text_helper
25
 
@@ -54,6 +50,60 @@ def _get_prompt_template(is_refinement: bool) -> str:
54
  return template
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  APP_TEXT = _load_strings()
58
 
59
  # Session variables
@@ -80,11 +130,8 @@ with st.sidebar:
80
  llm_provider_to_use = st.sidebar.selectbox(
81
  label='2: Select an LLM to use:',
82
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
83
- index=0,
84
- help=(
85
- 'LLM provider codes:\n\n'
86
- '- **[hf]**: Hugging Face Inference Endpoint\n'
87
- ),
88
  ).split(' ')[0]
89
 
90
  # The API key/access token
@@ -123,53 +170,28 @@ def set_up_chat_ui():
123
  with st.expander('Usage Instructions'):
124
  st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
125
 
126
- st.info(
127
- 'If you like SlideDeck AI, please consider leaving a heart ❤️ on the'
128
- ' [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or'
129
- ' a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai).'
130
- ' Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated.'
131
- )
132
-
133
- # view_messages = st.expander('View the messages in the session state')
134
-
135
- st.chat_message('ai').write(
136
- random.choice(APP_TEXT['ai_greetings'])
137
- )
138
 
139
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
140
-
141
- if _is_it_refinement():
142
- template = _get_prompt_template(is_refinement=True)
143
- else:
144
- template = _get_prompt_template(is_refinement=False)
145
-
146
- prompt_template = ChatPromptTemplate.from_template(template)
147
 
148
  # Since Streamlit app reloads at every interaction, display the chat history
149
  # from the save session state
150
  for msg in history.messages:
151
- msg_type = msg.type
152
- if msg_type == 'user':
153
- st.chat_message(msg_type).write(msg.content)
154
- else:
155
- st.chat_message(msg_type).code(msg.content, language='json')
156
 
157
  if prompt := st.chat_input(
158
  placeholder=APP_TEXT['chat_placeholder'],
159
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
160
  ):
161
- if not text_helper.is_valid_prompt(prompt):
162
- st.error(
163
- 'Not enough information provided!'
164
- ' Please be a little more descriptive and type a few words'
165
- ' with a few characters :)'
166
- )
167
- return
168
-
169
  provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
170
 
171
- if not provider or not llm_name:
172
- st.error('No valid LLM provider and/or model name found!')
173
  return
174
 
175
  logger.info(
@@ -178,72 +200,76 @@ def set_up_chat_ui():
178
  )
179
  st.chat_message('user').write(prompt)
180
 
181
- user_messages = _get_user_messages()
182
- user_messages.append(prompt)
183
- list_of_msgs = [
184
- f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
185
- ]
186
- list_of_msgs = '\n'.join(list_of_msgs)
187
-
188
  if _is_it_refinement():
 
 
 
 
 
189
  formatted_template = prompt_template.format(
190
  **{
191
- 'instructions': list_of_msgs,
192
  'previous_content': _get_last_response(),
193
  }
194
  )
195
  else:
196
- formatted_template = prompt_template.format(
197
- **{
198
- 'question': prompt,
199
- }
200
- )
201
 
202
  progress_bar = st.progress(0, 'Preparing to call LLM...')
203
  response = ''
204
 
205
  try:
206
- for chunk in llm_helper.get_langchain_llm(
207
- provider=provider,
208
- model=llm_name,
209
- max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
210
- api_key=api_key_token.strip(),
211
- ).stream(formatted_template):
212
- response += chunk
213
-
214
- # Update the progress bar
215
- progress_percentage = min(
216
- len(response) / GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'], 0.95
 
 
217
  )
 
 
 
 
 
 
218
  progress_bar.progress(
219
- progress_percentage,
 
 
 
 
 
220
  text='Streaming content...this might take a while...'
221
  )
222
  except requests.exceptions.ConnectionError:
223
- msg = (
224
  'A connection error occurred while streaming content from the LLM endpoint.'
225
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
226
- ' Alternatively, try selecting a different LLM from the dropdown list.'
 
227
  )
228
- logger.error(msg)
229
- st.error(msg)
230
  return
231
  except huggingface_hub.errors.ValidationError as ve:
232
- msg = (
233
  f'An error occurred while trying to generate the content: {ve}'
234
- '\nPlease try again with a significantly shorter input text.'
 
235
  )
236
- logger.error(msg)
237
- st.error(msg)
238
  return
239
  except Exception as ex:
240
- msg = (
241
  f'An unexpected error occurred while generating the content: {ex}'
242
  '\nPlease try again later, possibly with different inputs.'
243
- ' Alternatively, try selecting a different LLM from the dropdown list.'
 
244
  )
245
- logger.error(msg)
246
- st.error(msg)
247
  return
248
 
249
  history.add_user_message(prompt)
@@ -252,25 +278,20 @@ def set_up_chat_ui():
252
  # The content has been generated as JSON
253
  # There maybe trailing ``` at the end of the response -- remove them
254
  # To be careful: ``` may be part of the content as well when code is generated
255
- response_cleaned = text_helper.get_clean_json(response)
256
-
257
  logger.info(
258
- 'Cleaned JSON response:: original length: %d | cleaned length: %d',
259
- len(response), len(response_cleaned)
260
  )
261
- # logger.debug('Cleaned JSON: %s', response_cleaned)
262
 
263
  # Now create the PPT file
264
  progress_bar.progress(
265
  GlobalConfig.LLM_PROGRESS_MAX,
266
  text='Finding photos online and generating the slide deck...'
267
  )
268
- path = generate_slide_deck(response_cleaned)
269
  progress_bar.progress(1.0, text='Done!')
270
-
271
  st.chat_message('ai').code(response, language='json')
272
 
273
- if path:
274
  _display_download_button(path)
275
 
276
  logger.info(
@@ -291,44 +312,35 @@ def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]:
291
  try:
292
  parsed_data = json5.loads(json_str)
293
  except ValueError:
294
- st.error(
295
- 'Encountered error while parsing JSON...will fix it and retry'
296
- )
297
- logger.error(
298
- 'Caught ValueError: trying again after repairing JSON...'
299
  )
300
  try:
301
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
302
  except ValueError:
303
- st.error(
304
  'Encountered an error again while fixing JSON...'
305
  'the slide deck cannot be created, unfortunately ☹'
306
- '\nPlease try again later.'
 
307
  )
308
- logger.error(
309
- 'Caught ValueError: failed to repair JSON!'
310
- )
311
-
312
  return None
313
  except RecursionError:
314
- st.error(
315
- 'Encountered an error while parsing JSON...'
316
  'the slide deck cannot be created, unfortunately ☹'
317
- '\nPlease try again later.'
 
318
  )
319
- logger.error('Caught RecursionError while parsing JSON. Cannot generate the slide deck!')
320
-
321
  return None
322
  except Exception:
323
- st.error(
324
  'Encountered an error while parsing JSON...'
325
  'the slide deck cannot be created, unfortunately ☹'
326
- '\nPlease try again later.'
327
- )
328
- logger.error(
329
- 'Caught ValueError: failed to parse JSON!'
330
  )
331
-
332
  return None
333
 
334
  if DOWNLOAD_FILE_KEY in st.session_state:
 
5
  import logging
6
  import pathlib
7
  import random
 
8
  import tempfile
9
  from typing import List, Union
10
 
 
16
  from langchain_core.messages import HumanMessage
17
  from langchain_core.prompts import ChatPromptTemplate
18
 
 
 
 
19
  from global_config import GlobalConfig
20
  from helpers import llm_helper, pptx_helper, text_helper
21
 
 
50
  return template
51
 
52
 
53
+ def are_all_inputs_valid(
54
+ user_prompt: str,
55
+ selected_provider: str,
56
+ selected_model: str,
57
+ user_key: str,
58
+ ) -> bool:
59
+ """
60
+ Validate user input and LLM selection.
61
+
62
+ :param user_prompt: The prompt.
63
+ :param selected_provider: The LLM provider.
64
+ :param selected_model: Name of the model.
65
+ :param user_key: User-provided API key.
66
+ :return: `True` if all inputs "look" OK; `False` otherwise.
67
+ """
68
+
69
+ if not text_helper.is_valid_prompt(user_prompt):
70
+ handle_error(
71
+ 'Not enough information provided!'
72
+ ' Please be a little more descriptive and type a few words'
73
+ ' with a few characters :)',
74
+ False
75
+ )
76
+ return False
77
+
78
+ if not selected_provider or not selected_model:
79
+ handle_error('No valid LLM provider and/or model name found!', False)
80
+ return False
81
+
82
+ if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key):
83
+ handle_error(
84
+ 'The LLM settings do not look correct. Make sure that an API key/access token'
85
+ ' is provided if the selected LLM requires it.',
86
+ False
87
+ )
88
+ return False
89
+
90
+ return True
91
+
92
+
93
+ def handle_error(error_msg: str, should_log: bool):
94
+ """
95
+ Display an error message in the app.
96
+
97
+ :param error_msg: The error message to be displayed.
98
+ :param should_log: If `True`, log the message.
99
+ """
100
+
101
+ if should_log:
102
+ logger.error(error_msg)
103
+
104
+ st.error(error_msg)
105
+
106
+
107
  APP_TEXT = _load_strings()
108
 
109
  # Session variables
 
130
  llm_provider_to_use = st.sidebar.selectbox(
131
  label='2: Select an LLM to use:',
132
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
133
+ index=GlobalConfig.DEFAULT_MODEL_INDEX,
134
+ help=GlobalConfig.LLM_PROVIDER_HELP,
 
 
 
135
  ).split(' ')[0]
136
 
137
  # The API key/access token
 
170
  with st.expander('Usage Instructions'):
171
  st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
172
 
173
+ st.info(APP_TEXT['like_feedback'])
174
+ st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
 
 
 
 
 
 
 
 
 
 
175
 
176
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
177
+ prompt_template = ChatPromptTemplate.from_template(
178
+ _get_prompt_template(
179
+ is_refinement=_is_it_refinement()
180
+ )
181
+ )
 
 
182
 
183
  # Since Streamlit app reloads at every interaction, display the chat history
184
  # from the save session state
185
  for msg in history.messages:
186
+ st.chat_message(msg.type).code(msg.content, language='json')
 
 
 
 
187
 
188
  if prompt := st.chat_input(
189
  placeholder=APP_TEXT['chat_placeholder'],
190
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
191
  ):
 
 
 
 
 
 
 
 
192
  provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
193
 
194
+ if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
 
195
  return
196
 
197
  logger.info(
 
200
  )
201
  st.chat_message('user').write(prompt)
202
 
 
 
 
 
 
 
 
203
  if _is_it_refinement():
204
+ user_messages = _get_user_messages()
205
+ user_messages.append(prompt)
206
+ list_of_msgs = [
207
+ f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
208
+ ]
209
  formatted_template = prompt_template.format(
210
  **{
211
+ 'instructions': '\n'.join(list_of_msgs),
212
  'previous_content': _get_last_response(),
213
  }
214
  )
215
  else:
216
+ formatted_template = prompt_template.format(**{'question': prompt})
 
 
 
 
217
 
218
  progress_bar = st.progress(0, 'Preparing to call LLM...')
219
  response = ''
220
 
221
  try:
222
+ llm = llm_helper.get_langchain_llm(
223
+ provider=provider,
224
+ model=llm_name,
225
+ max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
226
+ api_key=api_key_token.strip(),
227
+ )
228
+
229
+ if not llm:
230
+ handle_error(
231
+ 'Failed to create an LLM instance! Make sure that you have selected the'
232
+ ' correct model from the dropdown list and have provided correct API key'
233
+ ' or access token.',
234
+ False
235
  )
236
+ return
237
+
238
+ for _ in llm.stream(formatted_template):
239
+ response += _
240
+
241
+ # Update the progress bar with an approx progress percentage
242
  progress_bar.progress(
243
+ min(
244
+ len(response) / GlobalConfig.VALID_MODELS[
245
+ llm_provider_to_use
246
+ ]['max_new_tokens'],
247
+ 0.95
248
+ ),
249
  text='Streaming content...this might take a while...'
250
  )
251
  except requests.exceptions.ConnectionError:
252
+ handle_error(
253
  'A connection error occurred while streaming content from the LLM endpoint.'
254
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
255
+ ' Alternatively, try selecting a different LLM from the dropdown list.',
256
+ True
257
  )
 
 
258
  return
259
  except huggingface_hub.errors.ValidationError as ve:
260
+ handle_error(
261
  f'An error occurred while trying to generate the content: {ve}'
262
+ '\nPlease try again with a significantly shorter input text.',
263
+ True
264
  )
 
 
265
  return
266
  except Exception as ex:
267
+ handle_error(
268
  f'An unexpected error occurred while generating the content: {ex}'
269
  '\nPlease try again later, possibly with different inputs.'
270
+ ' Alternatively, try selecting a different LLM from the dropdown list.',
271
+ True
272
  )
 
 
273
  return
274
 
275
  history.add_user_message(prompt)
 
278
  # The content has been generated as JSON
279
  # There maybe trailing ``` at the end of the response -- remove them
280
  # To be careful: ``` may be part of the content as well when code is generated
281
+ response = text_helper.get_clean_json(response)
 
282
  logger.info(
283
+ 'Cleaned JSON length: %d', len(response)
 
284
  )
 
285
 
286
  # Now create the PPT file
287
  progress_bar.progress(
288
  GlobalConfig.LLM_PROGRESS_MAX,
289
  text='Finding photos online and generating the slide deck...'
290
  )
 
291
  progress_bar.progress(1.0, text='Done!')
 
292
  st.chat_message('ai').code(response, language='json')
293
 
294
+ if path := generate_slide_deck(response):
295
  _display_download_button(path)
296
 
297
  logger.info(
 
312
  try:
313
  parsed_data = json5.loads(json_str)
314
  except ValueError:
315
+ handle_error(
316
+ 'Encountered error while parsing JSON...will fix it and retry',
317
+ True
 
 
318
  )
319
  try:
320
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
321
  except ValueError:
322
+ handle_error(
323
  'Encountered an error again while fixing JSON...'
324
  'the slide deck cannot be created, unfortunately ☹'
325
+ '\nPlease try again later.',
326
+ True
327
  )
 
 
 
 
328
  return None
329
  except RecursionError:
330
+ handle_error(
331
+ 'Encountered a recursion error while parsing JSON...'
332
  'the slide deck cannot be created, unfortunately ☹'
333
+ '\nPlease try again later.',
334
+ True
335
  )
 
 
336
  return None
337
  except Exception:
338
+ handle_error(
339
  'Encountered an error while parsing JSON...'
340
  'the slide deck cannot be created, unfortunately ☹'
341
+ '\nPlease try again later.',
342
+ True
 
 
343
  )
 
344
  return None
345
 
346
  if DOWNLOAD_FILE_KEY in st.session_state:
global_config.py CHANGED
@@ -17,17 +17,32 @@ class GlobalConfig:
17
  A data class holding the configurations.
18
  """
19
 
20
- VALID_PROVIDERS = {'hf'}
 
 
21
  VALID_MODELS = {
 
 
 
 
 
22
  '[hf]mistralai/Mistral-7B-Instruct-v0.2': {
23
  'description': 'faster, shorter',
24
- 'max_new_tokens': 8192
 
25
  },
26
  '[hf]mistralai/Mistral-Nemo-Instruct-2407': {
27
  'description': 'longer response',
28
- 'max_new_tokens': 12228
 
29
  },
30
  }
 
 
 
 
 
 
31
  LLM_MODEL_TEMPERATURE = 0.2
32
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
33
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
 
17
  A data class holding the configurations.
18
  """
19
 
20
+ PROVIDER_HUGGING_FACE = 'hf'
21
+ PROVIDER_GOOGLE_GEMINI = 'gg'
22
+ VALID_PROVIDERS = {PROVIDER_HUGGING_FACE, PROVIDER_GOOGLE_GEMINI}
23
  VALID_MODELS = {
24
+ '[gg]gemini-1.5-flash-002': {
25
+ 'description': 'faster response',
26
+ 'max_new_tokens': 8192,
27
+ 'paid': True,
28
+ },
29
  '[hf]mistralai/Mistral-7B-Instruct-v0.2': {
30
  'description': 'faster, shorter',
31
+ 'max_new_tokens': 8192,
32
+ 'paid': False,
33
  },
34
  '[hf]mistralai/Mistral-Nemo-Instruct-2407': {
35
  'description': 'longer response',
36
+ 'max_new_tokens': 10240,
37
+ 'paid': False,
38
  },
39
  }
40
+ LLM_PROVIDER_HELP = (
41
+ 'LLM provider codes:\n\n'
42
+ '- **[gg]**: Google Gemini API\n'
43
+ '- **[hf]**: Hugging Face Inference Endpoint\n'
44
+ )
45
+ DEFAULT_MODEL_INDEX = 1
46
  LLM_MODEL_TEMPERATURE = 0.2
47
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
48
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
helpers/llm_helper.py CHANGED
@@ -1,13 +1,18 @@
 
 
 
1
  import logging
2
  import re
 
3
  from typing import Tuple, Union
4
 
5
  import requests
6
  from requests.adapters import HTTPAdapter
7
  from urllib3.util import Retry
8
-
9
  from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
10
- from langchain_core.language_models import LLM
 
 
11
 
12
  from global_config import GlobalConfig
13
 
@@ -49,30 +54,26 @@ def get_provider_model(provider_model: str) -> Tuple[str, str]:
49
  return '', ''
50
 
51
 
52
- def get_hf_endpoint(repo_id: str, max_new_tokens: int, api_key: str = '') -> LLM:
53
  """
54
- Get an LLM via the HuggingFaceEndpoint of LangChain.
55
-
56
- :param repo_id: The model name.
57
- :param max_new_tokens: The max new tokens to generate.
58
- :param api_key: [Optional] Hugging Face access token.
59
- :return: The HF LLM inference endpoint.
 
 
60
  """
61
 
62
- logger.debug('Getting LLM via HF endpoint: %s', repo_id)
 
 
 
 
 
63
 
64
- return HuggingFaceEndpoint(
65
- repo_id=repo_id,
66
- max_new_tokens=max_new_tokens,
67
- top_k=40,
68
- top_p=0.95,
69
- temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
70
- repetition_penalty=1.03,
71
- streaming=True,
72
- huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
73
- return_full_text=False,
74
- stop_sequences=['</s>'],
75
- )
76
 
77
 
78
  def get_langchain_llm(
@@ -80,22 +81,19 @@ def get_langchain_llm(
80
  model: str,
81
  max_new_tokens: int,
82
  api_key: str = ''
83
- ) -> Union[LLM, None]:
84
  """
85
  Get an LLM based on the provider and model specified.
86
 
87
  :param provider: The LLM provider. Valid values are `hf` for Hugging Face.
88
- :param model:
89
- :param max_new_tokens:
90
- :param api_key:
91
- :return:
92
  """
93
- if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
94
- return None
95
 
96
- if provider == 'hf':
97
  logger.debug('Getting LLM via HF endpoint: %s', model)
98
-
99
  return HuggingFaceEndpoint(
100
  repo_id=model,
101
  max_new_tokens=max_new_tokens,
@@ -109,6 +107,27 @@ def get_langchain_llm(
109
  stop_sequences=['</s>'],
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return None
113
 
114
 
 
1
+ """
2
+ Helper functions to access LLMs.
3
+ """
4
  import logging
5
  import re
6
+ import sys
7
  from typing import Tuple, Union
8
 
9
  import requests
10
  from requests.adapters import HTTPAdapter
11
  from urllib3.util import Retry
 
12
  from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
13
+ from langchain_core.language_models import BaseLLM
14
+
15
+ sys.path.append('..')
16
 
17
  from global_config import GlobalConfig
18
 
 
54
  return '', ''
55
 
56
 
57
+ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool:
58
  """
59
+ Verify whether LLM settings are proper.
60
+ This function does not verify whether `api_key` is correct. It only confirms that the key has
61
+ at least five characters. Key verification is done when the LLM is created.
62
+
63
+ :param provider: Name of the LLM provider.
64
+ :param model: Name of the model.
65
+ :param api_key: The API key or access token.
66
+ :return: `True` if the settings "look" OK; `False` otherwise.
67
  """
68
 
69
+ if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
70
+ return False
71
+
72
+ if provider in [GlobalConfig.PROVIDER_GOOGLE_GEMINI, ]:
73
+ if not api_key or len(api_key) < 5:
74
+ return False
75
 
76
+ return True
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  def get_langchain_llm(
 
81
  model: str,
82
  max_new_tokens: int,
83
  api_key: str = ''
84
+ ) -> Union[BaseLLM, None]:
85
  """
86
  Get an LLM based on the provider and model specified.
87
 
88
  :param provider: The LLM provider. Valid values are `hf` for Hugging Face.
89
+ :param model: The name of the LLM.
90
+ :param max_new_tokens: The maximum number of tokens to generate.
91
+ :param api_key: API key or access token to use.
92
+ :return: An instance of the LLM or `None` in case of any error.
93
  """
 
 
94
 
95
+ if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
96
  logger.debug('Getting LLM via HF endpoint: %s', model)
 
97
  return HuggingFaceEndpoint(
98
  repo_id=model,
99
  max_new_tokens=max_new_tokens,
 
107
  stop_sequences=['</s>'],
108
  )
109
 
110
+ if provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI:
111
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
112
+ from langchain_google_genai import GoogleGenerativeAI
113
+
114
+ return GoogleGenerativeAI(
115
+ model=model,
116
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
117
+ max_tokens=max_new_tokens,
118
+ timeout=None,
119
+ max_retries=2,
120
+ google_api_key=api_key,
121
+ safety_settings={
122
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
123
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
124
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
125
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
126
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
127
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
128
+ }
129
+ )
130
+
131
  return None
132
 
133
 
requirements.txt CHANGED
@@ -10,6 +10,7 @@ pydantic==2.9.1
10
  langchain~=0.3.7
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
 
13
  streamlit~=1.38.0
14
 
15
  python-pptx
@@ -19,7 +20,6 @@ requests~=2.32.3
19
 
20
  transformers~=4.44.0
21
  torch==2.4.0
22
- langchain-community
23
 
24
  urllib3~=2.2.1
25
  lxml~=4.9.3
 
10
  langchain~=0.3.7
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
+ langchain-google-genai==2.0.6
14
  streamlit~=1.38.0
15
 
16
  python-pptx
 
20
 
21
  transformers~=4.44.0
22
  torch==2.4.0
 
23
 
24
  urllib3~=2.2.1
25
  lxml~=4.9.3
strings.json CHANGED
@@ -33,5 +33,6 @@
33
  "Looks like you have a looming deadline. Can I help you get started with your slide deck?",
34
  "Hello! What topic do you have on your mind today?"
35
  ],
36
- "chat_placeholder": "Write the topic or instructions here"
 
37
  }
 
33
  "Looks like you have a looming deadline. Can I help you get started with your slide deck?",
34
  "Hello! What topic do you have on your mind today?"
35
  ],
36
+ "chat_placeholder": "Write the topic or instructions here",
37
+ "like_feedback": "If you like SlideDeck AI, please consider leaving a heart ❤\uFE0F on the [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai). Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated."
38
  }