barunsaha commited on
Commit
44d6df8
1 Parent(s): cf45a37

Add support for Cohere

Browse files
Files changed (5) hide show
  1. app.py +8 -6
  2. global_config.py +10 -3
  3. helpers/llm_helper.py +19 -1
  4. helpers/pptx_helper.py +27 -22
  5. requirements.txt +1 -0
app.py CHANGED
@@ -138,11 +138,11 @@ with st.sidebar:
138
  api_key_token = st.text_input(
139
  label=(
140
  '3: Paste your API key/access token:\n\n'
141
- '*Optional* if an HF Mistral LLM is selected from the list but still encouraged.\n\n'
 
142
  ),
143
  type='password',
144
  )
145
- st.caption('(Wrong HF access token will lead to validation error)')
146
 
147
 
148
  def build_ui():
@@ -152,9 +152,9 @@ def build_ui():
152
 
153
  st.title(APP_TEXT['app_name'])
154
  st.subheader(APP_TEXT['caption'])
155
- # st.markdown(
156
- # '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501
157
- # )
158
 
159
  with st.expander('Usage Policies and Limitations'):
160
  st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2'])
@@ -267,7 +267,9 @@ def set_up_chat_ui():
267
  handle_error(
268
  f'An unexpected error occurred while generating the content: {ex}'
269
  '\nPlease try again later, possibly with different inputs.'
270
- ' Alternatively, try selecting a different LLM from the dropdown list.',
 
 
271
  True
272
  )
273
  return
 
138
  api_key_token = st.text_input(
139
  label=(
140
  '3: Paste your API key/access token:\n\n'
141
+ '*Mandatory* for Cohere and Gemini LLMs.'
142
+ ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
143
  ),
144
  type='password',
145
  )
 
146
 
147
 
148
  def build_ui():
 
152
 
153
  st.title(APP_TEXT['app_name'])
154
  st.subheader(APP_TEXT['caption'])
155
+ st.markdown(
156
+ '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501
157
+ )
158
 
159
  with st.expander('Usage Policies and Limitations'):
160
  st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2'])
 
267
  handle_error(
268
  f'An unexpected error occurred while generating the content: {ex}'
269
  '\nPlease try again later, possibly with different inputs.'
270
+ ' Alternatively, try selecting a different LLM from the dropdown list.'
271
+ ' If you are using Cohere or Gemini models, make sure that you have provided'
272
+ ' a correct API key.',
273
  True
274
  )
275
  return
global_config.py CHANGED
@@ -17,10 +17,16 @@ class GlobalConfig:
17
  A data class holding the configurations.
18
  """
19
 
20
- PROVIDER_HUGGING_FACE = 'hf'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
- VALID_PROVIDERS = {PROVIDER_HUGGING_FACE, PROVIDER_GOOGLE_GEMINI}
 
23
  VALID_MODELS = {
 
 
 
 
 
24
  '[gg]gemini-1.5-flash-002': {
25
  'description': 'faster response',
26
  'max_new_tokens': 8192,
@@ -39,10 +45,11 @@ class GlobalConfig:
39
  }
40
  LLM_PROVIDER_HELP = (
41
  'LLM provider codes:\n\n'
 
42
  '- **[gg]**: Google Gemini API\n'
43
  '- **[hf]**: Hugging Face Inference Endpoint\n'
44
  )
45
- DEFAULT_MODEL_INDEX = 1
46
  LLM_MODEL_TEMPERATURE = 0.2
47
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
48
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
 
17
  A data class holding the configurations.
18
  """
19
 
