vama09 commited on
Commit
063fa1a
1 Parent(s): d5b87d7

Intial Commit

Browse files
Files changed (2) hide show
  1. app.py +75 -0
  2. requirments.txt +8 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
4
+ import itertools
5
+ from nltk.corpus import stopwords
6
+ import nltk
7
+ import easyocr
8
+ import numpy as np
9
+ nltk.download('stopwords')
10
+
11
+ # load the model and tokenizer
12
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
13
+
14
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+ reader = easyocr.Reader(['en'])
17
+
18
+ # set up Streamlit app
19
+ st.set_page_config(layout='wide', page_title='Image Hashtag Recommender')
20
+
21
+
22
+ # define function to extract image features and generate hashtags
23
+ def generate_hashtags(image_file):
24
+ # get image and convert to RGB mode
25
+ image = Image.open(image_file).convert('RGB')
26
+
27
+ # extract image features
28
+ pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values
29
+ output_ids = model.generate(pixel_values)
30
+
31
+ # decode the model output to text and extract caption words
32
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
+ caption_words = [word.lower() for word in output_text.split() if not word.startswith("#")]
34
+
35
+ # remove stop words from caption words
36
+ stop_words = set(stopwords.words('english'))
37
+ caption_words = [word for word in caption_words if word not in stop_words]
38
+
39
+ # use easyocr to extract text from the image
40
+ text = reader.readtext(np.array(image))
41
+ detected_text = " ".join([item[1] for item in text])
42
+
43
+ # combine caption words and detected text
44
+ all_words = caption_words + detected_text.split()
45
+
46
+ # generate combinations of words for hashtags
47
+ hashtags = []
48
+ for n in range(1, 4):
49
+ word_combinations = list(itertools.combinations(all_words, n))
50
+ for combination in word_combinations:
51
+ hashtag = "#" + "".join(combination)
52
+ hashtags.append(hashtag)
53
+
54
+ # return top 10 hashtags by frequency
55
+ top_hashtags = [tag for tag in sorted(set(hashtags), key=hashtags.count, reverse=True) if tag != "#"]
56
+ return top_hashtags[:10]
57
+
58
+
59
+ # display the Streamlit app
60
+ st.title("Image Hashtag Recommender")
61
+
62
+ image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
63
+
64
+ # if the user has submitted an image, generate hashtags
65
+ if image_file is not None:
66
+ try:
67
+ hashtags = generate_hashtags(image_file)
68
+ if len(hashtags) > 0:
69
+ st.write("Top 10 hashtags for this image:")
70
+ for tag in hashtags:
71
+ st.write(tag)
72
+ else:
73
+ st.write("No hashtags found for this image.")
74
+ except Exception as e:
75
+ st.write(f"Error: {e}")
requirments.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ easyocr==1.6.2
2
+ nltk==3.8.1
3
+ numpy==1.23.5
4
+ Pillow==9.5.0
5
+ requests==2.28.1
6
+ streamlit==1.14.1
7
+ torch==2.0.0
8
+ transformers==4.28.1