1010from chdb import session
1111from urllib .request import urlretrieve
1212
13- if os .path .exists (".test_chdb_arrow_table" ):
14- shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
15- sess = session .Session (".test_chdb_arrow_table" )
13+ # Clean up and create session in the test methods instead of globally
1614
1715class TestChDBArrowTable (unittest .TestCase ):
1816 @classmethod
@@ -33,11 +31,16 @@ def setUpClass(cls):
3331
3432 print (f"Loaded Arrow table: { cls .num_rows } rows, { cls .num_columns } columns, { cls .table_size } bytes" )
3533
34+ if os .path .exists (".test_chdb_arrow_table" ):
35+ shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
36+ cls .sess = session .Session (".test_chdb_arrow_table" )
37+
3638 @classmethod
3739 def tearDownClass (cls ):
3840 # Clean up session directory
3941 if os .path .exists (".test_chdb_arrow_table" ):
4042 shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
43+ cls .sess .close ()
4144
4245 def setUp (self ):
4346 pass
@@ -54,23 +57,23 @@ def test_arrow_table_basic_info(self):
5457 def test_arrow_table_count (self ):
5558 """Test counting rows in Arrow table"""
5659 my_arrow_table = self .arrow_table
57- result = sess .query ("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)" , "CSV" )
60+ result = self . sess .query ("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)" , "CSV" )
5861 lines = str (result ).strip ().split ('\n ' )
5962 count = int (lines [0 ])
6063 self .assertEqual (count , self .num_rows , f"Count should match table rows: { self .num_rows } " )
6164
6265 def test_arrow_table_schema (self ):
6366 """Test querying Arrow table schema information"""
6467 my_arrow_table = self .arrow_table
65- result = sess .query ("DESCRIBE Python(my_arrow_table)" , "CSV" )
68+ result = self . sess .query ("DESCRIBE Python(my_arrow_table)" , "CSV" )
6669 # print(result)
6770 self .assertIn ('WatchID' , str (result ))
6871 self .assertIn ('URLHash' , str (result ))
6972
7073 def test_arrow_table_limit (self ):
7174 """Test LIMIT queries on Arrow table"""
7275 my_arrow_table = self .arrow_table
73- result = sess .query ("SELECT * FROM Python(my_arrow_table) LIMIT 5" , "CSV" )
76+ result = self . sess .query ("SELECT * FROM Python(my_arrow_table) LIMIT 5" , "CSV" )
7477 lines = str (result ).strip ().split ('\n ' )
7578 self .assertEqual (len (lines ), 5 , "Should have 5 data rows" )
7679
@@ -82,7 +85,7 @@ def test_arrow_table_select_columns(self):
8285 first_col = schema .field (0 ).name
8386 second_col = schema .field (1 ).name if len (schema ) > 1 else first_col
8487
85- result = sess .query (f"SELECT { first_col } , { second_col } FROM Python(my_arrow_table) LIMIT 3" , "CSV" )
88+ result = self . sess .query (f"SELECT { first_col } , { second_col } FROM Python(my_arrow_table) LIMIT 3" , "CSV" )
8689 lines = str (result ).strip ().split ('\n ' )
8790 self .assertEqual (len (lines ), 3 , "Should have 3 data rows" )
8891
@@ -96,7 +99,7 @@ def test_arrow_table_where_clause(self):
9699 numeric_col = field .name
97100 break
98101
99- result = sess .query (f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE { numeric_col } > 1" , "CSV" )
102+ result = self . sess .query (f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE { numeric_col } > 1" , "CSV" )
100103 lines = str (result ).strip ().split ('\n ' )
101104 count = int (lines [0 ])
102105 self .assertEqual (count , 1000000 )
@@ -111,7 +114,7 @@ def test_arrow_table_group_by(self):
111114 string_col = field .name
112115 break
113116
114- result = sess .query (f"SELECT { string_col } , COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY { string_col } ORDER BY cnt DESC LIMIT 5" , "CSV" )
117+ result = self . sess .query (f"SELECT { string_col } , COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY { string_col } ORDER BY cnt DESC LIMIT 5" , "CSV" )
115118 lines = str (result ).strip ().split ('\n ' )
116119 self .assertEqual (len (lines ), 5 )
117120
@@ -125,7 +128,7 @@ def test_arrow_table_aggregations(self):
125128 numeric_col = field .name
126129 break
127130
128- result = sess .query (f"SELECT AVG({ numeric_col } ) as avg_val, MIN({ numeric_col } ) as min_val, MAX({ numeric_col } ) as max_val FROM Python(my_arrow_table)" , "CSV" )
131+ result = self . sess .query (f"SELECT AVG({ numeric_col } ) as avg_val, MIN({ numeric_col } ) as min_val, MAX({ numeric_col } ) as max_val FROM Python(my_arrow_table)" , "CSV" )
129132 lines = str (result ).strip ().split ('\n ' )
130133 self .assertEqual (len (lines ), 1 )
131134
@@ -135,14 +138,14 @@ def test_arrow_table_order_by(self):
135138 # Use first column for ordering
136139 first_col = self .arrow_table .schema .field (0 ).name
137140
138- result = sess .query (f"SELECT { first_col } FROM Python(my_arrow_table) ORDER BY { first_col } LIMIT 10" , "CSV" )
141+ result = self . sess .query (f"SELECT { first_col } FROM Python(my_arrow_table) ORDER BY { first_col } LIMIT 10" , "CSV" )
139142 lines = str (result ).strip ().split ('\n ' )
140143 self .assertEqual (len (lines ), 10 )
141144
142145 def test_arrow_table_subquery (self ):
143146 """Test subqueries with Arrow table"""
144147 my_arrow_table = self .arrow_table
145- result = sess .query ("""
148+ result = self . sess .query ("""
146149 SELECT COUNT(*) as total_count
147150 FROM (
148151 SELECT * FROM Python(my_arrow_table)
@@ -161,7 +164,7 @@ def test_arrow_table_multiple_tables(self):
161164 # Create a smaller subset table
162165 subset_table = my_arrow_table .slice (0 , min (100 , my_arrow_table .num_rows ))
163166
164- result = sess .query ("""
167+ result = self . sess .query ("""
165168 SELECT
166169 (SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count,
167170 (SELECT COUNT(*) FROM Python(subset_table)) as subset_count
0 commit comments