tfrere commited on
Commit
287d153
·
1 Parent(s): 0146535
client/src/App.jsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useState, useEffect } from "react";
2
  import {
3
  Container,
4
  Paper,
@@ -20,6 +20,29 @@ function App() {
20
  const [storySegments, setStorySegments] = useState([]);
21
  const [currentChoices, setCurrentChoices] = useState([]);
22
  const [isLoading, setIsLoading] = useState(false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  const handleStoryAction = async (action, choiceId = null) => {
25
  setIsLoading(true);
@@ -29,12 +52,20 @@ function App() {
29
  choice_id: choiceId,
30
  });
31
 
 
 
 
 
 
 
 
32
  if (action === "restart") {
33
  setStorySegments([
34
  {
35
  text: response.data.story_text,
36
  isChoice: false,
37
  isDeath: response.data.is_death,
 
38
  },
39
  ]);
40
  } else {
@@ -44,6 +75,7 @@ function App() {
44
  text: response.data.story_text,
45
  isChoice: false,
46
  isDeath: response.data.is_death,
 
47
  },
48
  ]);
49
  }
@@ -62,8 +94,11 @@ function App() {
62
 
63
  // Start the story when the component mounts
64
  useEffect(() => {
65
- handleStoryAction("restart");
66
- }, []);
 
 
 
67
 
68
  const handleChoice = async (choiceId) => {
69
  // Add the chosen option to the story
@@ -114,6 +149,8 @@ function App() {
114
  sx={{
115
  justifyContent: segment.isChoice ? "flex-end" : "flex-start",
116
  display: "flex",
 
 
117
  }}
118
  >
119
  <Paper
@@ -146,6 +183,19 @@ function App() {
146
  color: segment.isChoice ? "inherit" : "primary",
147
  }}
148
  />
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  </Paper>
150
  </ListItem>
151
  ))}
 
1
+ import { useState, useEffect, useRef } from "react";
2
  import {
3
  Container,
4
  Paper,
 
20
  const [storySegments, setStorySegments] = useState([]);
21
  const [currentChoices, setCurrentChoices] = useState([]);
22
  const [isLoading, setIsLoading] = useState(false);
23
+ const isInitializedRef = useRef(false);
24
+
25
+ const generateImageForStory = async (storyText) => {
26
+ try {
27
+ console.log("Generating image for story:", storyText);
28
+ const response = await axios.post(`${API_URL}/api/generate-image`, {
29
+ prompt: `Comic book style scene: ${storyText}`,
30
+ width: 512,
31
+ height: 512,
32
+ });
33
+
34
+ console.log("Image generation response:", response.data);
35
+
36
+ if (response.data.success) {
37
+ console.log("Image URL length:", response.data.image_base64.length);
38
+ return response.data.image_base64;
39
+ }
40
+ return null;
41
+ } catch (error) {
42
+ console.error("Error generating image:", error);
43
+ return null;
44
+ }
45
+ };
46
 
47
  const handleStoryAction = async (action, choiceId = null) => {
48
  setIsLoading(true);
 
52
  choice_id: choiceId,
53
  });
54
 
55
+ // Générer l'image pour ce segment
56
+ const imageUrl = await generateImageForStory(response.data.story_text);
57
+ console.log(
58
+ "Generated image URL:",
59
+ imageUrl ? "Image received" : "No image"
60
+ );
61
+
62
  if (action === "restart") {
63
  setStorySegments([
64
  {
65
  text: response.data.story_text,
66
  isChoice: false,
67
  isDeath: response.data.is_death,
68
+ imageUrl: imageUrl,
69
  },
70
  ]);
71
  } else {
 
75
  text: response.data.story_text,
76
  isChoice: false,
77
  isDeath: response.data.is_death,
78
+ imageUrl: imageUrl,
79
  },
80
  ]);
81
  }
 
94
 
95
  // Start the story when the component mounts
