Ffftdtd5dtft commited on
Commit
25c2784
verified
1 Parent(s): 2779600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -39
app.py CHANGED
@@ -1,5 +1,3 @@
1
- !pip install redis diffusers transformers accelerate torch gradio audiocraft huggingface_hub
2
-
3
  import redis
4
  import pickle
5
  import torch
@@ -23,17 +21,15 @@ redis_password = os.getenv("REDIS_PASSWORD")
23
  HfFolder.save_token(hf_token)
24
 
25
  def connect_to_redis():
26
- max_retries = 5
27
- retry_delay = 1
28
- for attempt in range(max_retries):
29
  try:
30
  redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
31
- redis_client.ping()
 
32
  return redis_client
33
  except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
34
- print(f"Attempt {attempt + 1}: Connection to Redis failed: {e}. Retrying in {retry_delay} seconds...")
35
- time.sleep(retry_delay)
36
- raise ConnectionError("Failed to connect to Redis after multiple retries.")
37
 
38
  def reconnect_if_needed(redis_client):
39
  try:
@@ -46,59 +42,111 @@ def reconnect_if_needed(redis_client):
46
  def load_object_from_redis(key):
47
  redis_client = connect_to_redis()
48
  redis_client = reconnect_if_needed(redis_client)
49
- obj_data = redis_client.get(key)
50
- return pickle.loads(obj_data) if obj_data else None
 
 
 
 
51
 
52
  def save_object_to_redis(key, obj):
53
  redis_client = connect_to_redis()
54
  redis_client = reconnect_if_needed(redis_client)
55
- redis_client.set(key, pickle.dumps(obj))
 
 
 
 
 
56
 
57
  def get_model_or_download(model_id, redis_key, loader_func):
58
  model = load_object_from_redis(redis_key)
59
- if not model:
60
- model = loader_func(model_id, torch_dtype=torch.float16)
61
- save_object_to_redis(redis_key, model)
 
 
 
62
  return model
63
 
64
  def generate_image(prompt):
65
- return text_to_image_pipeline(prompt).images[0]
 
 
 
 
 
66
 
67
  def edit_image_with_prompt(image, prompt, strength=0.75):
68
- return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
 
 
 
 
 
69
 
70
  def generate_song(prompt, duration=10):
71
- return music_gen.generate(prompt, duration=duration)
 
 
 
 
 
72
 
73
  def generate_text(prompt):
74
- return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
 
 
 
 
 
75
 
76
  def generate_flux_image(prompt):
77
- return flux_pipeline(
78
- prompt,
79
- guidance_scale=0.0,
80
- num_inference_steps=4,
81
- max_sequence_length=256,
82
- generator=torch.Generator("cpu").manual_seed(0)
83
- ).images[0]
 
 
 
 
 
84
 
85
  def generate_code(prompt):
86
- inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
87
- outputs = starcoder_model.generate(inputs)
88
- return starcoder_tokenizer.decode(outputs[0])
 
 
 
 
 
89
 
90
  def generate_video(prompt):
91
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
92
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
93
- pipe.enable_model_cpu_offload()
94
- return export_to_video(pipe(prompt, num_inference_steps=25).frames)
 
 
 
 
 
95
 
96
  def test_model_meta_llama():
97
- messages = [
98
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
99
- {"role": "user", "content": "Who are you?"}
100
- ]
101
- return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
 
 
 
 
 
102
 
103
  def train_model(model, dataset, epochs, batch_size, learning_rate):
104
  output_dir = io.BytesIO()
@@ -139,7 +187,7 @@ music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melo
139
  save_object_to_redis("music_gen", music_gen)
140
  text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
141
  "text-generation",
142
- model="google/gemma-2-2b-it",
143
  model_kwargs={"torch_dtype": torch.bfloat16},
144
  device=device,
145
  use_auth_token=hf_token,
 
 
 
1
  import redis
2
  import pickle
3
  import torch
 
21
  HfFolder.save_token(hf_token)
22
 
23
  def connect_to_redis():
24
+ while True:
 
 
25
  try:
26
  redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
27
+ redis_client.ping() # Verifica si la conexi贸n est谩 activa
28
+ print("Connected to Redis successfully.")
29
  return redis_client
30
  except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, BrokenPipeError) as e:
31
+ print(f"Connection to Redis failed: {e}. Retrying in 1 second...")
32
+ time.sleep(1)
 
33
 
34
  def reconnect_if_needed(redis_client):
35
  try:
 
42
  def load_object_from_redis(key):
43
  redis_client = connect_to_redis()
44
  redis_client = reconnect_if_needed(redis_client)
