Skip to content

Commit

Permalink
Upgrade FAB to 4.1.1
Browse files Browse the repository at this point in the history
The Flask Application Builder have been updated recently to
support a number of newer dependencies. This PR is the
attempt to migrate FAB to newer version.

This includes:

* update setup.py and setup.cfg upper and lower bounds to
  account for proper version of dependencies that
  FAB < 4.0.0 was blocking from upgrade
* added typed Flask application retrieval with a custom
  application fields available for MyPy typing checks.
* fix typing to account for typing hints added in multiple
  upgraded libraries optional values and content of request
  returned as Mapping
* switch to PyJWT 2.* by using non-deprecated "required" claim as
  list rather than separate fields
* add possibiliyt to install providers without constraints
  so that we could avoid errors on conflicting constraints when
  upgrade-to-newer-dependencies is used
* add pre-commit to check that 2.4+ only get_airflow_app is not
  used in providers
* avoid Bad Request in case the request sent to Flask 2.0 is not
  JSon content type
* switch imports of internal classes to direct packages
  where classes are available rather than from "airflow.models" to
  satisfy MyPY
* synchronize changes of FAB Security Manager 4.1.1 with our copy
  of the Security Manager.
* add error handling for a few "None" cases detected by MyPY
* corrected test cases that were broken by immutability of
  Flask 2 objects and better escaping done by Flask 2
* updated test cases to account for redirection to "path" rather
  than full URL by Flask2

Fixes: apache#22397
  • Loading branch information
potiuk committed Jun 19, 2022
1 parent 5674491 commit 22652d1
Show file tree
Hide file tree
Showing 51 changed files with 489 additions and 325 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}"
run: >
breeze verify-provider-packages --use-airflow-version wheel --use-packages-from-dist
--package-format wheel
env:
SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}"
- name: "Remove airflow package and replace providers with 2.2-compliant versions"
run: |
rm -vf dist/apache_airflow-*.whl \
Expand Down Expand Up @@ -878,6 +880,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}"
run: >
breeze verify-provider-packages --use-airflow-version sdist --use-packages-from-dist
--package-format sdist
env:
SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}"
- name: "Fix ownership"
run: breeze fix-ownership
if: always()
Expand Down
36 changes: 27 additions & 9 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -686,29 +686,47 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none"
else
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
uninstall_providers
elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then
echo
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none"
else
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
uninstall_providers
else
echo
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none"
else
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
fi
if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then
echo
Expand Down
5 changes: 3 additions & 2 deletions airflow/api/auth/backend/basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from functools import wraps
from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast

from flask import Response, current_app, request
from flask import Response, request
from flask_appbuilder.const import AUTH_LDAP
from flask_login import login_user

from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.www.fab_security.sqla.models import User

CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
Expand All @@ -40,7 +41,7 @@ def auth_current_user() -> Optional[User]:
if auth is None or not auth.username or not auth.password:
return None

ab_security_manager = current_app.appbuilder.sm
ab_security_manager = get_airflow_app().appbuilder.sm
user = None
if ab_security_manager.auth_type == AUTH_LDAP:
user = ab_security_manager.auth_user_ldap(auth.username, auth.password)
Expand Down
9 changes: 5 additions & 4 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Collection, Optional

from connexion import NoContent
from flask import current_app, g, request
from flask import g, request
from marshmallow import ValidationError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_
Expand All @@ -38,6 +38,7 @@
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session


Expand All @@ -56,7 +57,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
def get_dag_details(*, dag_id: str) -> APIResponse:
"""Get details of DAG."""
dag: DAG = current_app.dag_bag.get_dag(dag_id)
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")
return dag_detail_schema.dump(dag)
Expand All @@ -83,7 +84,7 @@ def get_dags(
if dag_id_pattern:
dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))

readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
if tags:
Expand Down Expand Up @@ -143,7 +144,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
if dag_id_pattern == '~':
dag_id_pattern = '%'
dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)

dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
if tags:
Expand Down
26 changes: 14 additions & 12 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pendulum
from connexion import NoContent
from flask import current_app, g, request
from flask import g
from marshmallow import ValidationError
from sqlalchemy import or_
from sqlalchemy.orm import Query, Session
Expand All @@ -30,6 +30,7 @@
set_dag_run_state_to_success,
)
from airflow.api_connexion import security
from airflow.api_connexion.endpoints.mapping_from_request import get_mapping_from_request
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
Expand All @@ -47,6 +48,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models import DagModel, DagRun
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -167,7 +169,7 @@ def get_dag_runs(

# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id == "~":
appbuilder = current_app.appbuilder
appbuilder = get_airflow_app().appbuilder
query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
else:
query = query.filter(DagRun.dag_id == dag_id)
Expand Down Expand Up @@ -199,13 +201,13 @@ def get_dag_runs(
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
"""Get list of DAG Runs"""
body = request.get_json()
body = get_mapping_from_request()
try:
data = dagruns_batch_form_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

appbuilder = current_app.appbuilder
appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
query = session.query(DagRun)
if data.get("dag_ids"):
Expand Down Expand Up @@ -252,7 +254,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
detail=f"DAG with dag_id: '{dag_id}' has import errors",
)
try:
post_body = dagrun_schema.load(request.json, session=session)
post_body = dagrun_schema.load(get_mapping_from_request(), session=session)
except ValidationError as err:
raise BadRequest(detail=str(err))

Expand All @@ -268,7 +270,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
)
if not dagrun_instance:
try:
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
Expand All @@ -277,7 +279,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
state=DagRunState.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
Expand Down Expand Up @@ -310,12 +312,12 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
raise NotFound(error_message)
try:
post_body = set_dagrun_state_form_schema.load(request.json)
post_body = set_dagrun_state_form_schema.load(get_mapping_from_request())
except ValidationError as err:
raise BadRequest(detail=str(err))

state = post_body['state']
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if state == DagRunState.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DagRunState.QUEUED:
Expand All @@ -339,15 +341,15 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
)
if dag_run is None:
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
raise NotFound(error_message)
try:
post_body = clear_dagrun_form_schema.load(request.json)
post_body = clear_dagrun_form_schema.load(get_mapping_from_request())
except ValidationError as err:
raise BadRequest(detail=str(err))

dry_run = post_body.get('dry_run', False)
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
start_date = dag_run.logical_date
end_date = dag_run.logical_date

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/extra_link_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

from flask import current_app
from sqlalchemy.orm.session import Session

from airflow import DAG
Expand All @@ -25,6 +24,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models.dagbag import DagBag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session


Expand All @@ -46,7 +46,7 @@ def get_extra_links(
"""Get extra links for task instance"""
from airflow.models.taskinstance import TaskInstance

dagbag: DagBag = current_app.dag_bag
dagbag: DagBag = get_airflow_app().dag_bag
dag: DAG = dagbag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found')
Expand Down
11 changes: 6 additions & 5 deletions airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Optional

from flask import Response, current_app, request
from flask import Response, request
from itsdangerous.exc import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from sqlalchemy.orm.session import Session
Expand All @@ -29,6 +28,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models import TaskInstance
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -52,7 +52,7 @@ def get_log(
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get logs for specific task instance"""
key = current_app.config["SECRET_KEY"]
key = get_airflow_app().config["SECRET_KEY"]
if not token:
metadata = {}
else:
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_log(
metadata['end_of_log'] = True
raise NotFound(title="TaskInstance not found")

dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if dag:
try:
ti.task = dag.get_task(ti.task_id)
Expand All @@ -101,7 +101,8 @@ def get_log(
if return_type == 'application/json' or return_type is None: # default
logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
logs = logs[0] if task_try_number is not None else logs
token = URLSafeSerializer(key).dumps(metadata)
# we must have token here, so we can safely ignore it
token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment]
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)
Expand Down
24 changes: 24 additions & 0 deletions airflow/api_connexion/endpoints/mapping_from_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Mapping, cast


def get_mapping_from_request() -> Mapping[str, Any]:
from flask import request

return cast(Mapping[str, Any], request.get_json())
Loading

0 comments on commit 22652d1

Please sign in to comment.