AidenYan commited on
Commit
c9afddd
·
verified ·
1 Parent(s): d0db2d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -1,29 +1,33 @@
1
  import streamlit as st
2
- from transformers import pipeline as transformers_pipeline, AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
3
  import pandas as pd
4
  import torch
 
 
 
5
 
6
- # Load the tokenizer and models
7
  similarity_tokenizer = AutoTokenizer.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
8
  similarity_model = AutoModelForSequenceClassification.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
9
-
10
  story_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/genre-story-generator-v2")
11
  story_model = AutoModelForCausalLM.from_pretrained("pranavpsv/genre-story-generator-v2")
12
 
13
- labels_df = pd.read_csv("labels_to_image_urls.csv") # Update this path
 
14
 
15
  def image_to_text_from_url(image_url):
16
  """
17
- Generates a caption from the image at the given URL.
18
  """
19
  image_to_text_pipeline = transformers_pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
20
  return image_to_text_pipeline(image_url)[0]['generated_text']
21
 
22
  def generate_mask_from_result(input_text):
23
  """
24
- Simulate generating a mask from the result. Replace with actual logic.
25
  """
26
- return "This is a placeholder text based on the input: " + input_text
 
27
 
28
  def generate_story_from_text(input_text):
29
  """
@@ -35,7 +39,7 @@ def generate_story_from_text(input_text):
35
 
36
  def select_closest_sentence(generated_text):
37
  """
38
- Predict the similarity label for the generated text.
39
  """
40
  inputs = similarity_tokenizer(generated_text, return_tensors="pt")
41
  outputs = similarity_model(**inputs)
@@ -45,7 +49,7 @@ def select_closest_sentence(generated_text):
45
 
46
  def get_image_url_for_label(label):
47
  """
48
- Returns the image URL for a given label.
49
  """
50
  row = labels_df[labels_df['Label'] == label]
51
  if not row.empty:
@@ -53,29 +57,47 @@ def get_image_url_for_label(label):
53
  else:
54
  return None
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  def main():
57
  st.title("SmartCart (Product Recommender)")
58
 
 
59
  input_option = st.radio("Select input option:", ("Text", "URL"))
60
 
 
61
  if input_option == "Text":
62
  text_input = st.text_input("Enter the text:")
63
  if st.button("Generate Story and Image") and text_input:
64
- generated_text = generate_mask_from_result(text_input)
65
- story_text = generate_story_from_text(generated_text)
66
- st.success(f'Generated Caption: {text_input}')
67
- st.success(f'Generated Text: {generated_text}')
68
- st.text_area('Generated Story:', story_text, height=200)
 
 
69
 
 
70
  elif input_option == "URL":
71
  image_url = st.text_input("Enter the image URL:")
72
  if st.button("Generate Story and Image") and image_url:
73
  image_text = image_to_text_from_url(image_url)
74
- generated_text = generate_mask_from_result(image_text)
75
- story_text = generate_story_from_text(generated_text)
76
- st.success(f'Generated Caption: {image_text}')
77
- st.success(f'Generated Text: {generated_text}')
78
- st.text_area('Generated Story:', story_text, height=200)
 
 
79
 
80
  if __name__ == "__main__":
81
  main()
 
1
  import streamlit as st
2
+ from transformers import pipeline as transformers_pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
3
  import pandas as pd
4
  import torch
5
+ import requests
6
+ from PIL import Image
7
+ import io
8
 
9
+ # Load tokenizer and models for similarity and story generation
10
  similarity_tokenizer = AutoTokenizer.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
11
  similarity_model = AutoModelForSequenceClassification.from_pretrained("AidenYan/MiniLM_L6_v2_finetuned_ISOM5240_Group27")
 
12
  story_tokenizer = AutoTokenizer.from_pretrained("pranavpsv/genre-story-generator-v2")
13
  story_model = AutoModelForCausalLM.from_pretrained("pranavpsv/genre-story-generator-v2")
14
 
15
+ # Load the CSV file into a dataframe
16
+ labels_df = pd.read_csv("path_to_your_csv_file/labels_to_image_urls.csv") # Make sure to update this path
17
 
18
  def image_to_text_from_url(image_url):
19
  """
20
+ Generates a caption from the image at the given URL using an image-to-text pipeline.
21
  """
22
  image_to_text_pipeline = transformers_pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
23
  return image_to_text_pipeline(image_url)[0]['generated_text']
24
 
25
  def generate_mask_from_result(input_text):
26
  """
27
+ Placeholder for generating a mask from the result. This should be replaced with your actual logic.
28
  """
29
+ # Placeholder logic, replace with actual text processing if needed
30
+ return "Processed input: " + input_text
31
 
32
  def generate_story_from_text(input_text):
33
  """
 
39
 
40
  def select_closest_sentence(generated_text):
41
  """
42
+ Predicts the similarity label for the generated text using the similarity model.
43
  """
44
  inputs = similarity_tokenizer(generated_text, return_tensors="pt")
45
  outputs = similarity_model(**inputs)
 
49
 
50
  def get_image_url_for_label(label):
51
  """
52
+ Returns the image URL for a given label from the labels dataframe.
53
  """
54
  row = labels_df[labels_df['Label'] == label]
55
  if not row.empty:
 
57
  else:
58
  return None
59
 
60
+ def display_image_from_url(image_url):
61
+ """
62
+ Displays an image in the Streamlit app given its URL.
63
+ """
64
+ try:
65
+ response = requests.get(image_url)
66
+ image = Image.open(io.BytesIO(response.content))
67
+ st.image(image, use_column_width=True)
68
+ except Exception as e:
69
+ st.error(f"Failed to load image from URL: {e}")
70
+
71
  def main():
72
  st.title("SmartCart (Product Recommender)")
73
 
74
+ # User input for text or URL
75
  input_option = st.radio("Select input option:", ("Text", "URL"))
76
 
77
+ # Handling input via text
78
  if input_option == "Text":
79
  text_input = st.text_input("Enter the text:")
80
  if st.button("Generate Story and Image") and text_input:
81
+ processed_text = generate_mask_from_result(text_input)
82
+ story_text = generate_story_from_text(processed_text)
83
+ st.text_area('Generated Story:', story_text, height=300)
84
+ closest_label = select_closest_sentence(processed_text)
85
+ image_url = get_image_url_for_label(closest_label)
86
+ if image_url:
87
+ display_image_from_url(image_url)
88
 
89
+ # Handling input via image URL
90
  elif input_option == "URL":
91
  image_url = st.text_input("Enter the image URL:")
92
  if st.button("Generate Story and Image") and image_url:
93
  image_text = image_to_text_from_url(image_url)
94
+ processed_text = generate_mask_from_result(image_text)
95
+ story_text = generate_story_from_text(processed_text)
96
+ st.text_area('Generated Story:', story_text, height=300)
97
+ closest_label = select_closest_sentence(processed_text)
98
+ mapped_image_url = get_image_url_for_label(closest_label)
99
+ if mapped_image_url:
100
+ display_image_from_url(mapped_image_url)
101
 
102
  if __name__ == "__main__":
103
  main()