Traffine I/O

日本語

2023-03-31

Pythonのプロトコル型

プロトコルとは

プロトコルは、Pythonの型システムの強力な機能であり、より柔軟で表現力豊かな型ヒントを実現します。typingモジュールの一部であり、Pythonコードで型を注釈するために使用されます。プロトコルは構造的型付けの概念に基づいており、型の明示的な継承階層ではなく、型の振る舞いや構造に焦点を当てています。これにより、クラスのインターフェースを定義し、より堅牢でメンテナンスしやすいコードを作成することができます。

Pythonでプロトコルを使用することで、コードの可読性とメンテナンス性を大幅に向上させることができます。関数やメソッドの予想される入力と出力の型について明示的な情報を提供することで、コードをより自己文書化し、他の開発者が理解しやすくなります。さらに、この情報はMypyなどの静的型チェッカーによって、ランタイムエラーが発生する前に潜在的なバグや型の不一致をキャッチするために利用できます。

プロトコルの使い方

プロトコルを使用するには、まずtyping.Protocolクラスを理解する必要があります。typing.Protocolクラスは、新しいプロトコルクラスを定義するために使用されるメタクラスです。カスタムプロトコルを作成するには、typing.Protocolクラスを継承した新しいクラスを定義し、新しいクラス内で必要な属性とメソッドを定義する必要があります。

カスタムプロトコルは、typing.Protocolから継承する新しいクラスを定義することで作成されます。このクラスには、望ましいメソッドとプロパティが含まれており、これらは、このプロトコルを実装する任意の型の期待される動作を示すために使用されます。

例えば、2次元点のカスタムプロトコルを定義するには、Point2Dという新しいクラスを作成し、typing.Protocolから継承します。

python
from typing import Protocol

class Point2D(Protocol):
    x: float
    y: float

    def distance_to_origin(self) -> float:
        ...

この例では、Point2Dプロトコルは、それを実装する任意の型がfloat型のxyプロパティを持ち、floatを返すdistance_to_originという名前のメソッドを持つ必要があることを指定しています。Point2Dプロトコルをコード内で型ヒントとして使用すると、実際の継承階層に関係なく、期待される型がこれらのプロパティとメソッドを持つことを示します。

カスタムプロトコルを作成して使用することで、Pythonコードでより表現力豊かで柔軟な型ヒントを書き、最終的にはより堅牢でメンテナンスしやすいプロジェクトを作成することができます。

抽象ベースクラス(abc)との比較

typing.Protocolと抽象ベースクラス(ABC)の両方がPythonクラスのインターフェースを定義するのに役立ちますが、いくつかの重要な点で異なります。

typing.Protocolは構造的部分型付けに焦点を当て、ABCは名目的部分型付けに重点を置いています。つまり、プロトコルは型の構造と振る舞いに焦点を当てており、ABCは明示的な継承関係に基づいています。

プロトコルはABCよりも柔軟であり、実装を強制しません。必要な属性とメソッドがあれば、クラスはプロトコルに準拠できますが、明示的に継承する必要はありません。

ABCはランタイムで抽象メソッドの実装を強制できますが、プロトコルにはデフォルトではこの機能がありません。ただし、プロトコルに対してruntime_checkableデコレータを使用して、ランタイムチェックを行うことができます。

プロトコルと抽象ベースクラス(abc)の使い分け

typing.Protocolは、継承関係を明示的に強制しないで、型の構造と振る舞いに焦点を当てたい場合に一般的に選択されます。サードパーティのライブラリのアダプタプロトコルの作成、ファイルライクなオブジェクトの定義、カスタムイテラブルやイテレータインターフェースの作成に特に有用です。

一方、ABCは、実行時に特定のメソッドの実装を強制し、明示的な継承階層を確立する場合に適しています。

typing.Protocolを使ったscikit-learnの例

Scikit-learnは、幅広い機械学習の推論器を提供しています。Scikit-learn互換の推論器に必要なメソッドを定義するカスタム推定器プロトコルを定義することができます。

python
from typing import Any, Protocol
import numpy as np

class SklearnEstimator(Protocol):
    def fit(self, X: np.ndarray, y: np.ndarray, **kwargs: Any) -> "SklearnEstimator":
        ...

    def predict(self, X: np.ndarray) -> np.ndarray:
        ...

    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        ...

カスタム推論器プロトコルが用意されているため、このプロトコルに従う新しい推論器を実装できます。

python
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class CustomEstimator(BaseEstimator, ClassifierMixin):
    def fit(self, X: np.ndarray, y: np.ndarray) -> "CustomEstimator":
        # Implementation of the fit method
        ...

    def predict(self, X: np.ndarray) -> np.ndarray:
        # Implementation of the predict method
        ...

    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        # Implementation of the score method
        ...

UカスタムSklearnEstimatorプロトコルを利用することで、evaluate_estimators関数を作成し、Scikit-learn互換の推論器に対して適用できます。

python
from sklearn.model_selection import cross_val_score
from typing import List

def evaluate_estimators(estimators: List[SklearnEstimator], X: np.ndarray, y: np.ndarray) -> None:
    for estimator in estimators:
        scores = cross_val_score(estimator, X, y, cv=5)

    print(f"{estimator.__class__.__name__}:")
    print(f"  Mean cross-validation score: {scores.mean():.3f}")
    print(f"  Standard deviation: {scores.std():.3f}\n")

この例では、evaluate_estimators関数は、SklearnEstimatorプロトコルに準拠する推定器のリストを受け取り、Scikit-learnのcross_val_score関数を使用して各推論器のパフォーマンスを評価します。

カスタムSklearnEstimatorプロトコルを活用することで、evaluate_estimators関数に渡された推論器が、必要なfitpredictscoreメソッドを持っていることを確認し、コードをより堅牢でメンテナンスしやすくすることができます。

参考

https://peps.python.org/pep-0544/
https://mypy.readthedocs.io/en/stable/protocols.html

Ryusei Kakujo

researchgatelinkedingithub

Focusing on data science for mobility

Bench Press 100kg!