import torch import torch.nn.functional as F import numpy as np from PIL import Image import requests from datetime import datetime,timedelta import re attn_maps = {} def hook_fn(name): def forward_hook(module, input, output): if hasattr(module.processor, "attn_map"): attn_maps[name] = module.processor.attn_map del module.processor.attn_map return forward_hook def register_cross_attention_hook(unet): for name, module in unet.named_modules(): if name.split('.')[-1].startswith('attn2'): module.register_forward_hook(hook_fn(name)) return unet def upscale(attn_map, target_size): attn_map = torch.mean(attn_map, dim=0) attn_map = attn_map.permute(1,0) temp_size = None for i in range(0,5): scale = 2 ** i if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) break assert temp_size is not None, "temp_size cannot is None" attn_map = attn_map.view(attn_map.shape[0], *temp_size) attn_map = F.interpolate( attn_map.unsqueeze(0).to(dtype=torch.float32), size=target_size, mode='bilinear', align_corners=False )[0] attn_map = torch.softmax(attn_map, dim=0) return attn_map def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): idx = 0 if instance_or_negative else 1 net_attn_maps = [] for name, attn_map in attn_maps.items(): attn_map = attn_map.cpu() if detach else attn_map attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() attn_map = upscale(attn_map, image_size) net_attn_maps.append(attn_map) net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) return net_attn_maps def attnmaps2images(net_attn_maps): #total_attn_scores = 0 images = [] for attn_map in net_attn_maps: attn_map = attn_map.cpu().numpy() #total_attn_scores += attn_map.mean().item() normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 normalized_attn_map = normalized_attn_map.astype(np.uint8) #print("norm: ", normalized_attn_map.shape) image = Image.fromarray(normalized_attn_map) #image = fix_save_attn_map(attn_map) images.append(image) #print(total_attn_scores) return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") class RemoteJson: def __init__(self, url, refresh_gap_seconds=3600, processor=None): """ Initialize the RemoteJsonManager. :param url: The URL of the remote JSON file. :param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed. :param processor: Optional callback function to process the JSON after it's loaded successfully. """ self.url = url self.refresh_gap_seconds = refresh_gap_seconds self.processor = processor self.json_data = None self.last_updated = None def _load_json(self): """ Load JSON from the remote URL. If loading fails, return None. """ try: response = requests.get(self.url) response.raise_for_status() return response.json() except requests.RequestException as e: print(f"Failed to fetch JSON: {e}") return None def _should_refresh(self): """ Check whether the JSON should be refreshed based on the time gap. """ if not self.last_updated: return True # If no last update, always refresh return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds) def _update_json(self): """ Fetch and load the JSON from the remote URL. If it fails, keep the previous data. """ new_json = self._load_json() if new_json: self.json_data = new_json self.last_updated = datetime.now() print("JSON updated successfully.") if self.processor: self.json_data = self.processor(self.json_data) else: print("Failed to update JSON. Keeping the previous version.") def get(self): """ Get the JSON, checking whether it needs to be refreshed. If refresh is required, it fetches the new data and applies the processor. """ if self._should_refresh(): print("Refreshing JSON...") self._update_json() else: print("Using cached JSON.") return self.json_data def extract_key_value_pairs(input_string): # Define the regular expression to match [xxx:yyy] where yyy can have special characters pattern = r"\[([^\]]+):([^\]]+)\]" # Find all matches in the input string with the original matching string matches = re.finditer(pattern, input_string) # Convert matches to a list of dictionaries including the raw matching string result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches] return result def extract_characters(prefix, input_string): # Define the regular expression to match placeholders starting with "@" and ending with space or comma pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)" # Find all matches in the input string matches = re.findall(pattern, input_string) # Return a list of dictionaries with the extracted placeholders result = [{"raw": f"{prefix}{match}", "key": match} for match in matches] return result