Ffftdtd5dtft commited on
Commit
4e46091
·
verified ·
1 Parent(s): 838294b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download, HfApi, HfFolder
12
  import multiprocessing
13
  import io
14
  import time
 
15
 
16
  hf_token = os.getenv("HF_TOKEN")
17
  redis_host = os.getenv("REDIS_HOST")
@@ -61,18 +62,22 @@ def get_model_or_download(model_id, redis_key, loader_func):
61
  if model:
62
  return model
63
  try:
64
- model = loader_func(model_id, torch_dtype=torch.float16)
 
 
65
  save_object_to_redis(redis_key, model)
66
  except Exception as e:
67
  print(f"Failed to load or save model: {e}")
68
  return None
69
 
70
  def generate_image(prompt):
71
- redis_key = f"generated_image_{prompt}"
72
  image = load_object_from_redis(redis_key)
73
  if not image:
74
  try:
75
- image = text_to_image_pipeline(prompt).images[0]
 
 
76
  save_object_to_redis(redis_key, image)
77
  except Exception as e:
78
  print(f"Failed to generate image: {e}")
@@ -80,11 +85,13 @@ def generate_image(prompt):
80
  return image
81
 
82
  def edit_image_with_prompt(image, prompt, strength=0.75):
83
- redis_key = f"edited_image_{prompt}_{strength}"
84
  edited_image = load_object_from_redis(redis_key)
85
  if not edited_image:
86
  try:
87
- edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
 
 
88
  save_object_to_redis(redis_key, edited_image)
89
  except Exception as e:
90
  print(f"Failed to edit image: {e}")
@@ -92,11 +99,13 @@ def edit_image_with_prompt(image, prompt, strength=0.75):
92
  return edited_image
93
 
94
  def generate_song(prompt, duration=10):
95
- redis_key = f"generated_song_{prompt}_{duration}"
96
  song = load_object_from_redis(redis_key)
97
  if not song:
98
  try:
99
- song = music_gen.generate(prompt, duration=duration)
 
 
100
  save_object_to_redis(redis_key, song)
101
  except Exception as e:
102
  print(f"Failed to generate song: {e}")
@@ -104,11 +113,13 @@ def generate_song(prompt, duration=10):
104
  return song
105
 
106
  def generate_text(prompt):
107
- redis_key = f"generated_text_{prompt}"
108
  text = load_object_from_redis(redis_key)
109
  if not text:
110
  try:
111
- text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
 
 
112
  save_object_to_redis(redis_key, text)
113
  except Exception as e:
114
  print(f"Failed to generate text: {e}")
@@ -116,17 +127,19 @@ def generate_text(prompt):
116
  return text
117
 
118
  def generate_flux_image(prompt):
119
- redis_key = f"generated_flux_image_{prompt}"
120
  flux_image = load_object_from_redis(redis_key)
121
  if not flux_image:
122
  try:
123
- flux_image = flux_pipeline(
124
- prompt,
125
- guidance_scale=0.0,
126
- num_inference_steps=4,
127
- max_sequence_length=256,
128
- generator=torch.Generator("cpu").manual_seed(0)
129
- ).images[0]
 
 
130
  save_object_to_redis(redis_key, flux_image)
131
  except Exception as e:
132
  print(f"Failed to generate flux image: {e}")
@@ -134,13 +147,15 @@ def generate_flux_image(prompt):
134
  return flux_image
135
 
136
  def generate_code(prompt):
137
- redis_key = f"generated_code_{prompt}"
138
  code = load_object_from_redis(redis_key)
139
  if not code:
140
  try:
141
- inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
142
- outputs = starcoder_model.generate(inputs)
143
- code = starcoder_tokenizer.decode(outputs[0])
 
 
144
  save_object_to_redis(redis_key, code)
145
  except Exception as e:
146
  print(f"Failed to generate code: {e}")
@@ -148,14 +163,16 @@ def generate_code(prompt):
148
  return code
149
 
150
  def generate_video(prompt):
151
- redis_key = f"generated_video_{prompt}"
152
  video = load_object_from_redis(redis_key)
153
  if not video:
154
  try:
155
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
156
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
157
- pipe.enable_model_cpu_offload()
158
- video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
 
 
159
  save_object_to_redis(redis_key, video)
160
  except Exception as e:
161
  print(f"Failed to generate video: {e}")
@@ -171,7 +188,9 @@ def test_model_meta_llama():
171
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
172
  {"role": "user", "content": "Who are you?"}
173
  ]
174
- response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
 
 
175
  save_object_to_redis(redis_key, response)
176
  except Exception as e:
177
  print(f"Failed to test Meta-Llama: {e}")
@@ -188,7 +207,9 @@ def train_model(model, dataset, epochs, batch_size, learning_rate):
188
  )
189
  trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
190
  try:
191
- trainer.train()
 
 
192
  save_object_to_redis("trained_model", model)
193
  save_object_to_redis("training_results", output_dir.getvalue())
194
  except Exception as e:
 
12
  import multiprocessing
13
  import io
