diff --git a/auditor/scripts/revert_encoding/revert_encodings.py b/auditor/scripts/revert_encoding/revert_encodings.py index 0eba67450..87d373caa 100644 --- a/auditor/scripts/revert_encoding/revert_encodings.py +++ b/auditor/scripts/revert_encoding/revert_encodings.py @@ -36,6 +36,8 @@ def decode_record(record_id, meta): def main(): BATCH_SIZE = 1000 offset = 0 + conn = None + cursor = None try: conn = psycopg2.connect(**DB_CONFIG) diff --git a/auditor/scripts/revert_encoding/test_script.py b/auditor/scripts/revert_encoding/test_script.py index 719369c76..4faca2105 100644 --- a/auditor/scripts/revert_encoding/test_script.py +++ b/auditor/scripts/revert_encoding/test_script.py @@ -1,15 +1,17 @@ import json import unittest -from unittest.mock import MagicMock, patch from urllib.parse import quote +import psycopg2 +from unittest.mock import MagicMock, patch + from revert_encodings import decode_record, main class TestDecodeRecord(unittest.TestCase): def test_decode_record_success(self): # Test for successful decoding - record_id = quote("test_record_id/") + record_id = quote("test_record_id/", safe="") meta = { "key1": [quote("value1*"), quote("value2%")], "key2": [quote("value3!")], @@ -34,48 +36,68 @@ def test_decode_record_failure(self): self.assertIn("Error decoding meta", str(context.exception)) -@patch("main_script.psycopg2.connect") -def test_main(mock_connect): - # Mock database connection and cursor - mock_conn = MagicMock() - mock_cursor = MagicMock() - - # Setup mock connection - mock_connect.return_value = mock_conn - mock_conn.cursor.return_value = mock_cursor +class TestDatabaseUpdate(unittest.TestCase): + def setUp(self): + """Set up test cases""" + self.fetch_query = "SELECT id, record_id, meta FROM auditor_accounting ORDER BY id LIMIT 1000 OFFSET 0;" + self.update_query = """ + UPDATE auditor_accounting + SET record_id = %s, meta = %s + WHERE id = %s; + """ + + # Sample encoded data + self.sample_id = 1 + self.encoded_record_id = quote("test/record/1", safe="") + self.encoded_meta = {"key1": [quote("value1")]} + + # Expected decoded data + self.expected_record_id = "test/record/1" + self.expected_meta = json.dumps({"key1": ["value1"]}) + + @patch("psycopg2.connect") + def test_main_success(self, mock_connect): + """Test successful execution of main function""" + # Set up mock connection and cursor + mock_cursor = MagicMock() + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + # Mock the database responses + mock_cursor.fetchall.side_effect = [ + [ + (self.sample_id, self.encoded_record_id, self.encoded_meta) + ], # First batch + [], # Empty result to end the loop + ] + + # Run the main function + main() - # Mock database rows - rows = [ - (1, quote("record1/"), {"key1": [quote("value1/")]}), - (2, quote("record2/"), {"key2": [quote("value2/")]}), - ] + # Verify the correct SQL queries were executed + mock_cursor.execute.assert_any_call(self.fetch_query) + mock_cursor.execute.assert_any_call( + self.update_query, + (self.expected_record_id, self.expected_meta, self.sample_id), + ) - mock_cursor.fetchall.side_effect = [rows, []] # Return rows and then stop + # Verify proper cleanup + mock_conn.commit.assert_called_once() + mock_cursor.close.assert_called_once() + mock_conn.close.assert_called_once() - def mock_decode_record(record_id, meta): - return record_id.replace("%20", " "), json.dumps(meta) + @patch("psycopg2.connect") + def test_main_database_error(self, mock_connect): + """Test database error handling""" + # Configure the mock to raise an exception + mock_connect.side_effect = Exception("Database connection failed") - with patch("revert_encodings.decode_record", side_effect=mock_decode_record): + # Run the main function - should handle error gracefully main() - # Assertions - mock_cursor.execute.assert_any_call( - "SELECT id, record_id, meta FROM auditor_accounting ORDER BY id LIMIT 1000 OFFSET 0;" - ) - - for row in rows: - mock_cursor.execute.assert_any_call( - """ - UPDATE auditor_accounting - SET record_id = %s, meta = %s - WHERE id = %s; - """, - (row[1].replace("%20", " "), json.dumps(row[2]), row[0]), - ) - - mock_conn.commit.assert_called() - mock_cursor.close.assert_called() - mock_conn.close.assert_called() + # Verify the connection attempt was made + mock_connect.assert_called_once() if __name__ == "__main__":