Spaces:
Sleeping
Sleeping
import abc | |
import logging | |
import re | |
from typing import Any | |
import torch | |
from diffusers import AudioLDM2Pipeline, AutoPipelineForText2Image | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
SAMPLE_RATE = 16000 | |
class BaseHint(BaseModel, abc.ABC): | |
configs: dict | |
hints: list = [] | |
model: Any = None | |
def initialize(self): | |
"""Initialize the hint model.""" | |
pass | |
def generate_hint(self, country: str, n_hints: int): | |
"""Generate hints. | |
Args: | |
country (str): Country name used to base the hint | |
n_hints (int): Number of hints that will be generated | |
""" | |
pass | |
class TextHint(BaseHint): | |
tokenizer: Any = None | |
def initialize(self): | |
logger.info( | |
f"""Initializing text hint with model '{self.configs["model_id"]}'""" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.configs["model_id"], | |
token=self.configs["hf_access_token"], | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.configs["model_id"], | |
torch_dtype=torch.float16, | |
token=self.configs["hf_access_token"], | |
).to(self.configs["device"]) | |
logger.info("Initialization finisehd") | |
def generate_hint(self, country: str, n_hints: int): | |
logger.info(f"Generating '{n_hints}' text hints") | |
generation_config = GenerationConfig( | |
do_sample=True, | |
max_new_tokens=self.configs["max_output_tokens"], | |
top_k=self.configs["top_k"], | |
top_p=self.configs["top_p"], | |
temperature=self.configs["temperature"], | |
) | |
prompt = [ | |
f'Describe the country "{country}" without mentioning its name\n' | |
for _ in range(n_hints) | |
] | |
input_ids = self.tokenizer(prompt, return_tensors="pt") | |
text_hints = self.model.generate( | |
**input_ids.to(self.configs["device"]), | |
generation_config=generation_config, | |
) | |
for idx, text_hint in enumerate(text_hints): | |
text_hint = ( | |
self.tokenizer.decode(text_hint, skip_special_tokens=True) | |
.strip() | |
.replace(prompt[idx], "") | |
.strip() | |
) | |
text_hint = re.sub( | |
re.escape(country), "***", text_hint, flags=re.IGNORECASE | |
) | |
self.hints.append({"text": text_hint}) | |
logger.info(f"Text hints '{n_hints}' successfully generated") | |
class ImageHint(BaseHint): | |
def initialize(self): | |
logger.info( | |
f"""Initializing image hint with model '{self.configs["model_id"]}'""" | |
) | |
self.model = AutoPipelineForText2Image.from_pretrained( | |
self.configs["model_id"], | |
# torch_dtype=torch.float16, | |
variant="fp16", | |
).to(self.configs["device"]) | |
logger.info("Initialization finisehd") | |
def generate_hint(self, country: str, n_hints: int): | |
logger.info(f"Generating '{n_hints}' image hints") | |
prompt = [f"An image related to the country {country}" for _ in range(n_hints)] | |
img_hints = self.model( | |
prompt=prompt, | |
num_inference_steps=self.configs["num_inference_steps"], | |
guidance_scale=self.configs["guidance_scale"], | |
).images | |
self.hints = [{"image": img_hint} for img_hint in img_hints] | |
logger.info(f"Image hints '{n_hints}' successfully generated") | |
class AudioHint(BaseHint): | |
def initialize(self): | |
logger.info( | |
f"""Initializing audio hint with model '{self.configs["model_id"]}'""" | |
) | |
self.model = AudioLDM2Pipeline.from_pretrained( | |
self.configs["model_id"], | |
# torch_dtype=torch.float16, # Not working with MacOS | |
).to(self.configs["device"]) | |
logger.info("Initialization finisehd") | |
def generate_hint(self, country: str, n_hints: int): | |
logger.info(f"Generating '{n_hints}' audio hints") | |
prompt = f"A sound that resembles the country of {country}" | |
negative_prompt = "Low quality" | |
audio_hints = self.model( | |
prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=self.configs["num_inference_steps"], | |
audio_length_in_s=self.configs["audio_length_in_s"], | |
num_waveforms_per_prompt=n_hints, | |
).audios | |
for audio_hint in audio_hints: | |
self.hints.append( | |
{ | |
"audio": audio_hint, | |
"sample_rate": SAMPLE_RATE, | |
} | |
) | |
logger.info(f"Audio hints '{n_hints}' successfully generated") | |