14
  import time
15
+ from tqdm import tqdm
16
 
17
  hf_token = os.getenv("HF_TOKEN")
18
  redis_host = os.getenv("REDIS_HOST")
 
62
  if model:
63
  return model
64
  try:
65
+ with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
66
+ model = loader_func(model_id, torch_dtype=torch.float16)
67
+ pbar.update(1)
68
  save_object_to_redis(redis_key, model)
69
  except Exception as e:
70
  print(f"Failed to load or save model: {e}")
71
  return None
72
 
73
  def generate_image(prompt):
74
+ redis_key = f"generated_image:{prompt}"
75
  image = load_object_from_redis(redis_key)
76
  if not image:
77
  try:
78
+ with tqdm(total=1, desc="Generating image") as pbar:
79
+ image = text_to_image_pipeline(prompt).images[0]
80
+ pbar.update(1)
81
  save_object_to_redis(redis_key, image)
82
  except Exception as e:
83
  print(f"Failed to generate image: {e}")
 
85
  return image
86
 
87
  def edit_image_with_prompt(image, prompt, strength=0.75):
88
+ redis_key = f"edited_image:{prompt}:{strength}"
89
  edited_image = load_object_from_redis(redis_key)
90
  if not edited_image:
91
  try:
92
+ with tqdm(total=1, desc="Editing image") as pbar:
93
+ edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
94
+ pbar.update(1)
95
  save_object_to_redis(redis_key, edited_image)
96
  except Exception as e:
97
  print(f"Failed to edit image: {e}")
 
99
  return edited_image
100
 
101
  def generate_song(prompt, duration=10):
102
+ redis_key = f"generated_song:{prompt}:{duration}"
103
  song = load_object_from_redis(redis_key)
104
  if not song:
105
  try:
106
+ with tqdm(total=1, desc="Generating song") as pbar:
107
+ song = music_gen.generate(prompt, duration=duration)
108
+ pbar.update(1)
109
  save_object_to_redis(redis_key, song)
110
  except Exception as e:
111
  print(f"Failed to generate song: {e}")
 
113
  return song
114
 
115
  def generate_text(prompt):
116
+ redis_key = f"generated_text:{prompt}"
117
  text = load_object_from_redis(redis_key)
118
  if not text:
119
  try:
120
+ with tqdm(total=1, desc="Generating text") as pbar:
121
+ text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
122
+ pbar.update(1)
123
  save_object_to_redis(redis_key, text)
124
  except Exception as e:
125
  print(f"Failed to generate text: {e}")
 
127
  return text
128
 
129
  def generate_flux_image(prompt):
130
+ redis_key = f"generated_flux_image:{prompt}"
131
  flux_image = load_object_from_redis(redis_key)
132
  if not flux_image:
133
  try:
134
+ with tqdm(total=1, desc="Generating FLUX image") as pbar:
135
+ flux_image = flux_pipeline(
136
+ prompt,
137
+ guidance_scale=0.0,
138
+ num_inference_steps=4,
139
+ max_sequence_length=256,
140
+ generator=torch.Generator("cpu").manual_seed(0)
141
+ ).images[0]
142
+ pbar.update(1)
143
  save_object_to_redis(redis_key, flux_image)
144
  except Exception as e:
145
  print(f"Failed to generate flux image: {e}")
 
147
  return flux_image
148
 
149
  def generate_code(prompt):
150
+ redis_key = f"generated_code:{prompt}"
151
  code = load_object_from_redis(redis_key)
152
  if not code:
153
  try:
154
+ with tqdm(total=1, desc="Generating code") as pbar:
155
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
156
+ outputs = starcoder_model.generate(inputs)
157
+ code = starcoder_tokenizer.decode(outputs[0])
158
+ pbar.update(1)
159
  save_object_to_redis(redis_key, code)
160
  except Exception as e:
161
  print(f"Failed to generate code: {e}")
 
163
  return code
164
 
165
  def generate_video(prompt):
166
+ redis_key = f"generated_video:{prompt}"
167
  video = load_object_from_redis(redis_key)
168
  if not video:
169
  try:
170
+ with tqdm(total=1, desc="Generating video") as pbar:
171
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
172
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
173
+ pipe.enable_model_cpu_offload()
174
+ video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
175
+ pbar.update(1)
176
  save_object_to_redis(redis_key, video)
177
  except Exception as e:
178
  print(f"Failed to generate video: {e}")
 
188
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
189
  {"role": "user", "content": "Who are you?"}
190
  ]
191
+ with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
192
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
193
+ pbar.update(1)
194
  save_object_to_redis(redis_key, response)
195
  except Exception as e:
196
  print(f"Failed to test Meta-Llama: {e}")
 
207
  )
208
  trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
209
  try:
210
+ with tqdm(total=epochs, desc="Training model") as pbar:
211
+ trainer.train()
212
+ pbar.update(epochs)
213
  save_object_to_redis("trained_model", model)
214
  save_object_to_redis("training_results", output_dir.getvalue())
215
  except Exception as e: