Traffine I/O

日本語

2023-01-14

Kedro Hooks

Hooks

Hooksとは、Kedroのメイン実行にサブの処理を追加するための仕組みです。Hooksの発火のタイミングは次のメイン実行のタイミングから選択することになります。

  • after_catalog_created
  • before_node_run
  • after_node_run
  • on_node_error
  • before_pipeline_run
  • after_pipeline_run
  • on_pipeline_error
  • before_dataset_loaded
  • after_dataset_loaded
  • before_dataset_saved
  • after_dataset_saved
  • after_context_created

上記は<before/after>_<noun>_<past_participle>の法則で命名されています。例えばbefore_node_runでは、Nodeが実行される前にHookが走るということになります。

Hooks の使い方

次の手順でHookを設定します。

  1. src/<project_name>/hooks.pyにHookを定義する
  2. src/<project_name>/settings.pyHOOKSを更新する

src/<project_name>/hooks.pyには@hook_implデコレータを使い、Hookの実行を宣言します。次のコードでは、after_data_catalog_createdのタイミング、つまりDataCatalogが作成された後にHookが走るように宣言しています。

src/<project_name>/hooks.py
import logging

from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog


class DataCatalogHooks:
    @property
    def _logger(self):
        return logging.getLogger(self.__class__.__name__)

    @hook_impl
    def after_catalog_created(self, catalog: DataCatalog) -> None:
        self._logger.info(catalog.list())

src/<project_name>/settings.pyを次のように更新するとHookが設定されます。

src/<project_name>/settings.py
 `from <project_name>.hooks import ProjectHooks, DataCatalogHooks

-  HOOKS = (ProjectHooks(),)
+ `HOOKS = (ProjectHooks(), DataCatalogHooks())

Hooks の例

メモリー消費のトラッキング

memory-profiler を使い、メモリ消費をトラッキングすることができます。

$ pip install memory_profiler
src/<project_name>/hooks.py
from memory_profiler import memory_usage
import logging


def _normalise_mem_usage(mem_usage):
    # memory_profiler < 0.56.0 returns list instead of float
    return mem_usage[0] if isinstance(mem_usage, (list, tuple)) else mem_usage


class MemoryProfilingHooks:
    def __init__(self):
        self._mem_usage = {}

    @property
    def _logger(self):
        return logging.getLogger(self.__class__.__name__)

    @hook_impl
    def before_dataset_loaded(self, dataset_name: str) -> None:
        before_mem_usage = memory_usage(
            -1,
            interval=0.1,
            max_usage=True,
            retval=True,
            include_children=True,
        )
        before_mem_usage = _normalise_mem_usage(before_mem_usage)
        self._mem_usage[dataset_name] = before_mem_usage


    @hook_impl
    def after_dataset_loaded(self, dataset_name: str) -> None:
        after_mem_usage = memory_usage(
            -1,
            interval=0.1,
            max_usage=True,
            retval=True,
            include_children=True,
        )
        # memory_profiler < 0.56.0 returns list instead of float
        after_mem_usage = _normalise_mem_usage(after_mem_usage)

        self._logger.info(
            "Loading %s consumed %2.2fMiB memory",
            dataset_name,
            after_mem_usage - self._mem_usage[dataset_name],
        )

src/<project_name>/settings.pyHOOKSを次のように編集します。

src/<project_name>/settings.py
-  HOOKS = (ProjectHooks(),)
+  HOOKS = (MemoryProfilingHooks(),)

kedro runを実行するとメモリ消費のログが出力されるようになります。

$ kedro run

$ 2021-10-05 12:02:34,946 - kedro.io.data_catalog - INFO - Loading data from `shuttles` (ExcelDataSet)...
2021-10-05 12:02:43,358 - MemoryProfilingHooks - INFO - Loading shuttles consumed 82.67MiB memory
2021-10-05 12:02:43,358 - kedro.pipeline.node - INFO - Running node: preprocess_shuttles_node: preprocess_shuttles([shuttles]) -> [preprocessed_shuttles]
2021-10-05 12:02:43,440 - kedro.io.data_catalog - INFO - Saving data to `preprocessed_shuttles` (MemoryDataSet)...
2021-10-05 12:02:43,446 - kedro.runner.sequential_runner - INFO - Completed 1 out of 2 tasks
2021-10-05 12:02:43,559 - kedro.io.data_catalog - INFO - Loading data from `companies` (CSVDataSet)...
2021-10-05 12:02:43,727 - MemoryProfilingHooks - INFO - Loading companies consumed 4.16MiB memory

