hannabaker commited on
Commit
09b134c
·
verified ·
1 Parent(s): a641094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py CHANGED
@@ -1,13 +1,191 @@
1
  from fastapi import FastAPI, Request
2
  import requests, os, random
 
 
 
 
 
 
 
 
3
  from imdb import Cinemagoer
4
  IMGBB_API_KEY = os.getenv("IMGBB_API_KEY")
5
  GIST_URL = os.getenv("GIST_URL")
6
  TMDB_API_KEY = os.getenv("TMDB_API_KEY")
 
 
 
 
 
 
7
  DATA_DIR = "/data"
8
 
9
  app = FastAPI()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @app.get("/")
13
  def greet_json():
 
1
  from fastapi import FastAPI, Request
2
  import requests, os, random
3
+ import psycopg2
4
+ import logging
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from runware import Runware, IImageInference
9
+ from dotenv import load_dotenv
10
+ from openai import OpenAI
11
  from imdb import Cinemagoer
12
  IMGBB_API_KEY = os.getenv("IMGBB_API_KEY")
13
  GIST_URL = os.getenv("GIST_URL")
14
  TMDB_API_KEY = os.getenv("TMDB_API_KEY")
15
+ RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY")
16
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE")
18
+
19
+ client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
20
+
21
  DATA_DIR = "/data"
22
 
23
  app = FastAPI()
24
 
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"], # Allow requests from Next.js dev server
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ db_params = {
34
+ 'dbname': os.getenv('DB_NAME'),
35
+ 'user': os.getenv('DB_USER'),
36
+ 'password': os.getenv('DB_PASSWORD'),
37
+ 'host': os.getenv('DB_HOST'),
38
+ 'port': os.getenv('DB_PORT'),
39
+ 'sslmode': 'require'
40
+ }
41
+
42
+ class ImageRequest(BaseModel):
43
+ prompt: str
44
+ width: int
45
+ height: int
46
+ model: str
47
+ number_results: int = 1
48
+
49
+ def insert_batch(prompt: str, width: int, height: int, model: str, urls: list[str]) -> int:
50
+ conn = None
51
+ try:
52
+ conn = psycopg2.connect(**db_params)
53
+ cur = conn.cursor()
54
+ cur.execute(
55
+ "INSERT INTO batches (prompt, width, height, model) VALUES (%s, %s, %s, %s) RETURNING id",
56
+ (prompt, width, height, model)
57
+ )
58
+ batch_id = cur.fetchone()[0]
59
+
60
+ for url in urls:
61
+ cur.execute(
62
+ "INSERT INTO images (batch_id, url) VALUES (%s, %s)",
63
+ (batch_id, url)
64
+ )
65
+
66
+ conn.commit()
67
+ cur.close()
68
+ return batch_id
69
+ except (Exception, psycopg2.DatabaseError) as error:
70
+ print(f"Error inserting batch: {error}")
71
+ raise
72
+ finally:
73
+ if conn is not None:
74
+ conn.close()
75
+
76
+ @app.post("/generate-image")
77
+ async def generate_image(request: ImageRequest):
78
+ try:
79
+ runware = Runware(api_key=RUNWARE_API_KEY)
80
+ await runware.connect()
81
+
82
+ request_image = IImageInference(
83
+ positivePrompt=request.prompt,
84
+ model=request.model,
85
+ numberResults=request.number_results,
86
+ height=request.height,
87
+ width=request.width,
88
+ )
89
+
90
+ images = await runware.imageInference(requestImage=request_image)
91
+ image_urls = [image.imageURL for image in images]
92
+
93
+ batch_id = insert_batch(request.prompt, request.width, request.height, request.model, image_urls)
94
+
95
+ response = {"batch": {"id": batch_id, "prompt": request.prompt, "width": request.width, "height": request.height, "model": request.model, "images": [{"url": url} for url in image_urls]}}
96
+ return response
97
+ except Exception as e:
98
+ raise HTTPException(status_code=500, detail=f"Failed to generate image: {str(e)}")
99
+
100
+ @app.get("/get-batches")
101
+ async def get_batches():
102
+ conn = None
103
+ try:
104
+ conn = psycopg2.connect(**db_params)
105
+ cur = conn.cursor()
106
+ cur.execute("""
107
+ SELECT b.id, b.prompt, b.width, b.height, b.model, array_agg(i.url) as image_urls, b.created_at
108
+ FROM batches b
109
+ JOIN images i ON i.batch_id = b.id
110
+ GROUP BY b.id, b.prompt, b.width, b.height, b.model, b.created_at
111
+ ORDER BY b.created_at DESC
112
+ LIMIT 5
113
+ """)
114
+ rows = cur.fetchall()
115
+
116
+ batches = []
117
+ for row in rows:
118
+ created_at = row[6]
119
+ created_at_iso = created_at.isoformat() if created_at else None
120
+
121
+ batch = {
122
+ "id": row[0],
123
+ "prompt": row[1],
124
+ "width": row[2],
125
+ "height": row[3],
126
+ "model": row[4],
127
+ "images": [{"url": url} for url in row[5]],
128
+ "createdAt": created_at_iso
129
+ }
130
+ batches.append(batch)
131
+
132
+ return {"batches": batches}
133
+ except (Exception, psycopg2.DatabaseError) as error:
134
+ raise HTTPException(status_code=500, detail=str(error))
135
+ finally:
136
+ if conn is not None:
137
+ conn.close()
138
+
139
+ def delete_batch(batch_id: int):
140
+ conn = None
141
+ try:
142
+ conn = psycopg2.connect(**db_params)
143
+ cur = conn.cursor()
144
+
145
+ # Delete associated images first
146
+ cur.execute("DELETE FROM images WHERE batch_id = %s", (batch_id,))
147
+
148
+ # Then delete the batch
149
+ cur.execute("DELETE FROM batches WHERE id = %s", (batch_id,))
150
+
151
+ conn.commit()
152
+ return True
153
+ except (Exception, psycopg2.DatabaseError) as error:
154
+ print(f"Error deleting batch: {error}")
155
+ return False
156
+ finally:
157
+ if conn is not None:
158
+ conn.close()
159
+
160
+ @app.delete("/delete-batch")
161
+ async def delete_batch_route(id: int):
162
+ success = delete_batch(id)
163
+ if success:
164
+ return {"message": "Batch deleted successfully"}
165
+ else:
166
+ raise HTTPException(status_code=500, detail="Failed to delete batch")
167
+
168
+ @app.post("/enhance-prompt")
169
+ async def enhance_prompt(request: dict):
170
+ try:
171
+ prompt = request.get("prompt")
172
+ if not prompt:
173
+ raise HTTPException(status_code=400, detail="Prompt is required")
174
+
175
+ response = client.chat.completions.create(
176
+ model="gemini-1.5-flash-latest",
177
+ messages=[
178
+ {"role": "system", "content": "You are an AI assistant that enhances image generation prompts. Your task is to take a user's prompt and make it more detailed and descriptive, suitable for high-quality image generation."},
179
+ {"role": "user", "content": f"Enhance this image generation prompt: {prompt}. Reply with the enhanced prompt only."}
180
+ ]
181
+ )
182
+
183
+ enhanced_prompt = response.choices[0].message.content
184
+ return {"enhancedPrompt": enhanced_prompt}
185
+ except Exception as e:
186
+ logger.error(f"Error enhancing prompt: {str(e)}")
187
+ raise HTTPException(status_code=500, detail=f"Failed to enhance prompt: {str(e)}")
188
+
189
 
190
  @app.get("/")
191
  def greet_json():