20
+ PROVIDER_COHERE = 'co'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
+ PROVIDER_HUGGING_FACE = 'hf'
23
+ VALID_PROVIDERS = {PROVIDER_COHERE, PROVIDER_GOOGLE_GEMINI, PROVIDER_HUGGING_FACE}
24
  VALID_MODELS = {
25
+ '[co]command-r-08-2024': {
26
+ 'description': 'simpler, slower',
27
+ 'max_new_tokens': 4096,
28
+ 'paid': True,
29
+ },
30
  '[gg]gemini-1.5-flash-002': {
31
  'description': 'faster response',
32
  'max_new_tokens': 8192,
 
45
  }
46
  LLM_PROVIDER_HELP = (
47
  'LLM provider codes:\n\n'
48
+ '- **[co]**: Cohere\n'
49
  '- **[gg]**: Google Gemini API\n'
50
  '- **[hf]**: Hugging Face Inference Endpoint\n'
51
  )
52
+ DEFAULT_MODEL_INDEX = 2
53
  LLM_MODEL_TEMPERATURE = 0.2
54
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
55
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
helpers/llm_helper.py CHANGED
@@ -9,7 +9,6 @@ from typing import Tuple, Union
9
  import requests
10
  from requests.adapters import HTTPAdapter
11
  from urllib3.util import Retry
12
- from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
13
  from langchain_core.language_models import BaseLLM
14
 
15
  sys.path.append('..')
@@ -22,6 +21,8 @@ HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOK
22
  REQUEST_TIMEOUT = 35
23
 
24
  logger = logging.getLogger(__name__)
 
 
25
 
26
  retries = Retry(
27
  total=5,
@@ -93,6 +94,8 @@ def get_langchain_llm(
93
  """
94
 
95
  if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
 
 
96
  logger.debug('Getting LLM via HF endpoint: %s', model)
97
  return HuggingFaceEndpoint(
98
  repo_id=model,
@@ -111,6 +114,7 @@ def get_langchain_llm(
111
  from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
112
  from langchain_google_genai import GoogleGenerativeAI
113
 
 
114
  return GoogleGenerativeAI(
115
  model=model,
116
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
@@ -128,11 +132,25 @@ def get_langchain_llm(
128
  }
129
  )
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  return None
132
 
133
 
134
  if __name__ == '__main__':
135
  inputs = [
 
136
  '[hf]mistralai/Mistral-7B-Instruct-v0.2',
137
  '[gg]gemini-1.5-flash-002'
138
  ]
 
9
  import requests
10
  from requests.adapters import HTTPAdapter
11
  from urllib3.util import Retry
 
12
  from langchain_core.language_models import BaseLLM
13
 
14
  sys.path.append('..')
 
21
  REQUEST_TIMEOUT = 35
22
 
23
  logger = logging.getLogger(__name__)
24
+ logging.getLogger('httpx').setLevel(logging.WARNING)
25
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
26
 
27
  retries = Retry(
28
  total=5,
 
94
  """
95
 
96
  if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
97
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
98
+
99
  logger.debug('Getting LLM via HF endpoint: %s', model)
100
  return HuggingFaceEndpoint(
101
  repo_id=model,
 
114
  from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
115
  from langchain_google_genai import GoogleGenerativeAI
116
 
117
+ logger.debug('Getting LLM via Google Gemini: %s', model)
118
  return GoogleGenerativeAI(
119
  model=model,
120
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
 
132
  }
133
  )
134
 
135
+ if provider == GlobalConfig.PROVIDER_COHERE:
136
+ from langchain_cohere.llms import Cohere
137
+
138
+ logger.debug('Getting LLM via Cohere: %s', model)
139
+ return Cohere(
140
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
141
+ max_tokens=max_new_tokens,
142
+ timeout_seconds=None,
143
+ max_retries=2,
144
+ cohere_api_key=api_key,
145
+ streaming=True,
146
+ )
147
+
148
  return None
149
 
150
 
151
  if __name__ == '__main__':
152
  inputs = [
153
+ '[co]Cohere',
154
  '[hf]mistralai/Mistral-7B-Instruct-v0.2',
155
  '[gg]gemini-1.5-flash-002'
156
  ]
helpers/pptx_helper.py CHANGED
@@ -115,37 +115,42 @@ def generate_powerpoint_presentation(
115
 
116
  # Add content in a loop
117
  for a_slide in parsed_data['slides']:
118
- is_processing_done = _handle_icons_ideas(
119
- presentation=presentation,
120
- slide_json=a_slide,
121
- slide_width_inch=slide_width_inch,
122
- slide_height_inch=slide_height_inch
123
- )
124
-
125
- if not is_processing_done:
126
- is_processing_done = _handle_double_col_layout(
127
  presentation=presentation,
128
  slide_json=a_slide,
129
  slide_width_inch=slide_width_inch,
130
  slide_height_inch=slide_height_inch
131
  )
132
 
133
- if not is_processing_done:
134
- is_processing_done = _handle_step_by_step_process(
135
- presentation=presentation,
136
- slide_json=a_slide,
137
- slide_width_inch=slide_width_inch,
138
- slide_height_inch=slide_height_inch
139
- )
140
 
141
- if not is_processing_done:
142
- _handle_default_display(
143
- presentation=presentation,
144
- slide_json=a_slide,
145
- slide_width_inch=slide_width_inch,
146
- slide_height_inch=slide_height_inch
 
 
 
 
 
 
 
 
147
  )
148
 
 
 
 
 
149
  # The thank-you slide
150
  last_slide_layout = presentation.slide_layouts[0]
151
  slide = presentation.slides.add_slide(last_slide_layout)
 
115
 
116
  # Add content in a loop
117
  for a_slide in parsed_data['slides']:
118
+ try:
119
+ is_processing_done = _handle_icons_ideas(
 
 
 
 
 
 
 
120
  presentation=presentation,
121
  slide_json=a_slide,
122
  slide_width_inch=slide_width_inch,
123
  slide_height_inch=slide_height_inch
124
  )
125
 
126
+ if not is_processing_done:
127
+ is_processing_done = _handle_double_col_layout(
128
+ presentation=presentation,
129
+ slide_json=a_slide,
130
+ slide_width_inch=slide_width_inch,
131
+ slide_height_inch=slide_height_inch
132
+ )
133
 
134
+ if not is_processing_done:
135
+ is_processing_done = _handle_step_by_step_process(
136
+ presentation=presentation,
137
+ slide_json=a_slide,
138
+ slide_width_inch=slide_width_inch,
139
+ slide_height_inch=slide_height_inch
140
+ )
141
+
142
+ if not is_processing_done:
143
+ _handle_default_display(
144
+ presentation=presentation,
145
+ slide_json=a_slide,
146
+ slide_width_inch=slide_width_inch,
147
+ slide_height_inch=slide_height_inch
148
  )
149
 
150
+ except Exception:
151
+ # In case of any unforeseen error, try to salvage what is available
152
+ continue
153
+
154
  # The thank-you slide
155
  last_slide_layout = presentation.slide_layouts[0]
156
  slide = presentation.slides.add_slide(last_slide_layout)
requirements.txt CHANGED
@@ -11,6 +11,7 @@ langchain~=0.3.7
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
 
14
  streamlit~=1.38.0
15
 
16
  python-pptx
 
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
14
+ langchain-cohere==0.3.3
15
  streamlit~=1.38.0
16
 
17
  python-pptx