yonikremer commited on
Commit
b1dd47e
·
1 Parent(s): 15bf463

Added if check that the model is supported

Browse files
Files changed (4) hide show
  1. app.py +8 -8
  2. hanlde_form_submit.py +10 -0
  3. requirements.txt +4 -3
  4. supported_models.py +33 -40
app.py CHANGED
@@ -12,7 +12,6 @@ from on_server_start import main as on_server_start_main
12
 
13
  on_server_start_main()
14
 
15
- AVAILABLE_MODEL_NAMES = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
16
 
17
  st.title("Grouped Sampling Demo")
18
 
@@ -22,8 +21,6 @@ with st.form("request_form"):
22
  label="Model name",
23
  value="gpt2",
24
  help=f"The name of the model to use."
25
- f" Must be a model from this list:"
26
- f" {AVAILABLE_MODEL_NAMES}"
27
  )
28
 
29
  output_length: int = st.number_input(
@@ -35,18 +32,21 @@ with st.form("request_form"):
35
  )
36
 
37
  submitted_prompt: str = st.text_area(
38
- label="Input for the model",
39
  help="Enter the prompt for the model. The model will generate a response based on this prompt.",
40
  max_chars=16384,
 
41
  )
42
 
43
  submitted: bool = st.form_submit_button(
44
  label="Generate",
45
  help="Generate the output text.",
46
- disabled=False
47
-
48
  )
49
 
50
  if submitted:
51
- output = on_form_submit(selected_model_name, output_length, submitted_prompt)
52
- st.write(f"Generated text: {output}")
 
 
 
 
12
 
13
  on_server_start_main()
14
 
 
15
 
16
  st.title("Grouped Sampling Demo")
17
 
 
21
  label="Model name",
22
  value="gpt2",
23
  help=f"The name of the model to use."
 
 
24
  )
25
 
26
  output_length: int = st.number_input(
 
32
  )
33
 
34
  submitted_prompt: str = st.text_area(
35
+ label="Input for the model, It is highly recommended to write an English prompt.",
36
  help="Enter the prompt for the model. The model will generate a response based on this prompt.",
37
  max_chars=16384,
38
+ min_chars=16,
39
  )
40
 
41
  submitted: bool = st.form_submit_button(
42
  label="Generate",
43
  help="Generate the output text.",
44
+ disabled=False,
 
45
  )
46
 
47
  if submitted:
48
+ try:
49
+ output = on_form_submit(selected_model_name, output_length, submitted_prompt)
50
+ st.write(f"Generated text: {output}")
51
+ except ValueError as e:
52
+ st.error(e)
hanlde_form_submit.py CHANGED
@@ -2,6 +2,12 @@ import streamlit as st
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
4
 
 
 
 
 
 
 
5
  def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
6
  """
7
  Creates a pipeline with the given model name and group size.
@@ -25,6 +31,10 @@ def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
25
  :param prompt: The prompt to use.
26
  :return: The output of the model.
27
  """
 
 
 
 
28
  pipeline = create_pipeline(
29
  model_name,
30
  group_size,
 
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
4
 
5
+ from supported_models import get_supported_model_names
6
+
7
+
8
+ SUPPORTED_MODEL_NAMES = get_supported_model_names()
9
+
10
+
11
  def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
12
  """
13
  Creates a pipeline with the given model name and group size.
 
31
  :param prompt: The prompt to use.
32
  :return: The output of the model.
33
  """
34
+ if model_name not in SUPPORTED_MODEL_NAMES:
35
+ raise ValueError(f"The selected model {model_name} is not supported."
36
+ f"Supported models are all the models in:"
37
+ f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
38
  pipeline = create_pipeline(
39
  model_name,
40
  group_size,
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
  grouped-sampling>=1.0.4
2
  streamlit==1.17.0
3
  torch>1.12.1
4
- transformers
5
  hatchling
6
- beautifulsoup4
7
- urllib3
 
 
1
  grouped-sampling>=1.0.4
2
  streamlit==1.17.0
3
  torch>1.12.1
4
+ transformers~=4.26.0
5
  hatchling
6
+ beautifulsoup4~=4.11.2
7
+ urllib3
8
+ requests~=2.28.2
supported_models.py CHANGED
@@ -1,48 +1,41 @@
1
- from typing import List, Generator
2
 
3
- from bs4 import BeautifulSoup, Tag
4
- import urllib3
 
5
 
6
- SUPPORTED_MODEL_NAME_PAGES_FORMAT: str = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
 
7
 
8
 
9
- def get_model_name(model_card: Tag) -> str:
10
- """
11
- Gets the model name from the model card.
12
- :param model_card: The model card to get the model name from.
13
- :return: The model name.
14
- """
15
  h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
16
- h4_tag: Tag = model_card.find("h4", class_=h4_class)
17
  return h4_tag.text
18
 
19
 
20
- def get_soups() -> Generator[BeautifulSoup, None, None]:
21
- """
22
- Gets the pages to scrape.
23
- :return: A list of the pages to scrape.
24
- """
25
- curr_page_index = 0
26
- while True:
27
- curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={curr_page_index}"
28
- request = urllib3.PoolManager().request("GET", curr_page_url)
29
- if request.status != 200:
30
- return
31
- yield BeautifulSoup(request.data, "html.parser")
32
- curr_page_index += 1
33
-
34
-
35
- def get_supported_model_names() -> Generator[str, None, None]:
36
- """
37
- Scrapes the supported model names from the hugging face website.
38
- :return: A list of the supported model names.
39
- """
40
- for soup in get_soups():
41
- model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
42
- for model_card in model_cards:
43
- yield get_model_name(model_card)
44
-
45
-
46
- if __name__ == "__main__":
47
- for model_name in get_supported_model_names():
48
- print(model_name)
 
1
+ from typing import Generator
2
 
3
+ import requests
4
+ from bs4 import BeautifulSoup
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
 
7
+ SUPPORTED_MODEL_NAME_PAGES_FORMAT = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
8
+ MAX_WORKERS = 10
9
 
10
 
11
+ def get_model_name(model_card: BeautifulSoup) -> str:
 
 
 
 
 
12
  h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
13
+ h4_tag = model_card.find("h4", class_=h4_class)
14
  return h4_tag.text
15
 
16
 
17
+ def get_page(page_index: int):
18
+ curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={page_index}"
19
+ response = requests.get(curr_page_url)
20
+ if response.status_code == 200:
21
+ soup = BeautifulSoup(response.content, "html.parser")
22
+ return soup
23
+ return None
24
+
25
+
26
+ def get_model_names(soup):
27
+ model_cards = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
28
+ return [get_model_name(model_card) for model_card in model_cards]
29
+
30
+
31
+ def generate_supported_model_names() -> Generator[str, None, None]:
32
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
33
+ future_to_index = {executor.submit(get_page, index): index for index in range(100)}
34
+ for future in as_completed(future_to_index):
35
+ soup = future.result()
36
+ if soup:
37
+ yield from get_model_names(soup)
38
+
39
+
40
+ def get_supported_model_names() -> set[str]:
41
+ return set(generate_supported_model_names())