반응형

I. 개요

일련의 작업 진행 시 상황에 따라 다른 작업으로 이어져야 하는 경우는 굉장히 빈번하게 발생한다.

Airflow 는 기본적으로 DAG 으로 작업을 구조화해서 작업을 진행하기 때문에, 자동화할 때 이러한 조건부 작업을 구현하지 못한다면 매번 실패 후 재처리하는 작업이 필요하다.

기본적이지만 자주 사용되는 Branch task 인 BranchPythonOperator 와 BranchSQLOperator 의 사용법과 예제를 기록해둔다.

II. Branch Task

1. BranchPythonOperator

PythonOperator 기반으로 구성되어 task_id(s) 를 output 으로 하는 Python callable 을 통해 바로 다음에 이어지는 작업 요소를 결정한다.

BranchPythonOperator(
    python_callable : method, 
    op_args : dict, 
    op_kwargs : dict,
    templates_dict : dict [optional],
    templates_exts : list [optional]
)
  • 환경에 따라 다른 모델의 학습을 실행시킨다고 가정할 때 아래와 같이 작성할 수 있다.
    • 실행시키는 task 는 Dummy task 로 대체

from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from utils.util import get_airflow_env # 현재 동작중인 airflow 환경을 가져오는 custom function 
from datetime import datetime

def branch_callable(env):
    if env == "production" :
        return "production_task"
    else :
        return "staging_task"

env = get_airflow_env()

with DAG(
    dag_id='example_BranchPythonOperator',
    start_date=datetime(2022, 7, 15),
    schedule_interval= '* * * * *'
    ) as dag:

    branch_task = BranchPythonOperator(
        task_id = "env_branching",
        python_callable = branch_callable,
        op_kwargs = {"env" : env}
        )

    production_task = DummyOperator(
            task_id = "production_task"
    )

    staging_task = DummyOperator(
            task_id = "staging_task"
    )

    branch_task >> [production_task, staging_task]

2. BranchSQLOperator

ETL/ELT 작업중에 무결성 검증을 진행하거나 DB에 저장된 데이터에 따라 다른 작업이 진행돼야 하는 등 SQL 의 결과를 이용해 분기 작업이 필요할 수 있다. 이럴 때 BaseSQLOperator 기반의 BranchSQLOperator를 사용할 수 있다.

  • BranchPythonOperator 와 다르게 결과에 따라 어떤 task 를 실행시킬지 “명시적으로 선언한다”
  • 실행 결과는 반드시 Boolean (True/False), integer (0 = False, Otherwise = 1) , string (true/y/yes/1/on/false/n/no/0/off) 이어야 한다.
BranchSQLOperator(
    sql : str or templete ends with .sql ,  
    follow_task_ids_if_true : str, 
    follow_task_ids_if_false : str,
    conn_id : str,
    database : str,
    parameters : mapping/iter [optional]
)
  • DW 에 저장된 데이터를 기반으로 새로운 table 을 만드는 태스크가 있다고 할 때 아래와 같이 작업할 수 있다.
import os 
from datetime import datetime
from dependencies import utils
from airflow import models
from airflow.contrib.operators.bigquery_operator import BigQueryOperator
from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor
from airflow.providers.google.cloud.operators.bigquery import BigQueryValueCheckOperator
from airflow.operators.sql import BranchSQLOperator
from airflow.operators.python_operator import PythonOperator

yesterday = "{{ macros.ds_add(ds, -0) }}"
yesterday_suffix = "{{ macros.ds_format(macros.ds_add(ds, -0), '%Y-%m-%d', '%Y%m%d') }}"

with models.DAG(
        dag_id='sql_branch_example',
        description='sql_branch_example',
        schedule_interval='0 1 * * *',
        start_date=datetime(2022, 7, 15),
    ) as dag:

        # 원천 테이블의 존재를 확인한다
    check_origin_table_update = BigQueryTableExistenceSensor(
        task_id="check_origin_table_update", 
        project_id={project_id}, 
        dataset_id={dataset}, 
        table_id={table_name},
        gcp_conn_id={your_conn_key},     
    )

        # 원천 테이블 테이블 수준 정합성을 확인한다.
    check_table_level_quality_of_origin_table = BigQueryValueCheckOperator(
        task_id="check_table_level_quality_of_origin_table",
        sql=f"SELECT COUNT(DISTINCT type) FROM `project_id.dataset_id.table_name` WHERE date_kr = {yesterday}",
        pass_value=1,
        use_legacy_sql=False,
    )

        # ELT 테이블이 이미 생성되어 있다면 task 를 종료하기 위한 branch
    check_target_table_exsits = BranchSQLOperator(
        task_id = "check_target_table_exsits",
        conn_id = {your_conn_key},
        sql = f'SELECT IF(COUNT(1) > 0, True, False) FROM `project_id.dataset_id.target_table_name` WHERE date_kr = "{yesterday}"',
        follow_task_ids_if_true = "pass_update_target_table",
        follow_task_ids_if_false = "update_target_table_task",
        parameters={"use_legacy_sql":False} #optional
    )

    pass_update_target_table = PythonOperator(
        task_id='pass_update_target_table',
        python_callable = utils.pass_update_table_data_callable,
    )

    update_target_table_task = BigQueryOperator(
        task_id=f"update_target_table_task",
        sql=utils.read_sql(os.path.join(os.environ['DAGS_FOLDER'], 'sql/example.sql')).format(execute_date=yesterday_suffix),
        bigquery_conn_id={your_conn_key},
        use_legacy_sql=False,
        destination_dataset_table=f'project_id.dataset_id.target_table_name${yesterday_suffix}',
        write_disposition='WRITE_TRUNCATE',
        time_partitioning={'type': 'DAY', 'field': 'date_kr'},
    )

    check_origin_table_update >> check_table_level_quality_of_origin_table >> check_target_table_exsits >> [update_target_table_task, pass_update_target_table]

  • BigQueryTableExistenceSensor 는 task 실행 시 1분 주기로 target table 의 존재를 확인한다. 이때 별도의 escape 처리를 하지 않으면 존재가 확인되기 전까지 계속 실행상태를 유지한다.
  • BigQueryValueCheckOperator 는 실행된 sql 의 결과가 pass_value 와 일치하는지 확인한다.
반응형
복사했습니다!