Geek7 commited on
Commit
65f3fc3
1 Parent(s): 7eec8a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -25
app.py CHANGED
@@ -1,14 +1,21 @@
1
  import gradio as gr
2
  from random import randint
3
- from all_models import models
4
  from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
 
 
 
8
 
9
  lock = RLock()
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
 
 
 
 
12
  def load_fn(models):
13
  global models_load
14
  models_load = {}
@@ -24,11 +31,12 @@ def load_fn(models):
24
 
25
  load_fn(models)
26
 
27
- num_models = 6
28
  MAX_SEED = 3999999999
29
- default_models = models[:num_models]
30
  inference_timeout = 600
31
 
 
32
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
  kwargs = {"seed": seed}
34
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
@@ -43,30 +51,33 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
43
  result = None
44
  if task.done() and result is not None:
45
  with lock:
46
- png_path = "image.png"
47
- result.save(png_path)
48
- return png_path
49
- return None
50
-
51
- # Expose Gradio API
52
- def generate_api(model_str, prompt, seed=1):
53
- result = asyncio.run(infer(model_str, prompt, seed))
54
- if result:
55
- return result # Path to generated image
56
  return None
57
 
58
- from gradio_client import Client
59
-
60
- client = Client("Geek7/mdztxi2")
61
- result = client.predict(
62
- model_str=model_str,
63
- prompt=prompt,
64
- seed=seed,
65
- api_name="/generate_api "
66
- )
67
 
 
 
 
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Launch Gradio API without frontend
71
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
72
- iface.launch(show_api=True, share=True)
 
1
  import gradio as gr
2
  from random import randint
3
+ from all_models import models # Import the list of available models
4
  from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
8
+ from flask import Flask, request, jsonify, send_file
9
+ from flask_cors import CORS
10
+ import tempfile
11
 
12
  lock = RLock()
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
 
15
+ app = Flask(__name__)
16
+ CORS(app) # Enable CORS for all routes
17
+
18
+ # Function to load models
19
  def load_fn(models):
20
  global models_load
21
  models_load = {}
 
31
 
32
  load_fn(models)
33
 
34
+ num_models = 6 # Number of models to load initially
35
  MAX_SEED = 3999999999
36
+ default_models = models[:num_models] # Load the first few models for inference
37
  inference_timeout = 600
38
 
39
+ # Asynchronous function to perform inference
40
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
  kwargs = {"seed": seed}
42
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
 
51
  result = None
52
  if task.done() and result is not None:
53
  with lock:
54
+ temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
55
+ result.save(temp_image.name) # Save result as a temporary file
56
+ return temp_image.name # Return the path to the saved image
 
 
 
 
 
 
 
57
  return None
58
 
59
+ # Flask route for the API endpoint
60
+ @app.route('/generate_api', methods=['POST'])
61
+ def generate_api():
62
+ data = request.get_json()
 
 
 
 
 
63
 
64
+ # Extract required fields from the request
65
+ model_str = data.get('model_str', default_models[0]) # Default to first model if not provided
66
+ prompt = data.get('prompt', '')
67
+ seed = data.get('seed', 1)
68
 
69
+ if not prompt:
70
+ return jsonify({"error": "Prompt is required"}), 400
71
+
72
+ try:
73
+ # Call the async inference function
74
+ result_path = asyncio.run(infer(model_str, prompt, seed))
75
+ if result_path:
76
+ return send_file(result_path, mimetype='image/png') # Send back the generated image file
77
+ else:
78
+ return jsonify({"error": "Failed to generate image"}), 500
79
+ except Exception as e:
80
+ return jsonify({"error": str(e)}), 500
81
 
82
+ if __name__ == '__main__':
83
+ app.run(debug=True)