azizinaghsh commited on
Commit
f5aaf3c
·
1 Parent(s): 727445c

add character position input

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -94,7 +94,7 @@ def generate(
94
  prompt: str,
95
  seed: int,
96
  guidance_weight: float,
97
- sample_label: str,
98
  # ----------------------- #
99
  dataset: MultimodalDataset,
100
  device: torch.device,
@@ -110,10 +110,12 @@ def generate(
110
  diffuser.guidance_weight = guidance_weight
111
 
112
  # Inference
113
- sample_id = SAMPLE_IDS[LABEL_TO_IDS[sample_label]]
114
  seq_feat = diffuser.net.model.clip_sequential
115
 
116
  batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
 
 
117
  with torch.no_grad():
118
  out = diffuser.predict_step(batch, 0)
119
 
@@ -158,17 +160,17 @@ def launch_app(gen_fn: Callable):
158
  with gr.Row():
159
  with gr.Column(scale=3):
160
  with gr.Column(scale=2):
161
- sample_str = gr.Dropdown(
162
- choices=["static", "right", "complex"],
163
- label="Character trajectory",
164
- value="right",
165
- interactive=True,
166
  )
167
  text = gr.Textbox(
168
  placeholder="Type the camera motion you want to generate",
169
  show_label=True,
170
  label="Text prompt",
171
- value=DEFAULT_TEXT[LABEL_TO_IDS[sample_str.value]],
172
  )
173
  seed = gr.Number(value=33, label="Seed")
174
  guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
@@ -189,13 +191,7 @@ def launch_app(gen_fn: Callable):
189
  processed_example = examples.non_none_processed_examples[example_id]
190
  return gr.utils.resolve_singleton(processed_example)
191
 
192
- def change_fn(change):
193
- sample_index = LABEL_TO_IDS[change]
194
- return gr.update(value=DEFAULT_TEXT[sample_index])
195
-
196
- sample_str.change(fn=change_fn, inputs=[sample_str], outputs=[text])
197
-
198
- inputs = [text, seed, guidance, sample_str]
199
  examples.dataset.click(
200
  load_example,
201
  inputs=[examples.dataset],
 
94
  prompt: str,
95
  seed: int,
96
  guidance_weight: float,
97
+ character_position: list,
98
  # ----------------------- #
99
  dataset: MultimodalDataset,
100
  device: torch.device,
 
110
  diffuser.guidance_weight = guidance_weight
111
 
112
  # Inference
113
+ sample_id = SAMPLE_IDS[0] # Default to the first sample ID
114
  seq_feat = diffuser.net.model.clip_sequential
115
 
116
  batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
117
+ batch["character_position"] = torch.tensor(character_position, device=device)
118
+
119
  with torch.no_grad():
120
  out = diffuser.predict_step(batch, 0)
121
 
 
160
  with gr.Row():
161
  with gr.Column(scale=3):
162
  with gr.Column(scale=2):
163
+ char_position = gr.Textbox(
164
+ placeholder="Enter character position as [x, y, z]",
165
+ show_label=True,
166
+ label="Character Position (3D vector)",
167
+ value="[0.0, 0.0, 0.0]",
168
  )
169
  text = gr.Textbox(
170
  placeholder="Type the camera motion you want to generate",
171
  show_label=True,
172
  label="Text prompt",
173
+ value=DEFAULT_TEXT[0],
174
  )
175
  seed = gr.Number(value=33, label="Seed")
176
  guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
 
191
  processed_example = examples.non_none_processed_examples[example_id]
192
  return gr.utils.resolve_singleton(processed_example)
193
 
194
+ inputs = [text, seed, guidance, char_position]
 
 
 
 
 
 
195
  examples.dataset.click(
196
  load_example,
197
  inputs=[examples.dataset],