Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update code style for airflow db commands to SQLAlchemy 2.0 style #31486

Merged
merged 5 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any

from sqlalchemy import MetaData, String
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry

from airflow.configuration import conf

Expand All @@ -45,8 +45,9 @@ def _get_schema():


metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
mapper_registry = registry(metadata=metadata)

Base: Any = declarative_base(metadata=metadata)
Base: Any = mapper_registry.generate_base()

ID_LEN = 250

Expand Down
77 changes: 35 additions & 42 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@

from airflow.models.base import Base


log = logging.getLogger(__name__)

REVISION_HEADS_MAP = {
Expand Down Expand Up @@ -686,21 +685,28 @@ def create_default_connections(session: Session = NEW_SESSION):
)


def _create_db_from_orm(session):
from alembic import command
def _get_flask_db(sql_database_uri):
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

from airflow.www.session import AirflowDatabaseSessionInterface

flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
return db


def _create_db_from_orm(session):
from alembic import command

from airflow.models.base import Base
from airflow.www.fab_security.sqla.models import Model
from airflow.www.session import AirflowDatabaseSessionInterface

def _create_flask_session_tbl(sql_database_uri):
flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
db = _get_flask_db(sql_database_uri)
db.create_all()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
Expand Down Expand Up @@ -1004,15 +1010,16 @@ def reflect_tables(tables: list[Base | str] | None, session):
"""
import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)
bind = session.bind
metadata = sqlalchemy.schema.MetaData()

if tables is None:
metadata.reflect(resolve_fks=False)
metadata.reflect(bind=bind, resolve_fks=False)
else:
for tbl in tables:
try:
table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the bind=bind part still needed? You already set metadata.bind above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. It was surprising to me too. Here's the warning that it throws:

/opt/airflow/airflow/utils/db.py:1023 RemovedIn20Warning: The ``bind`` argument for schema methods that invoke SQL against an engine or connection will be required in SQLAlchemy 2.0. (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the metadata.bind line also then? Seems odd to need both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed it. I remember having some issues around the metadata but let's see if the CI passes cause I can't reproduce locally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened was that the metadata that was returned by the function was not bound to any database. The other binds are for the reflect methods. Despite the session being bound to metadata, they still want the bind to happen on its method

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead we can add bind=bind explicitly to the method calls instead? Not sure how feasible that is without actually digging into the call stack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this now...taking a look

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See 8a08c05, let me know what you think

except exc.InvalidRequestError:
continue
return metadata
Expand Down Expand Up @@ -1633,8 +1640,9 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False):
connection = settings.engine.connect()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
drop_airflow_models(connection)
drop_airflow_moved_tables(session)
with connection.begin():
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
drop_airflow_models(connection)
drop_airflow_moved_tables(connection)

if not skip_init:
initdb(session=session)
Expand Down Expand Up @@ -1701,27 +1709,12 @@ def drop_airflow_models(connection):
:return: None
"""
from airflow.models.base import Base

# Drop connection and chart - those tables have been deleted and in case you
# run resetdb on schema with chart or users table will fail
chart = Table("chart", Base.metadata)
chart.drop(settings.engine, checkfirst=True)
user = Table("user", Base.metadata)
user.drop(settings.engine, checkfirst=True)
users = Table("users", Base.metadata)
users.drop(settings.engine, checkfirst=True)
dag_stats = Table("dag_stats", Base.metadata)
dag_stats.drop(settings.engine, checkfirst=True)
session = Table("session", Base.metadata)
session.drop(settings.engine, checkfirst=True)
from airflow.www.fab_security.sqla.models import Model

Base.metadata.drop_all(connection)
# we remove the Tables here so that if resetdb is run metadata does not keep the old tables.
Base.metadata.remove(session)
Base.metadata.remove(dag_stats)
Base.metadata.remove(users)
Base.metadata.remove(user)
Base.metadata.remove(chart)
Model.metadata.drop_all(connection)
db = _get_flask_db(connection.engine.url)
db.drop_all()
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext

Expand All @@ -1731,11 +1724,11 @@ def drop_airflow_models(connection):
version.drop(connection)


def drop_airflow_moved_tables(session):
def drop_airflow_moved_tables(connection):
from airflow.models.base import Base
from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX

tables = set(inspect(session.get_bind()).get_table_names())
tables = set(inspect(connection).get_table_names())
to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)]
for tbl in to_delete:
tbl.drop(settings.engine, checkfirst=False)
Expand All @@ -1749,7 +1742,7 @@ def check(session: Session = NEW_SESSION):

:param session: session of the sqlalchemy
"""
session.execute("select 1 as is_alive;")
session.execute(text("select 1 as is_alive;"))
log.info("Connection successful.")


Expand Down Expand Up @@ -1780,23 +1773,23 @@ def create_global_lock(
dialect = conn.dialect
try:
if dialect.name == "postgresql":
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), timeout=lock_timeout)
conn.execute(text("SELECT pg_advisory_lock(:id)"), id=lock.value)
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), id=str(lock), timeout=lock_timeout)
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
elif dialect.name == "mssql":
# TODO: make locking work for MSSQL
pass

yield
finally:
if dialect.name == "postgresql":
conn.execute("SET LOCK_TIMEOUT TO DEFAULT")
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), id=lock.value).fetchone()
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
if not unlocked:
raise RuntimeError("Error releasing DB lock!")
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
conn.execute(text("select RELEASE_LOCK(:id)"), id=str(lock))
conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
elif dialect.name == "mssql":
# TODO: make locking work for MSSQL
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def drop_tables_with_prefix(prefix):
metadata = reflect_tables(None, session)
for table_name, table in metadata.tables.items():
if table_name.startswith(prefix):
table.drop()
table.drop(session.bind)


def clear_db_serialized_dags():
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_resetdb(
session_mock = MagicMock()
resetdb(session_mock, skip_init=skip_init)
mock_drop_airflow.assert_called_once_with(mock_connect.return_value)
mock_drop_moved.assert_called_once_with(session_mock)
mock_drop_moved.assert_called_once_with(mock_connect.return_value)
if skip_init:
mock_init.assert_not_called()
else:
Expand Down