Skip to content

Commit

Permalink
Refactor SqlAlchemy session.execute() calls to 2.0 style in case of p…
Browse files Browse the repository at this point in the history
…lain text SQL queries
  • Loading branch information
moiseenkov committed Aug 21, 2023
1 parent a1e6cd4 commit 6c4010d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
5 changes: 2 additions & 3 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,9 +1222,8 @@ def _create_table_as(
)
else:
# Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax
session.execute(
f"CREATE TABLE {target_table_name} AS {source_query.selectable.compile(bind=session.get_bind())}"
)
select_table = source_query.selectable.compile(bind=session.get_bind())
session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}"))


def _move_dangling_data_to_new_table(
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pendulum
import pytest
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import DeclarativeMeta

Expand Down Expand Up @@ -211,7 +212,7 @@ def test__build_query(self, table_name, date_add_kwargs, expected_to_delete, ext
)
stmt = CreateTableAs(target_table_name, query.selectable)
session.execute(stmt)
res = session.execute(f"SELECT COUNT(1) FROM {target_table_name}")
res = session.execute(text(f"SELECT COUNT(1) FROM {target_table_name}"))
for row in res:
assert row[0] == expected_to_delete

Expand Down
13 changes: 7 additions & 6 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from kubernetes.client import models as k8s
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import StatementError

from airflow import settings
Expand Down Expand Up @@ -54,7 +55,7 @@ def setup_method(self):
# make sure NOT to run in UTC. Only postgres supports storing
# timezone information in the datetime field
if session.bind.dialect.name == "postgresql":
session.execute("SET timezone='Europe/Amsterdam'")
session.execute(text("SET timezone='Europe/Amsterdam'"))

self.session = session

Expand Down Expand Up @@ -208,17 +209,17 @@ def test_with_row_locks(

def test_prohibit_commit(self):
with prohibit_commit(self.session) as guard:
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
guard.commit()

# Check the expected_commit is reset
with pytest.raises(RuntimeError):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
self.session.commit()

def test_prohibit_commit_specific_session_only(self):
Expand All @@ -233,12 +234,12 @@ def test_prohibit_commit_specific_session_only(self):
assert other_session is not self.session

with prohibit_commit(self.session):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

other_session.execute("SELECT 1")
other_session.execute(text("SELECT 1"))
other_session.commit()

def teardown_method(self):
Expand Down

0 comments on commit 6c4010d

Please sign in to comment.