Files changed (1) hide show
  1. app.py +124 -59
app.py CHANGED
@@ -1,76 +1,141 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer
4
- import itertools
5
- from nltk.corpus import stopwords
6
- import nltk
7
- import easyocr
8
- import torch
9
  import numpy as np
 
10
  nltk.download('stopwords')
 
 
 
 
 
 
 
 
 
 
11
 
12
- # load the model and tokenizer
13
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
14
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
15
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
- reader = easyocr.Reader(['en'])
18
 
19
- # set up Streamlit app
20
- st.set_page_config(layout='wide', page_title='Image Hashtag Recommender')
21
 
22
- def generate_hashtags(image_file):
23
- # get image and convert to RGB mode
24
- image = Image.open(image_file).convert('RGB')
 
 
 
 
25
 
26
- # extract image features
27
- inputs = processor(image, return_tensors="pt")
 
 
 
 
28
 
29
- output_ids = model.generate(**inputs)
30
- # out_text = processor.decode(out[0], skip_special_tokens=True)
 
31
 
32
- # decode the model output to text and extract caption words
33
- output_text = processor.decode(output_ids[0], skip_special_tokens=True)
34
- caption_words = [word.lower() for word in output_text.split() if not word.startswith("#")]
35
 
36
- # remove stop words from caption words
37
- stop_words = set(stopwords.words('english'))
38
- caption_words = [word for word in caption_words if word not in stop_words]
 
 
 
 
 
 
 
 
39
 
40
- # use easyocr to extract text from the image
41
- text = reader.readtext(np.array(image))
42
- detected_text = " ".join([item[1] for item in text])
 
43
 
44
- # combine caption words and detected text
45
- all_words = caption_words + detected_text.split()
46
 