45
+ try:
46
+ obj_data = redis_client.get(key)
47
+ return pickle.loads(obj_data) if obj_data else None
48
+ except (pickle.PickleError, redis.exceptions.RedisError) as e:
49
+ print(f"Failed to load object from Redis: {e}")
50
+ return None
51
 
52
  def save_object_to_redis(key, obj):
53
  redis_client = connect_to_redis()
54
  redis_client = reconnect_if_needed(redis_client)
55
+ try:
56
+ if not redis_client.exists(key): # Solo guarda si no existe
57
+ redis_client.set(key, pickle.dumps(obj))
58
+ print(f"Object saved to Redis: {key}")
59
+ except redis.exceptions.RedisError as e:
60
+ print(f"Failed to save object to Redis: {e}")
61
 
62
  def get_model_or_download(model_id, redis_key, loader_func):
63
  model = load_object_from_redis(redis_key)
64
+ if model:
65
+ print(f"Model loaded from Redis: {redis_key}")
66
+ return model
67
+ model = loader_func(model_id, torch_dtype=torch.float16)
68
+ save_object_to_redis(redis_key, model)
69
+ print(f"Model downloaded and saved to Redis: {redis_key}")
70
  return model
71
 
72
  def generate_image(prompt):
73
+ redis_key = f"generated_image_{prompt}"
74
+ image = load_object_from_redis(redis_key)
75
+ if not image:
76
+ image = text_to_image_pipeline(prompt).images[0]
77
+ save_object_to_redis(redis_key, image)
78
+ return image
79
 
80
  def edit_image_with_prompt(image, prompt, strength=0.75):
81
+ redis_key = f"edited_image_{prompt}_{strength}"
82
+ edited_image = load_object_from_redis(redis_key)
83
+ if not edited_image:
84
+ edited_image = img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
85
+ save_object_to_redis(redis_key, edited_image)
86
+ return edited_image
87
 
88
  def generate_song(prompt, duration=10):
89
+ redis_key = f"generated_song_{prompt}_{duration}"
90
+ song = load_object_from_redis(redis_key)
91
+ if not song:
92
+ song = music_gen.generate(prompt, duration=duration)
93
+ save_object_to_redis(redis_key, song)
94
+ return song
95
 
96
  def generate_text(prompt):
97
+ redis_key = f"generated_text_{prompt}"
98
+ text = load_object_from_redis(redis_key)
99
+ if not text:
100
+ text = text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"].strip()
101
+ save_object_to_redis(redis_key, text)
102
+ return text
103
 
104
  def generate_flux_image(prompt):
105
+ redis_key = f"generated_flux_image_{prompt}"
106
+ flux_image = load_object_from_redis(redis_key)
107
+ if not flux_image:
108
+ flux_image = flux_pipeline(
109
+ prompt,
110
+ guidance_scale=0.0,
111
+ num_inference_steps=4,
112
+ max_sequence_length=256,
113
+ generator=torch.Generator("cpu").manual_seed(0)
114
+ ).images[0]
115
+ save_object_to_redis(redis_key, flux_image)
116
+ return flux_image
117
 
118
  def generate_code(prompt):
119
+ redis_key = f"generated_code_{prompt}"
120
+ code = load_object_from_redis(redis_key)
121
+ if not code:
122
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
123
+ outputs = starcoder_model.generate(inputs)
124
+ code = starcoder_tokenizer.decode(outputs[0])
125
+ save_object_to_redis(redis_key, code)
126
+ return code
127
 
128
  def generate_video(prompt):
129
+ redis_key = f"generated_video_{prompt}"
130
+ video = load_object_from_redis(redis_key)
131
+ if not video:
132
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16)
133
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
134
+ pipe.enable_model_cpu_offload()
135
+ video = export_to_video(pipe(prompt, num_inference_steps=25).frames)
136
+ save_object_to_redis(redis_key, video)
137
+ return video
138
 
139
  def test_model_meta_llama():
140
+ redis_key = "meta_llama_test_response"
141
+ response = load_object_from_redis(redis_key)
142
+ if not response:
143
+ messages = [
144
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
145
+ {"role": "user", "content": "Who are you?"}
146
+ ]
147
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"].strip()
148
+ save_object_to_redis(redis_key, response)
149
+ return response
150
 
151
  def train_model(model, dataset, epochs, batch_size, learning_rate):
152
  output_dir = io.BytesIO()
 
187
  save_object_to_redis("music_gen", music_gen)
188
  text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
189
  "text-generation",
190
+ model="google/gemini-2-2b-it",
191
  model_kwargs={"torch_dtype": torch.bfloat16},
192
  device=device,
193
  use_auth_token=hf_token,