96
  useEffect(() => {
97
+ if (!isInitializedRef.current) {
98
+ handleStoryAction("restart");
99
+ isInitializedRef.current = true;
100
+ }
101
+ }, []); // Empty dependency array since we're using a ref
102
 
103
  const handleChoice = async (choiceId) => {
104
  // Add the chosen option to the story
 
149
  sx={{
150
  justifyContent: segment.isChoice ? "flex-end" : "flex-start",
151
  display: "flex",
152
+ flexDirection: "column",
153
+ alignItems: segment.isChoice ? "flex-end" : "flex-start",
154
  }}
155
  >
156
  <Paper
 
183
  color: segment.isChoice ? "inherit" : "primary",
184
  }}
185
  />
186
+ {!segment.isChoice && segment.imageUrl && (
187
+ <Box sx={{ mt: 2, width: "100%", textAlign: "center" }}>
188
+ <img
189
+ src={segment.imageUrl}
190
+ alt="Story scene"
191
+ style={{
192
+ maxWidth: "100%",
193
+ height: "auto",
194
+ borderRadius: "4px",
195
+ }}
196
+ />
197
+ </Box>
198
+ )}
199
  </Paper>
200
  </ListItem>
201
  ))}
server/.env.example CHANGED
@@ -1 +1,2 @@
1
- MISTRAL_API_KEY=your-mistral-api-key-here
 
 
1
+ MISTRAL_API_KEY=your-mistral-api-key-here
2
+ HF_API_KEY=your-hf-api-key-here
server/poetry.lock CHANGED
@@ -814,13 +814,13 @@ trio = ["trio (>=0.22.0,<1.0)"]
814
 
815
  [[package]]
816
  name = "httpx"
817
- version = "0.26.0"
818
  description = "The next generation HTTP client."
819
  optional = false
820
  python-versions = ">=3.8"
821
  files = [
822
- {file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"},
823
- {file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"},
824
  ]
825
 
826
  [package.dependencies]
@@ -828,13 +828,13 @@ anyio = "*"
828
  certifi = "*"
829
  httpcore = "==1.*"
830
  idna = "*"
831
- sniffio = "*"
832
 
833
  [package.extras]
834
  brotli = ["brotli", "brotlicffi"]
835
  cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
836
  http2 = ["h2 (>=3,<5)"]
837
  socks = ["socksio (==1.*)"]
 
838
 
839
  [[package]]
840
  name = "httpx-sse"
@@ -2551,4 +2551,4 @@ cffi = ["cffi (>=1.11)"]
2551
  [metadata]
2552
  lock-version = "2.0"
2553
  python-versions = "^3.9"
2554
- content-hash = "6cd85934aee7e38dc9ec3499bf8cf48722604a929b63397da9fa88f4090bbde9"
 
814
 
815
  [[package]]
816
  name = "httpx"
817
+ version = "0.28.1"
818
  description = "The next generation HTTP client."
819
  optional = false
820
  python-versions = ">=3.8"
821
  files = [
822
+ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
823
+ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
824
  ]
825
 
826
  [package.dependencies]
 
828
  certifi = "*"
829
  httpcore = "==1.*"
830
  idna = "*"
 
831
 
832
  [package.extras]
833
  brotli = ["brotli", "brotlicffi"]
834
  cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
835
  http2 = ["h2 (>=3,<5)"]
836
  socks = ["socksio (==1.*)"]
837
+ zstd = ["zstandard (>=0.18.0)"]
838
 
839
  [[package]]
840
  name = "httpx-sse"
 
2551
  [metadata]
2552
  lock-version = "2.0"
2553
  python-versions = "^3.9"
2554
+ content-hash = "f1ac792e9026c6373be7fd6f4db3f418c7965c58ec1b0df8286a47c4d997b92b"
server/pyproject.toml CHANGED
@@ -12,6 +12,7 @@ python-dotenv = "^1.0.0"
12
  elevenlabs = "^0.2.26"
13
  langchain = "^0.3.15"
14
  langchain-mistralai = "^0.2.4"
 
15
 
16
  [tool.poetry.group.dev.dependencies]
17
  pytest = "^7.4.0"
 
12
  elevenlabs = "^0.2.26"
13
  langchain = "^0.3.15"
14
  langchain-mistralai = "^0.2.4"
15
+ requests = "^2.31.0"
16
 
17
  [tool.poetry.group.dev.dependencies]
18
  pytest = "^7.4.0"
server/server.py CHANGED
@@ -5,6 +5,9 @@ from pydantic import BaseModel
5
  from typing import List, Optional
6
  import os
7
  from dotenv import load_dotenv
 
 
 
8
 
9
  # Choose import based on environment
10
  if os.getenv("DOCKER_ENV"):
@@ -19,6 +22,8 @@ load_dotenv()
19
  API_HOST = os.getenv("API_HOST", "0.0.0.0")
20
  API_PORT = int(os.getenv("API_PORT", "8000"))
21
  STATIC_FILES_DIR = os.getenv("STATIC_FILES_DIR", "../client/dist")
 
 
22
 
23
  app = FastAPI(title="Echoes of Influence")
24
 
@@ -61,6 +66,17 @@ class ChatMessage(BaseModel):
61
  message: str
62
  choice_id: Optional[int] = None
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  @app.get("/api/health")
65
  async def health_check():
66
  """Health check endpoint"""
@@ -75,6 +91,8 @@ async def health_check():
75
  @app.post("/api/chat", response_model=StoryResponse)
76
  async def chat_endpoint(chat_message: ChatMessage):
77
  try:
 
 
78
  # Handle restart
79
  if chat_message.message.lower() == "restart":
80
  game_state.reset()
@@ -82,19 +100,25 @@ async def chat_endpoint(chat_message: ChatMessage):
82
  else:
83
  previous_choice = f"Choice {chat_message.choice_id}" if chat_message.choice_id else "none"
84
 
 
 
85
  # Generate story segment
86
  story_segment = story_generator.generate_story_segment(game_state, previous_choice)
 
87
 
88
  # Update radiation level
89
  game_state.radiation_level += story_segment.radiation_increase
 
90
 
91
  # Check for radiation death
92
  if game_state.radiation_level >= MAX_RADIATION:
93
  story_segment = story_generator.process_radiation_death(story_segment)
 
94
 
95
  # Only increment story beat if not dead
96
  if not story_segment.is_death:
97
  game_state.story_beat += 1
 
98
 
99
  # Convert to response format
100
  choices = [] if story_segment.is_death else [
@@ -102,17 +126,149 @@ async def chat_endpoint(chat_message: ChatMessage):
102
  for i, choice in enumerate(story_segment.choices, 1)
103
  ]
104
 
105
- return StoryResponse(
106
  story_text=story_segment.story_text,
107
  choices=choices,
108
  is_death=story_segment.is_death,
109
  radiation_level=game_state.radiation_level
110
  )
 
 
111
 
112
  except Exception as e:
113
- print(f"Error: {str(e)}")
 
 
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Mount static files (this should be after all API routes)
117
  app.mount("/", StaticFiles(directory=STATIC_FILES_DIR, html=True), name="static")
118
 
 
5
  from typing import List, Optional
6
  import os
7
  from dotenv import load_dotenv
8
+ import requests
9
+ import base64
10
+ import time
11
 
12
  # Choose import based on environment
13
  if os.getenv("DOCKER_ENV"):
 
22
  API_HOST = os.getenv("API_HOST", "0.0.0.0")
23
  API_PORT = int(os.getenv("API_PORT", "8000"))
24
  STATIC_FILES_DIR = os.getenv("STATIC_FILES_DIR", "../client/dist")
25
+ HF_API_KEY = os.getenv("HF_API_KEY")
26
+ AWS_TOKEN = os.getenv("AWS_TOKEN", "VHVlIEZlYiAyNyAwOTowNzoyMiBDRVQgMjAyNA==") # Token par défaut pour le développement
27
 
28
  app = FastAPI(title="Echoes of Influence")
29
 
 
66
  message: str
67
  choice_id: Optional[int] = None
68
 
69
+ class ImageGenerationRequest(BaseModel):
70
+ prompt: str
71
+ negative_prompt: Optional[str] = None
72
+ width: Optional[int] = 1024
73
+ height: Optional[int] = 1024
74
+
75
+ class ImageGenerationResponse(BaseModel):
76
+ success: bool
77
+ image_base64: Optional[str] = None
78
+ error: Optional[str] = None
79
+
80
  @app.get("/api/health")
81
  async def health_check():
82
  """Health check endpoint"""
 
91
  @app.post("/api/chat", response_model=StoryResponse)
92
  async def chat_endpoint(chat_message: ChatMessage):
93
  try:
94
+ print("Received chat message:", chat_message)
95
+
96
  # Handle restart
97
  if chat_message.message.lower() == "restart":
98
  game_state.reset()
 
100
  else:
101
  previous_choice = f"Choice {chat_message.choice_id}" if chat_message.choice_id else "none"
102
 
103
+ print("Previous choice:", previous_choice)
104
+
105
  # Generate story segment
106
  story_segment = story_generator.generate_story_segment(game_state, previous_choice)
107
+ print("Generated story segment:", story_segment)
108
 
109
  # Update radiation level
110
  game_state.radiation_level += story_segment.radiation_increase
111
+ print("Updated radiation level:", game_state.radiation_level)
112
 
113
  # Check for radiation death
114
  if game_state.radiation_level >= MAX_RADIATION:
115
  story_segment = story_generator.process_radiation_death(story_segment)
116
+ print("Processed radiation death")
117
 
118
  # Only increment story beat if not dead
119
  if not story_segment.is_death:
120
  game_state.story_beat += 1
121
+ print("Incremented story beat to:", game_state.story_beat)
122
 
123
  # Convert to response format
124
  choices = [] if story_segment.is_death else [
 
126
  for i, choice in enumerate(story_segment.choices, 1)
127
  ]
128
 
129
+ response = StoryResponse(
130
  story_text=story_segment.story_text,
131
  choices=choices,
132
  is_death=story_segment.is_death,
133
  radiation_level=game_state.radiation_level
134
  )
135
+ print("Sending response:", response)
136
+ return response
137
 
138
  except Exception as e:
139
+ import traceback
140
+ print(f"Error in chat_endpoint: {str(e)}")
141
+ print("Traceback:", traceback.format_exc())
142
  raise HTTPException(status_code=500, detail=str(e))
143
 
144
+ async def transform_story_to_art_prompt(story_text: str) -> str:
145
+ try:
146
+ from langchain_mistralai.chat_models import ChatMistralAI
147
+ from langchain.schema import HumanMessage, SystemMessage
148
+
149
+ chat = ChatMistralAI(
150
+ api_key=mistral_api_key,
151
+ model="mistral-small"
152
+ )
153
+
154
+ messages = [
155
+ SystemMessage(content="""Tu es un expert en prompts pour la génération d'images.
156
+ Transforme l'histoire en un prompt court et précis.
157
+
158
+ Format strict:
159
+ "color comic panel, style of Hergé, [scène principale en 5-7 mots], french comic panel"
160
+
161
+ Exemple:
162
+ "color comic panel, style of Hergé, detective running through dark alley, french comic panel"
163
+
164
+ Règles:
165
+ - Maximum 20 mots pour décrire la scène
166
+ - Pas d'adjectifs superflus
167
+ - Capture l'action principale uniquement"""),
168
+ HumanMessage(content=f"Transforme en prompt court: {story_text}")
169
+ ]
170
+
171
+ response = chat.invoke(messages)
172
+ return response.content
173
+
174
+ except Exception as e:
175
+ print(f"Error transforming prompt: {str(e)}")
176
+ return story_text
177
+
178
+ @app.post("/api/generate-image", response_model=ImageGenerationResponse)
179
+ async def generate_image(request: ImageGenerationRequest):
180
+ try:
181
+ if not HF_API_KEY:
182
+ return ImageGenerationResponse(
183
+ success=False,
184
+ error="HF_API_KEY is not configured in .env file"
185
+ )
186
+
187
+ # Transformer le prompt en prompt artistique
188
+ original_prompt = request.prompt
189
+ # Enlever le préfixe pour la transformation
190
+ story_text = original_prompt.replace("moebius style scene: ", "").strip()
191
+ art_prompt = await transform_story_to_art_prompt(story_text)
192
+ # Réappliquer le préfixe
193
+ final_prompt = f"moebius style scene: {art_prompt}"
194
+ print("Original prompt:", original_prompt)
195
+ print("Transformed art prompt:", final_prompt)
196
+
197
+ # Paramètres de retry
198
+ max_retries = 3
199
+ retry_delay = 1 # secondes
200
+
201
+ for attempt in range(max_retries):
202
+ try:
203
+ # Appel à l'endpoint HF avec authentification
204
+ response = requests.post(
205
+ "https://tvsk4iu4ghzffi34.us-east-1.aws.endpoints.huggingface.cloud",
206
+ headers={
207
+ "Content-Type": "application/json",
208
+ "Accept": "image/jpeg",
209
+ "Authorization": f"Bearer {HF_API_KEY}"
210
+ },
211
+ json={
212
+ "inputs": final_prompt,
213
+ "parameters": {
214
+ "guidance_scale": 9.0, # Valeur du Comic Factory
215
+ "width": request.width or 1024,
216
+ "height": request.height or 1024,
217
+ "negative_prompt": "manga, anime, american comic, grayscale, monochrome, photo, painting, 3D render"
218
+ }
219
+ }
220
+ )
221
+
222
+ print(f"Attempt {attempt + 1} - API Response status:", response.status_code)
223
+ print("API Response headers:", dict(response.headers))
224
+
225
+ if response.status_code == 503:
226
+ if attempt < max_retries - 1:
227
+ print(f"Service unavailable, retrying in {retry_delay} seconds...")
228
+ time.sleep(retry_delay)
229
+ retry_delay *= 2 # Exponential backoff
230
+ continue
231
+ else:
232
+ return ImageGenerationResponse(
233
+ success=False,
234
+ error="Service is currently unavailable after multiple retries"
235
+ )
236
+
237
+ if response.status_code != 200:
238
+ error_msg = response.text if response.text else "Unknown error"
239
+ print("Error response:", error_msg)
240
+ return ImageGenerationResponse(
241
+ success=False,
242
+ error=f"API error: {error_msg}"
243
+ )
244
+
245
+ # L'API renvoie directement l'image en binaire
246
+ image_bytes = response.content
247
+ base64_image = base64.b64encode(image_bytes).decode('utf-8')
248
+
249
+ print("Base64 image length:", len(base64_image))
250
+
251
+ return ImageGenerationResponse(
252
+ success=True,
253
+ image_base64=f"data:image/jpeg;base64,{base64_image}"
254
+ )
255
+
256
+ except requests.exceptions.RequestException as e:
257
+ if attempt < max_retries - 1:
258
+ print(f"Request failed, retrying in {retry_delay} seconds... Error: {str(e)}")
259
+ time.sleep(retry_delay)
260
+ retry_delay *= 2
261
+ continue
262
+ else:
263
+ raise
264
+
265
+ except Exception as e:
266
+ print("Error in generate_image:", str(e))
267
+ return ImageGenerationResponse(
268
+ success=False,
269
+ error=f"Error generating image: {str(e)}"
270
+ )
271
+
272
  # Mount static files (this should be after all API routes)
273
  app.mount("/", StaticFiles(directory=STATIC_FILES_DIR, html=True), name="static")
274