Spaces:
Sleeping
Sleeping
import requests | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from operator import add | |
from functools import reduce | |
from keras.models import load_model | |
import random | |
# configure GPUs | |
for gpu in tf.config.list_physical_devices('GPU'): | |
tf.config.experimental.set_memory_growth(gpu, enable=True) | |
if len(tf.config.list_physical_devices('GPU')) > 0: | |
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU') | |
ntmap = {'A': (1, 0, 0, 0), | |
'C': (0, 1, 0, 0), | |
'G': (0, 0, 1, 0), | |
'T': (0, 0, 0, 1) | |
} | |
epimap = {'A': 1, 'N': 0} | |
def get_seqcode(seq): | |
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape( | |
(1, len(seq), -1)) | |
def get_epicode(eseq): | |
return np.array(list(map(lambda c: epimap[c], eseq))).reshape(1, len(eseq), -1) | |
class Episgt: | |
def __init__(self, fpath, num_epi_features, with_y=True): | |
self._fpath = fpath | |
self._ori_df = pd.read_csv(fpath, sep='\t', index_col=None, header=None) | |
self._num_epi_features = num_epi_features | |
self._with_y = with_y | |
self._num_cols = num_epi_features + 2 if with_y else num_epi_features + 1 | |
self._cols = list(self._ori_df.columns)[-self._num_cols:] | |
self._df = self._ori_df[self._cols] | |
def length(self): | |
return len(self._df) | |
def get_dataset(self, x_dtype=np.float32, y_dtype=np.float32): | |
x_seq = np.concatenate(list(map(get_seqcode, self._df[self._cols[0]]))) | |
x_epis = np.concatenate([np.concatenate(list(map(get_epicode, self._df[col]))) for col in | |
self._cols[1: 1 + self._num_epi_features]], axis=-1) | |
x = np.concatenate([x_seq, x_epis], axis=-1).astype(x_dtype) | |
x = x.transpose(0, 2, 1) | |
if self._with_y: | |
y = np.array(self._df[self._cols[-1]]).astype(y_dtype) | |
return x, y | |
else: | |
return x | |
from keras.models import load_model | |
class DCModelOntar: | |
def __init__(self, ontar_model_dir, is_reg=False): | |
self.model = load_model(ontar_model_dir) | |
def ontar_predict(self, x, channel_first=True): | |
if channel_first: | |
x = x.transpose([0, 2, 3, 1]) | |
yp = self.model.predict(x) | |
return yp.ravel() | |
# Function to generate random epigenetic data | |
def generate_random_epigenetic_data(length): | |
return ''.join(random.choice('AN') for _ in range(length)) | |
# Function to predict on-target efficiency and format output | |
def format_prediction_output(gRNA_sites, gene_id, model_path): | |
dcModel = DCModelOntar(model_path) | |
formatted_data = [] | |
for gRNA in gRNA_sites: | |
# Encode the gRNA sequence | |
encoded_seq = get_seqcode(gRNA).reshape(-1,4,1,23) | |
#encoded_seq = np.expand_dims(encoded_seq, axis=2) # Adjust the shape for the model | |
# Generate random epigenetic features (as placeholders) | |
ctcf = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23) | |
dnase = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23) | |
h3k4me3 = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23) | |
rrbs = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23) | |
# Predict on-target efficiency using the model | |
input = np.concatenate((encoded_seq, ctcf, dnase, h3k4me3, rrbs), axis=1) | |
prediction = dcModel.ontar_predict(input) | |
# Format output | |
formatted_data.append([gene_id, "start_pos", "end_pos", "strand", gRNA, ctcf, dnase, h3k4me3, rrbs, prediction[0]]) | |
return formatted_data | |
def fetch_ensembl_transcripts(gene_symbol): | |
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
gene_data = response.json() | |
if 'Transcript' in gene_data: | |
return gene_data['Transcript'] | |
else: | |
print("No transcripts found for gene:", gene_symbol) | |
return None | |
else: | |
print(f"Error fetching gene data from Ensembl: {response.text}") | |
return None | |
def fetch_ensembl_sequence(transcript_id): | |
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json" | |
response = requests.get(url) | |
if response.status_code == 200: | |
sequence_data = response.json() | |
if 'seq' in sequence_data: | |
return sequence_data['seq'] | |
else: | |
print("No sequence found for transcript:", transcript_id) | |
return None | |
else: | |
print(f"Error fetching sequence data from Ensembl: {response.text}") | |
return None | |
def find_crispr_targets(sequence, pam="NGG", target_length=20): | |
targets = [] | |
len_sequence = len(sequence) | |
for i in range(len_sequence - len(pam) + 1): | |
if sequence[i + 1:i + 3] == pam[1:]: | |
if i >= target_length: | |
target_seq = sequence[i - target_length:i + 3] | |
targets.append(target_seq) | |
return targets | |
def process_gene(gene_symbol, model_path): | |
transcripts = fetch_ensembl_transcripts(gene_symbol) | |
all_data = [] | |
if transcripts: | |
for transcript in transcripts: | |
transcript_id = transcript['id'] | |
gene_sequence = fetch_ensembl_sequence(transcript_id) | |
if gene_sequence: | |
gRNA_sites = find_crispr_targets(gene_sequence) | |
if gRNA_sites: | |
formatted_data = format_prediction_output(gRNA_sites, transcript_id, model_path) | |
all_data.extend(formatted_data) | |
return all_data | |
# Function to save results as CSV | |
def save_to_csv(data, filename="crispr_results.csv"): | |
df = pd.DataFrame(data, | |
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "CTCF", "Dnase", "H3K4me3", "RRBS", | |
"Prediction"]) | |
df.to_csv(filename, index=False) |