|
33 | 33 | from google.cloud import bigquery |
34 | 34 | from google.cloud.bigquery import StandardSqlDataType |
35 | 35 | from google.cloud.bigquery.client import Client as BigQueryClient |
| 36 | + from google.cloud.bigquery.job import QueryJob |
36 | 37 | from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult |
37 | 38 | from google.cloud.bigquery.table import Table as BigQueryTable |
38 | 39 |
|
@@ -186,6 +187,31 @@ def query_factory() -> Query: |
186 | 187 | ) |
187 | 188 | ] |
188 | 189 |
|
| 190 | + def close(self) -> t.Any: |
| 191 | + # Cancel all pending query jobs across all threads |
| 192 | + all_query_jobs = self._connection_pool.get_all_attributes("query_job") |
| 193 | + for query_job in all_query_jobs: |
| 194 | + if query_job: |
| 195 | + try: |
| 196 | + if not self._db_call(query_job.done): |
| 197 | + self._db_call(query_job.cancel) |
| 198 | + logger.debug( |
| 199 | + "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", |
| 200 | + query_job.project, |
| 201 | + query_job.location, |
| 202 | + query_job.job_id, |
| 203 | + ) |
| 204 | + except Exception as ex: |
| 205 | + logger.debug( |
| 206 | + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
| 207 | + query_job.project, |
| 208 | + query_job.location, |
| 209 | + query_job.job_id, |
| 210 | + str(ex), |
| 211 | + ) |
| 212 | + |
| 213 | + return super().close() |
| 214 | + |
189 | 215 | def _begin_session(self, properties: SessionProperties) -> None: |
190 | 216 | from google.cloud.bigquery import QueryJobConfig |
191 | 217 |
|
@@ -318,7 +344,10 @@ def create_mapping_schema( |
318 | 344 | if len(table.parts) == 3 and "." in table.name: |
319 | 345 | # The client's `get_table` method can't handle paths with >3 identifiers |
320 | 346 | self.execute(exp.select("*").from_(table).limit(0)) |
321 | | - query_results = self._query_job._query_results |
| 347 | + query_job = self._query_job |
| 348 | + assert query_job is not None |
| 349 | + |
| 350 | + query_results = query_job._query_results |
322 | 351 | columns = create_mapping_schema(query_results.schema) |
323 | 352 | else: |
324 | 353 | bq_table = self._get_table(table) |
@@ -717,7 +746,9 @@ def _fetch_native_df( |
717 | 746 | self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False |
718 | 747 | ) -> DF: |
719 | 748 | self.execute(query, quote_identifiers=quote_identifiers) |
720 | | - return self._query_job.to_dataframe() |
| 749 | + query_job = self._query_job |
| 750 | + assert query_job is not None |
| 751 | + return query_job.to_dataframe() |
721 | 752 |
|
722 | 753 | def _create_column_comments( |
723 | 754 | self, |
@@ -1021,20 +1052,23 @@ def _execute( |
1021 | 1052 | job_config=job_config, |
1022 | 1053 | timeout=self._extra_config.get("job_creation_timeout_seconds"), |
1023 | 1054 | ) |
| 1055 | + query_job = self._query_job |
| 1056 | + assert query_job is not None |
1024 | 1057 |
|
1025 | 1058 | logger.debug( |
1026 | 1059 | "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", |
1027 | | - self._query_job.project, |
1028 | | - self._query_job.location, |
1029 | | - self._query_job.job_id, |
| 1060 | + query_job.project, |
| 1061 | + query_job.location, |
| 1062 | + query_job.job_id, |
1030 | 1063 | ) |
1031 | 1064 |
|
1032 | 1065 | results = self._db_call( |
1033 | | - self._query_job.result, |
| 1066 | + query_job.result, |
1034 | 1067 | timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore |
1035 | 1068 | ) |
| 1069 | + |
1036 | 1070 | self._query_data = iter(results) if results.total_rows else iter([]) |
1037 | | - query_results = self._query_job._query_results |
| 1071 | + query_results = query_job._query_results |
1038 | 1072 | self.cursor._set_rowcount(query_results) |
1039 | 1073 | self.cursor._set_description(query_results.schema) |
1040 | 1074 |
|
@@ -1198,23 +1232,23 @@ def _query_data(self) -> t.Any: |
1198 | 1232 |
|
1199 | 1233 | @_query_data.setter |
1200 | 1234 | def _query_data(self, value: t.Any) -> None: |
1201 | | - return self._connection_pool.set_attribute("query_data", value) |
| 1235 | + self._connection_pool.set_attribute("query_data", value) |
1202 | 1236 |
|
1203 | 1237 | @property |
1204 | | - def _query_job(self) -> t.Any: |
| 1238 | + def _query_job(self) -> t.Optional[QueryJob]: |
1205 | 1239 | return self._connection_pool.get_attribute("query_job") |
1206 | 1240 |
|
1207 | 1241 | @_query_job.setter |
1208 | 1242 | def _query_job(self, value: t.Any) -> None: |
1209 | | - return self._connection_pool.set_attribute("query_job", value) |
| 1243 | + self._connection_pool.set_attribute("query_job", value) |
1210 | 1244 |
|
1211 | 1245 | @property |
1212 | 1246 | def _session_id(self) -> t.Any: |
1213 | 1247 | return self._connection_pool.get_attribute("session_id") |
1214 | 1248 |
|
1215 | 1249 | @_session_id.setter |
1216 | 1250 | def _session_id(self, value: t.Any) -> None: |
1217 | | - return self._connection_pool.set_attribute("session_id", value) |
| 1251 | + self._connection_pool.set_attribute("session_id", value) |
1218 | 1252 |
|
1219 | 1253 |
|
1220 | 1254 | class _ErrorCounter: |
|
0 commit comments