File size: 5,820 Bytes
59d06f0
 
058805f
 
 
fdd924a
51fe236
53552be
058805f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d48c8
058805f
5c83c17
058805f
 
43d48c8
058805f
 
 
 
 
fdd924a
058805f
 
43d48c8
058805f
43d48c8
058805f
43d48c8
887a9b8
43d48c8
ea96ba1
 
 
43d48c8
ea96ba1
 
 
43d48c8
e66248b
ea96ba1
 
b32b280
ea96ba1
 
650e4ee
76a393b
ea96ba1
 
 
 
 
e66248b
058805f
51fe236
 
43d48c8
ea96ba1
 
058805f
446ce6d
fdd924a
058805f
fdd924a
1a3a8f4
fdd924a
b587781
7167899
4ca9f3c
7b6fa84
a302174
 
 
 
 
 
 
 
 
 
f964c3b
fdd924a
53552be
57354b1
 
e66248b
6af8341
51fe236
a0641a7
446ce6d
36bbb4d
058805f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr

# Using openai models ---------------------------------------------------------

from langchain_openai import OpenAI
import os
openai_api_key = os.getenv("OPENAI_API_KEY")

import io
import base64
import requests
import json

width = 800


# Function to call the API for image and get the response
def get_response_for_image(openai_api_key, image):
    base64_image = base64.b64encode(image).decode('utf-8')
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai_api_key}"
    }
    payload = {
        "model": "gpt-4o",
        "messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": '''Describe or caption the image within 20 words. Output in json format with key: Description'''
              },
              {
                "type": "image_url",
                "image_url": {
                  "url": f"data:image/jpeg;base64,{base64_image}",
                  "detail": "low"
                }
              }
            ]
          }
        ],
        "max_tokens": 200
    }
    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    return response.json()


def generate_story(image, theme, genre, word_count):
    try:
        # Convert PIL image to bytes-like format
        with io.BytesIO() as output:
            image.save(output, format="JPEG")
            image_bytes = output.getvalue()
        
        # Decode the caption
        caption_response = get_response_for_image(openai_api_key, image_bytes)
        json_str = caption_response['choices'][0]['message']['content']
        json_str = json_str.replace('```json', '').replace('```', '').strip()
        content_json = json.loads(json_str)
        caption_text = content_json['Description']
        
        
        # Generate story based on the caption
        story_prompt = f"Write an interesting {theme} story in the {genre} genre about {caption_text}. The story should be within {word_count} words."

        llm = OpenAI(model_name="gpt-3.5-turbo-instruct", openai_api_key=openai_api_key, max_tokens=1000)
        story = llm.invoke(story_prompt)
        
        return caption_text, story
    except Exception as e:
        return f"An error occurred during inference: {str(e)}"
    
    
# Using open source models ----------------------------------------------------

'''
from transformers import pipeline, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 

# Load text generation model

text_generation_model = pipeline("text-generation", model="distilbert/distilgpt2")

# Load image captioning model

encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"

feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint)

def generate_story(image, theme, genre, word_count):
    try:
        # Preprocess the image
        image = image.convert('RGB')
        image_features = feature_extractor(images=image, return_tensors="pt")
        
        # Generate image caption
        caption_ids = model.generate(image_features.pixel_values, max_length=50, num_beams=3, temperature=1.0)
        
        # Decode the caption
        caption_text = tokenizer.batch_decode(caption_ids, skip_special_tokens=True)[0]
        
        # Generate story based on the caption
        story_prompt = f"Write an interesting {theme} story in the {genre} genre. The story should be within {word_count} words about {caption_text}."
        story = text_generation_model(story_prompt, max_length=150)[0]["generated_text"]
        
        return caption_text, story
        
    except Exception as e:
        return f"An error occurred during inference: {str(e)}"
'''


# ------------------------------------------------------------------------- 

# Gradio interface
input_image = gr.Image(label="Select Image",type="pil")
input_theme = gr.Dropdown(["Love and Loss", "Identity and Self-Discovery", "Power and Corruption", "Redemption and Forgiveness", "Survival and Resilience", "Nature and the Environment", "Justice and Injustice", "Friendship and Loyalty", "Hope and Despair"], label="Input Theme")
input_genre = gr.Dropdown(["Fantasy", "Science Fiction", "Poetry", "Mystery/Thriller", "Romance", "Historical Fiction", "Horror", "Adventure", "Drama", "Comedy"], label="Input Genre")
output_caption = gr.Textbox(label="Image Caption", lines=3)
output_text = gr.Textbox(label="Generated Story",lines=20)
examples = [
    ["example1.jpg", "Love and Loss", "Fantasy", 80],
    ["example2.jpg", "Identity and Self-Discovery", "Science Fiction", 100],
    ["example3.jpg", "Power and Corruption", "Mystery/Thriller", 120],
    ["example4.jpg", "Redemption and Forgiveness", "Romance", 80],
    ["example5.jpg", "Survival and Resilience", "Poetry", 150],
    ["example6.jpg", "Nature and the Environment", "Horror", 120],
    ["example7.jpg", "Justice and Injustice", "Adventure", 80],
    ["example8.jpg", "Friendship and Loyalty", "Drama", 100],
]
word_count_slider = gr.Slider(minimum=50, maximum=200, value=80, step=5, label="Word Count")


gr.Interface(
    fn=generate_story,
    inputs=[input_image, input_theme, input_genre, word_count_slider],
    theme='freddyaboulton/dracula_revamped',
    outputs=[output_caption, output_text],
    examples = examples,
    title="Image to Story Generator",
    description="Generate a story from an image taking theme and genre as input. It leverages image captioning and text generation models.",
).launch()