データのバリデーション

Great Expectations を使い、入力と出力のバリデーションを行うことができます。

$ pip install great-expectations
src/<project_name>/hooks.py
from typing import Any, Dict

from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog

import great_expectations as ge


class DataValidationHooks:

    # Map expectation to dataset
    DATASET_EXPECTATION_MAPPING = {
        "companies": "raw_companies_dataset_expectation",
        "preprocessed_companies": "preprocessed_companies_dataset_expectation",
    }

    @hook_impl
    def before_node_run(
        self, catalog: DataCatalog, inputs: Dict[str, Any], session_id: str
    ) -> None:
        """Validate inputs data to a node based on using great expectation
        if an expectation suite is defined in ``DATASET_EXPECTATION_MAPPING``.
        """
        self._run_validation(catalog, inputs, session_id)

    @hook_impl
    def after_node_run(
        self, catalog: DataCatalog, outputs: Dict[str, Any], session_id: str
    ) -> None:
        """Validate outputs data from a node based on using great expectation
        if an expectation suite is defined in ``DATASET_EXPECTATION_MAPPING``.
        """
        self._run_validation(catalog, outputs, session_id)

    def _run_validation(
        self, catalog: DataCatalog, data: Dict[str, Any], session_id: str
    ):
        for dataset_name, dataset_value in data.items():
            if dataset_name not in self.DATASET_EXPECTATION_MAPPING:
                continue

            dataset = catalog._get_dataset(dataset_name)
            dataset_path = str(dataset._filepath)
            expectation_suite = self.DATASET_EXPECTATION_MAPPING[dataset_name]

            expectation_context = ge.data_context.DataContext()
            batch = expectation_context.get_batch(
                {"path": dataset_path, "datasource": "files_datasource"},
                expectation_suite,
            )
            expectation_context.run_validation_operator(
                "action_list_operator",
                assets_to_validate=[batch],
                session_id=session_id,
            )

src/<project_name>/settings.pyHOOKSを次のように編集します。

src/iris/settings.py
-  HOOKS = (ProjectHooks(),)
+  HOOKS = (DataValidationHooks(),)

メトリクスのトラッキング

MLflow を使い、メトリクスのトラッキングを仕込むことができます。

$ pip install mlflow
src/<project_name>/hooks.py
from typing import Any, Dict

import mlflow
import mlflow.sklearn
from kedro.framework.hooks import hook_impl
from kedro.pipeline.node import Node


class ModelTrackingHooks:
    """Namespace for grouping all model-tracking hooks with MLflow together."""

    @hook_impl
    def before_pipeline_run(self, run_params: Dict[str, Any]) -> None:
        """Hook implementation to start an MLflow run
        with the session_id of the Kedro pipeline run.
        """
        mlflow.start_run(run_name=run_params["session_id"])
        mlflow.log_params(run_params)

    @hook_impl
    def after_node_run(
        self, node: Node, outputs: Dict[str, Any], inputs: Dict[str, Any]
    ) -> None:
        """Hook implementation to add model tracking after some node runs.
        In this example, we will:
        * Log the parameters after the data splitting node runs.
        * Log the model after the model training node runs.
        * Log the model's metrics after the model evaluating node runs.
        """
        if node._func_name == "split_data":
            mlflow.log_params(
                {"split_data_ratio": inputs["params:example_test_data_ratio"]}
            )

        elif node._func_name == "train_model":
            model = outputs["example_model"]
            mlflow.sklearn.log_model(model, "model")
            mlflow.log_params(inputs["parameters"])

    @hook_impl
    def after_pipeline_run(self) -> None:
        """Hook implementation to end the MLflow run
        after the Kedro pipeline finishes.
        """
        mlflow.end_run()

src/<project_name>/settings.pyHOOKSを次のように編集します。

src/<project_name>/settings.py
-  HOOKS = (ProjectHooks(),)
+  HOOKS = (ModelTrackingHooks(),)

参考

https://kedro.readthedocs.io/en/stable/hooks/introduction.html
https://kedro.readthedocs.io/en/stable/kedro.framework.hooks.html

Ryusei Kakujo

researchgatelinkedingithub

Focusing on data science for mobility

Bench Press 100kg!