CRISPRTool / cas9on.py
supercat666's picture
fixed cas9on
ce4236e
raw
history blame
6.03 kB
import requests
import tensorflow as tf
import pandas as pd
import numpy as np
from operator import add
from functools import reduce
from keras.models import load_model
import random
# configure GPUs
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, enable=True)
if len(tf.config.list_physical_devices('GPU')) > 0:
tf.config.experimental.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')
ntmap = {'A': (1, 0, 0, 0),
'C': (0, 1, 0, 0),
'G': (0, 0, 1, 0),
'T': (0, 0, 0, 1)
}
epimap = {'A': 1, 'N': 0}
def get_seqcode(seq):
return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape(
(1, len(seq), -1))
def get_epicode(eseq):
return np.array(list(map(lambda c: epimap[c], eseq))).reshape(1, len(eseq), -1)
class Episgt:
def __init__(self, fpath, num_epi_features, with_y=True):
self._fpath = fpath
self._ori_df = pd.read_csv(fpath, sep='\t', index_col=None, header=None)
self._num_epi_features = num_epi_features
self._with_y = with_y
self._num_cols = num_epi_features + 2 if with_y else num_epi_features + 1
self._cols = list(self._ori_df.columns)[-self._num_cols:]
self._df = self._ori_df[self._cols]
@property
def length(self):
return len(self._df)
def get_dataset(self, x_dtype=np.float32, y_dtype=np.float32):
x_seq = np.concatenate(list(map(get_seqcode, self._df[self._cols[0]])))
x_epis = np.concatenate([np.concatenate(list(map(get_epicode, self._df[col]))) for col in
self._cols[1: 1 + self._num_epi_features]], axis=-1)
x = np.concatenate([x_seq, x_epis], axis=-1).astype(x_dtype)
x = x.transpose(0, 2, 1)
if self._with_y:
y = np.array(self._df[self._cols[-1]]).astype(y_dtype)
return x, y
else:
return x
from keras.models import load_model
class DCModelOntar:
def __init__(self, ontar_model_dir, is_reg=False):
self.model = load_model(ontar_model_dir)
def ontar_predict(self, x, channel_first=True):
if channel_first:
x = x.transpose([0, 2, 3, 1])
yp = self.model.predict(x)
return yp.ravel()
# Function to generate random epigenetic data
def generate_random_epigenetic_data(length):
return ''.join(random.choice('AN') for _ in range(length))
# Function to predict on-target efficiency and format output
def format_prediction_output(gRNA_sites, gene_id, model_path):
dcModel = DCModelOntar(model_path)
formatted_data = []
for gRNA in gRNA_sites:
# Encode the gRNA sequence
encoded_seq = get_seqcode(gRNA).reshape(-1,4,1,23)
#encoded_seq = np.expand_dims(encoded_seq, axis=2) # Adjust the shape for the model
# Generate random epigenetic features (as placeholders)
ctcf = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
dnase = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
h3k4me3 = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
rrbs = get_epicode(generate_random_epigenetic_data(len(gRNA))).reshape(-1,1,1,23)
# Predict on-target efficiency using the model
input = np.concatenate((encoded_seq, ctcf, dnase, h3k4me3, rrbs), axis=1)
prediction = dcModel.ontar_predict(input)
# Format output
formatted_data.append([gene_id, "start_pos", "end_pos", "strand", gRNA, ctcf, dnase, h3k4me3, rrbs, prediction[0]])
return formatted_data
def fetch_ensembl_transcripts(gene_symbol):
url = f"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json"
response = requests.get(url)
if response.status_code == 200:
gene_data = response.json()
if 'Transcript' in gene_data:
return gene_data['Transcript']
else:
print("No transcripts found for gene:", gene_symbol)
return None
else:
print(f"Error fetching gene data from Ensembl: {response.text}")
return None
def fetch_ensembl_sequence(transcript_id):
url = f"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json"
response = requests.get(url)
if response.status_code == 200:
sequence_data = response.json()
if 'seq' in sequence_data:
return sequence_data['seq']
else:
print("No sequence found for transcript:", transcript_id)
return None
else:
print(f"Error fetching sequence data from Ensembl: {response.text}")
return None
def find_crispr_targets(sequence, pam="NGG", target_length=20):
targets = []
len_sequence = len(sequence)
for i in range(len_sequence - len(pam) + 1):
if sequence[i + 1:i + 3] == pam[1:]:
if i >= target_length:
target_seq = sequence[i - target_length:i + 3]
targets.append(target_seq)
return targets
def process_gene(gene_symbol, model_path):
transcripts = fetch_ensembl_transcripts(gene_symbol)
all_data = []
if transcripts:
for transcript in transcripts:
transcript_id = transcript['id']
gene_sequence = fetch_ensembl_sequence(transcript_id)
if gene_sequence:
gRNA_sites = find_crispr_targets(gene_sequence)
if gRNA_sites:
formatted_data = format_prediction_output(gRNA_sites, transcript_id, model_path)
all_data.extend(formatted_data)
return all_data
# Function to save results as CSV
def save_to_csv(data, filename="crispr_results.csv"):
df = pd.DataFrame(data,
columns=["Gene ID", "Start Pos", "End Pos", "Strand", "gRNA", "CTCF", "Dnase", "H3K4me3", "RRBS",
"Prediction"])
df.to_csv(filename, index=False)