Spaces:
Running
Running
# Test_SQLite_DB.py | |
# Description: Test file for SQLite_DB.py | |
# | |
# Usage: python -m unittest test_sqlite_db.py | |
# | |
# Imports | |
import unittest | |
import sqlite3 | |
import threading | |
import time | |
from unittest.mock import patch | |
# | |
# Local Imports | |
from App_Function_Libraries.DB.SQLite_DB import Database, add_media_with_keywords, add_media_version, DatabaseError | |
# | |
####################################################################################################################### | |
# | |
# Functions: | |
class TestDatabase(unittest.TestCase): | |
def setUp(self): | |
self.db = Database(':memory:') # Use in-memory database for testing | |
def test_connection_management(self): | |
with self.db.get_connection() as conn: | |
self.assertIsInstance(conn, sqlite3.Connection) | |
self.assertEqual(len(self.db.pool), 1) | |
def test_execute_query(self): | |
self.db.execute_query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") | |
self.db.execute_query("INSERT INTO test (name) VALUES (?)", ("test_name",)) | |
with self.db.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT name FROM test") | |
result = cursor.fetchone() | |
self.assertEqual(result[0], "test_name") | |
def test_execute_many(self): | |
self.db.execute_query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") | |
data = [("name1",), ("name2",), ("name3",)] | |
self.db.execute_many("INSERT INTO test (name) VALUES (?)", data) | |
with self.db.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT COUNT(*) FROM test") | |
count = cursor.fetchone()[0] | |
self.assertEqual(count, 3) | |
def test_connection_retry(self): | |
def lock_database(): | |
with self.db.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("BEGIN EXCLUSIVE TRANSACTION") | |
time.sleep(2) # Hold the lock for 2 seconds | |
thread = threading.Thread(target=lock_database) | |
thread.start() | |
time.sleep(0.1) # Give the thread time to acquire the lock | |
with self.assertRaises(DatabaseError): | |
self.db.execute_query("SELECT 1") # This should retry and eventually fail | |
thread.join() | |
class TestAddMediaWithKeywords(unittest.TestCase): | |
def setUp(self): | |
self.db = Database(':memory:') | |
self.db.execute_query(""" | |
CREATE TABLE Media ( | |
id INTEGER PRIMARY KEY, | |
url TEXT, | |
title TEXT NOT NULL, | |
type TEXT NOT NULL, | |
content TEXT, | |
author TEXT, | |
ingestion_date TEXT, | |
transcription_model TEXT | |
) | |
""") | |
self.db.execute_query("CREATE TABLE Keywords (id INTEGER PRIMARY KEY, keyword TEXT NOT NULL UNIQUE)") | |
self.db.execute_query(""" | |
CREATE TABLE MediaKeywords ( | |
id INTEGER PRIMARY KEY, | |
media_id INTEGER NOT NULL, | |
keyword_id INTEGER NOT NULL, | |
FOREIGN KEY (media_id) REFERENCES Media(id), | |
FOREIGN KEY (keyword_id) REFERENCES Keywords(id) | |
) | |
""") | |
self.db.execute_query(""" | |
CREATE TABLE MediaModifications ( | |
id INTEGER PRIMARY KEY, | |
media_id INTEGER NOT NULL, | |
prompt TEXT, | |
summary TEXT, | |
modification_date TEXT, | |
FOREIGN KEY (media_id) REFERENCES Media(id) | |
) | |
""") | |
self.db.execute_query(""" | |
CREATE TABLE MediaVersion ( | |
id INTEGER PRIMARY KEY, | |
media_id INTEGER NOT NULL, | |
version INTEGER NOT NULL, | |
prompt TEXT, | |
summary TEXT, | |
created_at TEXT NOT NULL, | |
FOREIGN KEY (media_id) REFERENCES Media(id) | |
) | |
""") | |
self.db.execute_query("CREATE VIRTUAL TABLE media_fts USING fts5(title, content)") | |
def test_add_new_media(self, mock_db): | |
mock_db.get_connection = self.db.get_connection | |
result = add_media_with_keywords( | |
url="http://example.com", | |
title="Test Title", | |
media_type="article", | |
content="Test content", | |
keywords="test,keyword", | |
prompt="Test prompt", | |
summary="Test summary", | |
transcription_model="Test model", | |
author="Test Author", | |
ingestion_date="2023-01-01" | |
) | |
self.assertIn("added/updated successfully", result) | |
with self.db.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT COUNT(*) FROM Media") | |
self.assertEqual(cursor.fetchone()[0], 1) | |
cursor.execute("SELECT COUNT(*) FROM Keywords") | |
self.assertEqual(cursor.fetchone()[0], 2) | |
cursor.execute("SELECT COUNT(*) FROM MediaKeywords") | |
self.assertEqual(cursor.fetchone()[0], 2) | |
cursor.execute("SELECT COUNT(*) FROM MediaModifications") | |
self.assertEqual(cursor.fetchone()[0], 1) | |
cursor.execute("SELECT COUNT(*) FROM MediaVersion") | |
self.assertEqual(cursor.fetchone()[0], 1) | |
def test_update_existing_media(self, mock_db): | |
mock_db.get_connection = self.db.get_connection | |
add_media_with_keywords( | |
url="http://example.com", | |
title="Test Title", | |
media_type="article", | |
content="Test content", | |
keywords="test,keyword", | |
prompt="Test prompt", | |
summary="Test summary", | |
transcription_model="Test model", | |
author="Test Author", | |
ingestion_date="2023-01-01" | |
) | |
result = add_media_with_keywords( | |
url="http://example.com", | |
title="Updated Title", | |
media_type="article", | |
content="Updated content", | |
keywords="test,new", | |
prompt="Updated prompt", | |
summary="Updated summary", | |
transcription_model="Updated model", | |
author="Updated Author", | |
ingestion_date="2023-01-02" | |
) | |
self.assertIn("added/updated successfully", result) | |
with self.db.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT COUNT(*) FROM Media") | |
self.assertEqual(cursor.fetchone()[0], 1) | |
cursor.execute("SELECT title FROM Media") | |
self.assertEqual(cursor.fetchone()[0], "Updated Title") | |
cursor.execute("SELECT COUNT(*) FROM Keywords") | |
self.assertEqual(cursor.fetchone()[0], 3) | |
cursor.execute("SELECT COUNT(*) FROM MediaKeywords") | |
self.assertEqual(cursor.fetchone()[0], 3) | |
cursor.execute("SELECT COUNT(*) FROM MediaModifications") | |
self.assertEqual(cursor.fetchone()[0], 2) | |
cursor.execute("SELECT COUNT(*) FROM MediaVersion") | |
self.assertEqual(cursor.fetchone()[0], 2) | |
if __name__ == '__main__': | |
unittest.main() | |
# | |
# End of File | |
####################################################################################################################### | |