Spaces:
Sleeping
Sleeping
import re | |
from multiprocessing import cpu_count | |
from keras.src.saving import load_model | |
import pandas as pd | |
from keras.src.utils import set_random_seed | |
from numpy import int64 | |
from pandarallel import pandarallel | |
from sklearn.preprocessing import RobustScaler | |
import gradio as gr | |
set_random_seed(65536) | |
pandarallel.initialize(use_memory_fs=False, nb_workers=cpu_count()) | |
model = load_model('./sqid.keras') | |
def sql_tokenize(sql_query): | |
sql_query = sql_query.replace('`', ' ').replace('%20', ' ').replace('=', ' = ').replace('((', ' (( ').replace( | |
'))', ' )) ').replace('(', ' ( ').replace(')', ' ) ').replace('||', ' || ').replace(',', '').replace( | |
'--', ' -- ').replace(':', ' : ').replace('%23', ' # ').replace('+', ' + ').replace('!=', | |
' != ') \ | |
.replace('"', ' " ').replace('%26', ' and ').replace('$', ' $ ').replace('%28', ' ( ').replace('%2A', ' * ') \ | |
.replace('%7C', ' | ').replace('&', ' & ').replace(']', ' ] ').replace('[', ' [ ').replace(';', | |
' ; ').replace( | |
'/*', ' /* ') | |
sql_reserved = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER', 'BY', 'GROUP', 'HAVING', | |
'LIMIT', 'BETWEEN', 'IS', 'NULL', '%', 'LIKE', 'MIN', 'MAX', 'AS', 'UPPER', 'LOWER', 'TO_DATE', | |
'=', '>', '<', '>=', '<=', '!=', '<>', 'BETWEEN', 'LIKE', 'EXISTS', 'JOIN', 'UNION', 'ALL', | |
'ASC', 'DESC', '||', 'AVG', 'LIMIT', 'EXCEPT', 'INTERSECT', 'CASE', 'WHEN', 'THEN', 'IF', | |
'IF', 'ANY', 'CAST', 'CONVERT', 'COALESCE', 'NULLIF', 'INNER', 'OUTER', 'LEFT', 'RIGHT', 'FULL', | |
'CROSS', 'OVER', 'PARTITION', 'SUM', 'COUNT', 'WITH', 'INTERVAL', 'WINDOW', 'OVER', | |
'ROW_NUMBER', 'RANK', | |
'DENSE_RANK', 'NTILE', 'FIRST_VALUE', 'LAST_VALUE', 'LAG', 'LEAD', 'DISTINCT', 'COMMENT', | |
'INSERT', | |
'UPDATE', 'DELETED', 'MERGE', '*', 'generate_series', 'char', 'chr', 'substr', 'lpad', | |
'extract', | |
'year', 'month', 'day', 'timestamp', 'number', 'string', 'concat', 'INFORMATION_SCHEMA', | |
"SQLITE_MASTER", 'TABLES', 'COLUMNS', 'CUBE', 'ROLLUP', 'RECURSIVE', 'FILTER', 'EXCLUDE', | |
'AUTOINCREMENT', 'WITHOUT', 'ROWID', 'VIRTUAL', 'INDEXED', 'UNINDEXED', 'SERIAL', | |
'DO', 'RETURNING', 'ILIKE', 'ARRAY', 'ANYARRAY', 'JSONB', 'TSQUERY', 'SEQUENCE', | |
'SYNONYM', 'CONNECT', 'START', 'LEVEL', 'ROWNUM', 'NOCOPY', 'MINUS', 'AUTO_INCREMENT', 'BINARY', | |
'ENUM', 'REPLACE', 'SET', 'SHOW', 'DESCRIBE', 'USE', 'EXPLAIN', 'STORED', 'VIRTUAL', 'RLIKE', | |
'MD5', 'SLEEP', 'BENCHMARK', '@@VERSION', 'VERSION', '@VERSION', 'CONVERT', 'NVARCHAR', '#', | |
'##', 'INJECTX', | |
'DELAY', 'WAITFOR', 'RAND', | |
} | |
tokens = sql_query.split() | |
tokens = [re.sub(r"""[^*\w\s.=\-><_|()!"']""", '', token) for token in tokens] | |
for i, token in enumerate(tokens): | |
if token.strip().upper() in sql_reserved: | |
continue | |
if token.strip().isnumeric(): | |
tokens[i] = '#NUMBER#' | |
elif re.match(r'^[a-zA-Z_.|][a-zA-Z0-9_.|]*$', token.strip()): | |
tokens[i] = '#IDENTIFIER#' | |
elif re.match(r'^[\d:]*$', token.strip()): | |
tokens[i] = '#TIMESTAMP#' | |
elif '%' in token.strip(): | |
tokens[i] = ' '.join( | |
[j.strip() if j.strip() in ('%', "'", "'") else '#IDENTIFIER#' for j in token.strip().split('%')]) | |
return ' '.join(tokens) | |
def add_features(x): | |
x['Query'] = x['Query'].copy().parallel_apply(lambda a: sql_tokenize(a)) | |
x['num_tables'] = x['Query'].str.lower().str.count(r'FROM\s+#IDENTIFIER#', flags=re.I) | |
x['num_columns'] = x['Query'].str.lower().str.count(r'SELECT\s+#IDENTIFIER#', flags=re.I) | |
x['num_literals'] = x['Query'].str.lower().str.count("'[^']*'", flags=re.I) + x['Query'].str.lower().str.count( | |
'"[^"]"', flags=re.I) | |
x['num_parentheses'] = x['Query'].str.lower().str.count("\\(", flags=re.I) + x['Query'].str.lower().str.count( | |
'\\)', | |
flags=re.I) | |
x['has_union'] = x['Query'].str.lower().str.count(" union |union all", flags=re.I) > 0 | |
x['has_union'] = x['has_union'].astype(int64) | |
x['depth_nested_queries'] = x['Query'].str.lower().str.count("\\(", flags=re.I) | |
x['num_join'] = x['Query'].str.lower().str.count( | |
" join |inner join|outer join|full outer join|full inner join|cross join|left join|right join", | |
flags=re.I) | |
x['num_sp_chars'] = x['Query'].parallel_apply(lambda a: len(re.findall(r'[\'";\-*/%=><|#]', a))) | |
x['has_mismatched_quotes'] = x['Query'].parallel_apply( | |
lambda sql_query: 1 if re.search(r"'.*[^']$|\".*[^\"]$", sql_query) else 0) | |
x['has_tautology'] = x['Query'].parallel_apply(lambda sql_query: 1 if re.search(r"'[\s]*=[\s]*'", sql_query) else 0) | |
return x | |
def is_malicious_sql(sql, threshold): | |
input_df = pd.DataFrame([sql], columns=['Query']) | |
input_df = add_features(input_df) | |
numeric_features = ["num_tables", "num_columns", "num_literals", "num_parentheses", "has_union", | |
"depth_nested_queries", "num_join", "num_sp_chars", "has_mismatched_quotes", "has_tautology"] | |
scaler = RobustScaler() | |
x_in = scaler.fit_transform(input_df[numeric_features]) | |
preds = model.predict([input_df['Query'], x_in]).tolist()[0][0] | |
if preds > float(threshold): | |
return 'Malicious' | |
return 'Safe' | |
def respond( | |
message, | |
history, | |
threshold | |
): | |
if len(history) > 5: | |
history = history[1:] | |
for val in history: | |
if val[0].lower().strip() == message.lower().strip(): | |
return val[1] | |
val = (message.lower().strip(), is_malicious_sql(message, threshold)) | |
print(val) | |
return val[1] | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Slider(minimum=0.01, maximum=0.99, value=0.75, step=0.01, label="Detection Probability Threshold "), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |