14
14
# KIND, either express or implied. See the License for the
15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
+ import ctypes
17
18
import datetime
18
19
import os
19
20
import re
20
21
import threading
21
22
import time
22
- import ctypes
23
23
from typing import Any
24
24
25
25
import pyarrow as pa
@@ -1299,7 +1299,6 @@ def test_collect_partitioned():
1299
1299
assert [[batch ]] == ctx .create_dataframe ([[batch ]]).collect_partitioned ()
1300
1300
1301
1301
1302
-
1303
1302
def test_union (ctx ):
1304
1303
batch = pa .RecordBatch .from_arrays (
1305
1304
[pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
@@ -1917,7 +1916,7 @@ def test_fill_null_date32_column(null_df):
1917
1916
dates = result .column (4 ).to_pylist ()
1918
1917
assert dates [0 ] == datetime .date (2000 , 1 , 1 ) # Original value
1919
1918
assert dates [1 ] == epoch_date # Filled value
1920
- assert dates [2 ] == datetime .date (2022 , 1 , 1 ) # Original value
1919
+ assert dates [2 ] == datetime .date (2022 , 1 , 1 ) # Original value
1921
1920
assert dates [3 ] == epoch_date # Filled value
1922
1921
1923
1922
# Other date column should be unchanged
@@ -2068,13 +2067,13 @@ def test_fill_null_all_null_column(ctx):
2068
2067
2069
2068
def test_collect_interrupted ():
2070
2069
"""Test that a long-running query can be interrupted with Ctrl-C.
2071
-
2070
+
2072
2071
This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
2073
2072
exception in the main thread during a long-running query execution.
2074
2073
"""
2075
2074
# Create a context and a DataFrame with a query that will run for a while
2076
2075
ctx = SessionContext ()
2077
-
2076
+
2078
2077
# Create a recursive computation that will run for some time
2079
2078
batches = []
2080
2079
for i in range (10 ):
@@ -2086,49 +2085,49 @@ def test_collect_interrupted():
2086
2085
names = ["a" , "b" ],
2087
2086
)
2088
2087
batches .append (batch )
2089
-
2088
+
2090
2089
# Register tables
2091
2090
ctx .register_record_batches ("t1" , [batches ])
2092
2091
ctx .register_record_batches ("t2" , [batches ])
2093
-
2092
+
2094
2093
# Create a large join operation that will take time to process
2095
2094
df = ctx .sql ("""
2096
2095
WITH t1_expanded AS (
2097
- SELECT
2098
- a,
2099
- b,
2096
+ SELECT
2097
+ a,
2098
+ b,
2100
2099
CAST(a AS DOUBLE) / 1.5 AS c,
2101
2100
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2102
2101
FROM t1
2103
2102
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2104
2103
),
2105
2104
t2_expanded AS (
2106
- SELECT
2105
+ SELECT
2107
2106
a,
2108
2107
b,
2109
2108
CAST(a AS DOUBLE) * 2.5 AS e,
2110
2109
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2111
2110
FROM t2
2112
2111
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2113
2112
)
2114
- SELECT
2115
- t1.a, t1.b, t1.c, t1.d,
2113
+ SELECT
2114
+ t1.a, t1.b, t1.c, t1.d,
2116
2115
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2117
2116
FROM t1_expanded t1
2118
2117
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2119
2118
WHERE t1.a > 100 AND t2.a > 100
2120
2119
""" )
2121
-
2120
+
2122
2121
# Flag to track if the query was interrupted
2123
2122
interrupted = False
2124
2123
interrupt_error = None
2125
2124
main_thread = threading .main_thread ()
2126
-
2125
+
2127
2126
# Shared flag to indicate query execution has started
2128
2127
query_started = threading .Event ()
2129
2128
max_wait_time = 5.0 # Maximum wait time in seconds
2130
-
2131
- # This function will be run in a separate thread and will raise
2129
+
2130
+ # This function will be run in a separate thread and will raise
2132
2131
# KeyboardInterrupt in the main thread
2133
2132
def trigger_interrupt ():
2134
2133
"""Poll for query start, then raise KeyboardInterrupt in the main thread"""
@@ -2139,31 +2138,33 @@ def trigger_interrupt():
2139
2138
if time .time () - start_time > max_wait_time :
2140
2139
msg = f"Query did not start within { max_wait_time } seconds"
2141
2140
raise RuntimeError (msg )
2142
-
2141
+
2143
2142
# Check if thread ID is available
2144
2143
thread_id = main_thread .ident
2145
2144
if thread_id is None :
2146
2145
msg = "Cannot get main thread ID"
2147
2146
raise RuntimeError (msg )
2148
-
2147
+
2149
2148
# Use ctypes to raise exception in main thread
2150
2149
exception = ctypes .py_object (KeyboardInterrupt )
2151
2150
res = ctypes .pythonapi .PyThreadState_SetAsyncExc (
2152
- ctypes .c_long (thread_id ), exception )
2151
+ ctypes .c_long (thread_id ), exception
2152
+ )
2153
2153
if res != 1 :
2154
2154
# If res is 0, the thread ID was invalid
2155
2155
# If res > 1, we modified multiple threads
2156
2156
ctypes .pythonapi .PyThreadState_SetAsyncExc (
2157
- ctypes .c_long (thread_id ), ctypes .py_object (0 ))
2157
+ ctypes .c_long (thread_id ), ctypes .py_object (0 )
2158
+ )
2158
2159
msg = "Failed to raise KeyboardInterrupt in main thread"
2159
2160
raise RuntimeError (msg )
2160
-
2161
+
2161
2162
# Start a thread to trigger the interrupt
2162
2163
interrupt_thread = threading .Thread (target = trigger_interrupt )
2163
- # we mark as daemon so the test process can exit even if this thread doesn’ t finish
2164
+ # we mark as daemon so the test process can exit even if this thread doesn' t finish
2164
2165
interrupt_thread .daemon = True
2165
2166
interrupt_thread .start ()
2166
-
2167
+
2167
2168
# Execute the query and expect it to be interrupted
2168
2169
try :
2169
2170
# Signal that we're about to start the query
@@ -2173,10 +2174,10 @@ def trigger_interrupt():
2173
2174
interrupted = True
2174
2175
except Exception as e :
2175
2176
interrupt_error = e
2176
-
2177
+
2177
2178
# Assert that the query was interrupted properly
2178
2179
if not interrupted :
2179
2180
pytest .fail (f"Query was not interrupted; got error: { interrupt_error } " )
2180
-
2181
+
2181
2182
# Make sure the interrupt thread has finished
2182
2183
interrupt_thread .join (timeout = 1.0 )
0 commit comments