heaversm commited on
Commit
81e44ac
·
1 Parent(s): dc1d5f9

add max generations functionality and do enforce pw functionality

Browse files
Files changed (1) hide show
  1. app.py +63 -7
app.py CHANGED
@@ -29,9 +29,16 @@ credentials = service_account.Credentials.from_service_account_info(service_acco
29
  project="pdr-imagen"
30
  aiplatform.init(project=project, credentials=credentials)
31
 
 
 
 
 
 
 
 
32
  def generate_image(pw,prompt,model_name):
33
 
34
- if pw != os.getenv("PW"):
35
  raise gr.Error("Invalid password. Please try again.")
36
 
37
  try:
@@ -46,16 +53,59 @@ def generate_image(pw,prompt,model_name):
46
 
47
  except Exception as e:
48
  print(e)
49
- raise gr.Error(f"An error occurred while generating the image for: {entry}")
50
  return image_url
51
 
52
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
55
-
56
  #password
57
- pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service")
58
- gr.Markdown("Need access? Send a DM to @HeaversMike on Twitter or send me an email / Slack msg.")
59
 
60
  #instructions
61
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
@@ -69,13 +119,19 @@ with gr.Blocks() as demo:
69
  model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
70
 
71
  with gr.Row():
72
- btn = gr.Button("Generate Images")
73
 
74
  #output
75
  with gr.Accordion("Image Output",label="Image Output",open=True):
76
  output_image = gr.Image(label="Image")
77
 
 
 
 
78
  btn.click(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name=False)
79
  text.submit(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF
80
 
 
 
 
81
  demo.launch(share=False)
 
29
  project="pdr-imagen"
30
  aiplatform.init(project=project, credentials=credentials)
31
 
32
+ # enforce password is True if DO_ENFORCE_PW is set to "true"
33
+ DO_ENFORCE_PW = os.getenv("DO_ENFORCE_PW")
34
+
35
+
36
+ def trigger_max_gens():
37
+ gr.Warning("🖼️ Max Image Generations Reached! 🖼️")
38
+
39
  def generate_image(pw,prompt,model_name):
40
 
41
+ if pw != os.getenv("PW") and DO_ENFORCE_PW == "true":
42
  raise gr.Error("Invalid password. Please try again.")
43
 
44
  try:
 
53
 
54
  except Exception as e:
55
  print(e)
56
+ raise gr.Error(f"An error occurred while generating the image")
57
  return image_url
58
 
59
+ custom_js = """
60
+ function customJS() {
61
+ //Limit Image Generation
62
+ const MAX_GENERATIONS = 10;
63
+ const DO_ENFORCE_MAX_GENERATIONS = true;
64
+
65
+ disableGenerateButton = function() {
66
+ const btn = document.getElementById('btn_generate-images');
67
+ btn.disabled = true;
68
+ btn.classList.add('not-visible');
69
+ }
70
+
71
+ triggerMaxGenerationsToast = function() {
72
+ const trigger_max_gens_btn = document.getElementById('trigger-max-gens-btn');
73
+ trigger_max_gens_btn.click();
74
+ }
75
+
76
+ setCurrentGenerations = function() {
77
+ if (!DO_ENFORCE_MAX_GENERATIONS) {
78
+ return;
79
+ }
80
+ const curGenerations = localStorage.getItem('currentGenerations');
81
+ console.log(`${curGenerations} / ${MAX_GENERATIONS}`)
82
+ if (curGenerations) {
83
+ if (curGenerations >= MAX_GENERATIONS) {
84
+ triggerMaxGenerationsToast();
85
+ disableGenerateButton();
86
+ } else {
87
+ localStorage.setItem('currentGenerations', parseInt(curGenerations) + 1);
88
+ }
89
+ } else {
90
+ localStorage.setItem('currentGenerations', 1);
91
+ }
92
+ }
93
+
94
+ setCurrentGenerations();
95
+
96
+ document.getElementById('btn_generate-images').addEventListener('click', function() {
97
+ setCurrentGenerations();
98
+ });
99
+
100
+ }
101
+ """
102
+
103
+ with gr.Blocks(js=custom_js) as demo:
104
 
105
  gr.Markdown("# <center>Google Vertex Imagen Generator</center>")
 
106
  #password
107
+ pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service",visible=False if DO_ENFORCE_PW == "false" else True)
108
+ gr.Markdown("Need access? Send a DM to @HeaversMike on Twitter or send me an email / Slack msg.",visible=False if DO_ENFORCE_PW == "false" else True)
109
 
110
  #instructions
111
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
 
119
  model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
120
 
121
  with gr.Row():
122
+ btn = gr.Button("Generate Images", variant="primary", elem_id="btn_generate-images")
123
 
124
  #output
125
  with gr.Accordion("Image Output",label="Image Output",open=True):
126
  output_image = gr.Image(label="Image")
127
 
128
+ with gr.Row():
129
+ trigger_max_gens_btn = gr.Button(value="Show Max Gens Reached",visible=False,elem_id="trigger-max-gens-btn")
130
+
131
  btn.click(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name=False)
132
  text.submit(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF
133
 
134
+ #js-triggered functionality
135
+ trigger_max_gens_btn.click(trigger_max_gens, None, None)
136
+
137
  demo.launch(share=False)