47
- # generate combinations of words for hashtags
48
- hashtags = []
49
- for n in range(1, 4):
50
- word_combinations = list(itertools.combinations(all_words, n))
51
- for combination in word_combinations:
52
- hashtag = "#" + "".join(combination)
53
- hashtags.append(hashtag)
54
-
55
- # return top 10 hashtags by frequency
56
- top_hashtags = [tag for tag in sorted(set(hashtags), key=hashtags.count, reverse=True) if tag != "#"]
57
- return [top_hashtags[:10], output_text]
58
-
59
- st.title("Image Caption and HashTag Recommender")
60
-
61
- image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
62
-
63
- # if the user has submitted an image, generate hashtags
64
- if image_file is not None:
65
- try:
66
- hashtags = generate_hashtags(image_file)
67
- if len(hashtags) > 0:
68
- st.write(f"Caption : {hashtags[1]}")
69
- st.write("Top 10 hashtags for this image:")
70
- for tag in hashtags[0]:
71
- st.write(tag)
72
- else:
73
- st.write("No hashtags found for this image.")
74
- except Exception as e:
75
- st.write(f"Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
 
 
 
 
 
 
3
  import numpy as np
4
+ import nltk
5
  nltk.download('stopwords')
6
+ nltk.download('punkt')
7
+ import pandas as pd
8
+ import random
9
+ import easyocr
10
+ import re
11
+ from nltk.corpus import stopwords
12
+ from nltk.tokenize import word_tokenize
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+ from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
16
 
17
+ # Directory path to the saved model on Google Drive
18
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
19
 
20
+ # Load the feature extractor and tokenizer
21
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
22
  tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
23
 
 
 
24
 
25
+ def generate_captions(image):
26
+ image = Image.open(image).convert("RGB")
27
+ generated_caption = tokenizer.decode(model.generate(feature_extractor(image, return_tensors="pt").pixel_values.to("cpu"))[0])
28
+ sentence = generated_caption
29
+ text_to_remove = "<|endoftext|>"
30
+ generated_caption = sentence.replace(text_to_remove, "")
31
+ return generated_caption
32
 
33
+ # use easyocr to extract text from the image
34
+ def image_text(image):
35
+ img_np = np.array(image)
36
+ reader = easyocr.Reader(['en'])
37
+ text = reader.readtext(img_np)
38
+ detected_text = " ".join([item[1] for item in text])
39
 
40
+ # Extract individual words, convert to lowercase, and add "#" symbol
41
+ detected_text= ['#' + entry[1].strip().lower().replace(" ", "") for entry in text]
42
+ return detected_text
43
 
44
+ # Load NLTK stopwords for filtering
45
+ stop_words = set(stopwords.words('english'))
 
46
 
47
+ # Add hashtags to keywords, which have been generated from image captioing
48
+ def add_hashtags(keywords):
49
+ hashtags = []
50
+
51
+ for keyword in keywords:
52
+ # Generate hashtag from the keyword (you can modify this part as per your requirements)
53
+ hashtag = '#' + keyword.lower()
54
+
55
+ hashtags.append(hashtag)
56
+
57
+ return hashtags
58
 
59
+ def trending_hashtags(caption):
60
+ # Read trending hashtags from a file separated by commas
61
+ with open("hashies.txt", "r") as file:
62
+ hashtags_string = file.read()
63
 
64
+ # Split the hashtags by commas and remove any leading/trailing spaces
65
+ trending_hashtags = [hashtag.strip() for hashtag in hashtags_string.split(',')]
66
 
67
+ # Create a DataFrame from the hashtags
68
+ df = pd.DataFrame(trending_hashtags, columns=["Hashtags"])
69
+
70
+ # Function to extract keywords from a given text
71
+ def extract_keywords(caption):
72
+ tokens = word_tokenize(caption)
73
+ keywords = [token.lower() for token in tokens if token.lower() not in stop_words]
74
+ return keywords
75
+
76
+ # Extract keywords from caption and trending hashtags
77
+ caption_keywords = extract_keywords(caption)
78
+ hashtag_keywords = [extract_keywords(hashtag) for hashtag in df["Hashtags"]]
79
+
80
+ # Function to calculate cosine similarity between two strings
81
+ def calculate_similarity(text1, text2):
82
+ tfidf_vectorizer = TfidfVectorizer()
83
+ tfidf_matrix = tfidf_vectorizer.fit_transform([text1, text2])
84
+ similarity_matrix = cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])
85
+ return similarity_matrix[0][0]
86
+
87
+ # Calculate similarity between caption and each trending hashtag
88
+ similarities = [calculate_similarity(' '.join(caption_keywords), ' '.join(keywords)) for keywords in hashtag_keywords]
89
+
90
+ # Sort trending hashtags based on similarity in descending order
91
+ sorted_hashtags = [hashtag for _, hashtag in sorted(zip(similarities, df["Hashtags"]), reverse=True)]
92
+
93
+ # Select top k relevant hashtags (e.g., top 5) without duplicates
94
+ selected_hashtags = list(set(sorted_hashtags[:5]))
95
+
96
+ selected_hashtag = [word.strip("'") for word in selected_hashtags]
97
+
98
+ return selected_hashtag
99
+
100
+ # create the Streamlit app
101
+ def app():
102
+ st.title('Image from your Side, Trending Hashtags from our Side')
103
+
104
+ st.write('Upload an image to see what we have in store.')
105
+
106
+ # create file uploader
107
+ uploaded_file = st.file_uploader("Got You Covered, Upload your wish!, magic on the Way! ", type=["jpg", "jpeg", "png"])
108
+
109
+ # check if file has been uploaded
110
+ if uploaded_file is not None:
111
+ # load the image
112
+ image = Image.open(uploaded_file).convert("RGB")
113
+
114
+ # Image Captions
115
+ string = generate_captions(uploaded_file)
116
+ tokens = word_tokenize(string)
117
+ keywords = [token.lower() for token in tokens if token.lower() not in stop_words]
118
+ hashtags = add_hashtags(keywords)
119
+
120
+ # Text Captions from image
121
+ extracted_text = image_text(image)
122
+
123
+ #Final Hashtags Generation
124
+ web_hashtags = trending_hashtags(string)
125
+
126
+ combined_hashtags = hashtags + extracted_text + web_hashtags
127
+
128
+ # Shuffle the list randomly
129
+ random.shuffle(combined_hashtags)
130
+
131
+ combined_hashtags = list(set(item for item in combined_hashtags[:15] if not re.search(r'\d$', item)))
132
+
133
+
134
+ # display the image
135
+ st.image(image, caption='The Uploaded File')
136
+ st.write("First is first captions for your Photo : ", string)
137
+ st.write("Magical hashies have arrived : ", combined_hashtags)
138
 
139
+ # run the app
140
+ if __name__ == '__main__':
141
+ app()