"""Get/put submission results concerning attention from/on COS.""" import os import json import dill import logging import numpy as np from typing import Iterable from configuration import GENES from cos import ( RESULTS_PREFIX, bytes_from_key, string_from_key, bytes_to_key, ) from utils import Drug from plots import embed_barplot from smiles import smiles_attention_to_svg logger = logging.getLogger("openapi_server:attention") def download_attention(workspace_id: str, task_id: str, sample_name: str) -> dict: """ Download attention figures and related data. Args: workspace_id (str): workspace identifier. task_id (str): task identifier. sample_name (str): name of the sample. Returns: dict: attention figures and related data. """ def _remote_to_bytes(basename: str) -> bytes: object_name = os.path.join(workspace_id, task_id, sample_name, basename) key = os.path.join(RESULTS_PREFIX, object_name) return bytes_from_key(key) drug_path = os.path.join(workspace_id, task_id, "drug.json") key = os.path.join(RESULTS_PREFIX, drug_path) drug = Drug(**json.loads(string_from_key(key))) logger.debug(f"download attention results from COS for {drug.smiles}.") # omic logger.debug("gene attention.") gene_attention = dill.loads(_remote_to_bytes("gene_attention.pkl")) genes = np.array(GENES) order = gene_attention.argsort()[::-1] # descending gene_attention_js, gene_attention_html = embed_barplot( genes[order], gene_attention[order] ) logger.debug("gene attention plots created.") # smiles logger.debug("SMILES attention.") smiles_attention = dill.loads(_remote_to_bytes("smiles_attention.pkl")) drug_attention_svg, drug_color_bar_svg = smiles_attention_to_svg( drug.smiles, smiles_attention ) logger.debug("SMILES attention plots created.") return { "drug": drug, "sample_name": sample_name, "sample_drug_attention_svg": drug_attention_svg, "sample_drug_color_bar_svg": drug_color_bar_svg, "sample_gene_attention_js": gene_attention_js, "sample_gene_attention_html": gene_attention_html, } def _upload_ndarray(sample_prefix: str, array: np.ndarray, filename: str) -> None: bytes_to_key(dill.dumps(array), os.path.join(sample_prefix, f"{filename}.pkl")) def upload_attention( prefix: str, sample_names: Iterable[str], omic_attention: np.ndarray, smiles_attention: np.ndarray, ) -> None: """ Upload attention profiles. Args: prefix (str): base prefix used as a root. sample_names (Iterable[str]): name of the samples. omic_attention (np.ndarray): attention values for genes. smiles_attention (np.ndarray): attention values for SMILES. Raises: ValueError: mismatch in sample names and gene attention. ValueError: mismatch in sample names and SMILES attention. ValueError: mismatch in number of genes and gene attention. """ omic_entities = np.array(GENES) # sanity checks if len(sample_names) != omic_attention.shape[0]: raise ValueError( f"length of sample_names {len(sample_names)} does not " f"match omic_attention {omic_attention.shape[0]}" ) if len(sample_names) != len(smiles_attention): raise ValueError( f"length of sample_names {len(sample_names)} does not " f"match smiles_attention {len(smiles_attention)}" ) if len(omic_entities) != omic_attention.shape[1]: raise ValueError( f"length of omic_entities {len(omic_entities)} " f"does not match omic_attention.shape[1] {omic_attention.shape[1]}" ) # special case first sample_name = "average" # omic res = {} omic_alphas = omic_attention.mean(axis=0) res["gene_attention"] = omic_alphas # smiles smiles_alphas = smiles_attention.mean(axis=0) res["smiles_attention"] = smiles_alphas # logging.debug('uploaded "average" attention figures.') # for index, sample_name in enumerate(sample_names): # res[f"gene_attention_{index}"] = omic_attention[index] # res[f"smiles_attention_{index}"] = smiles_attention[index] return res