python - 根据需要仅继承父 pandera SchemaModel 的一些字段

我有输入和输出 pandera SchemaModels,输出继承了输入,它准确地表示输入模式的所有属性都在输出模式的范围内。

我想要避免的是根据需要(非可选)继承所有属性,因为它们正确地来自输入模式。相反,我想根据输入模式的需要保留它们,但定义其中哪些对于输出模式仍然是必需的,而其他继承的属性则变为可选。

这个pydantic https://stackoverflow.com/questions/61948723/how-to-extend-a-pydantic-object-and-change-some-fields-type 类似,并且有在父类中定义__init_subclass__ 方法的解决方案。但是,这对于 pandera 类来说并不是开箱即用的,我不确定它是否可实现或是否是正确的方法。

import pandera as pa
from typing import Optional
from pandera.typing import Index, DataFrame, Series, Category

class InputSchema(pa.SchemaModel):

    reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)

    def __init_subclass__(cls, optional_fields=None, **kwargs):
        super().__init_subclass__(**kwargs)
        if optional_fields:
            for field in optional_fields:
                cls.__fields__[field].outer_type_ = Optional
                cls.__fields__[field].required = False    


class OutputSchema(InputSchema, optional_fields=['reporting_date']):

    test: Series[str] = pa.Field()


@pa.check_types
def func(inputs: DataFrame[InputSchema]) -> DataFrame[OutputSchema]:
    inputs = inputs.drop(columns=['reporting_date'])
    inputs['test'] = 'a'
    return inputs

data = pd.DataFrame({'reporting_date': ['2023-01-11', '2023-01-12']})

func(data)

错误:

---> 18 class OutputSchema(InputSchema, optional_fields=['reporting_date']):
KeyError: 'reporting_date'

编辑:

期望的结果是能够设置继承模式中的哪些字段仍然是必需的,而其余的则成为可选的:

class InputSchema(pa.SchemaModel):

    reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
    other_field: Series[str] = pa.Field()


class OutputSchema(InputSchema, required=['reporting_date'])

    test: Series[str] = pa.Field()

生成的 OutputSchema 具有必需的 reporting_datetestother_field 是可选的。

回答1

https://github.com/unionai-oss/pandera/discussions/990 上提出了一个类似的问题,https://github.com/unionai-oss/pandera/pull/1012/files 正在跟踪下一个 pandera 版本。没有干净的解决方案,但最简单的一种是通过重载 to_schema 来排除列:

import pandera as pa
from pandera.typing import Series

class InputSchema(pa.SchemaModel):
    reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)

class OutputSchema(InputSchema):
    test: Series[str]
    
    @classmethod
    def to_schema(cls) -> pa.DataFrameSchema:
        return super().to_schema().remove_columns(["reporting_date"])

这在没有 SchemaError 的情况下针对您的检查功能运行。

回答2

这是通过重用输入模式中的现有类型注释的解决方案:

import pandera as pa
import pandas as pd
from typing import Optional
from pandera.typing import Index, DataFrame, Series, Category
from pydantic import Field, BaseModel
from typing import Annotated, Type


def copy_field(from_model: Type[BaseModel], fname: str, annotations: dict[str, ...]):
    annotations[fname] = from_model.__annotations__[fname]


class InputSchema(pa.SchemaModel):
    reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
    not_inherit: Series[str]

class OutputSchema(pa.SchemaModel):
    test: Series[str] = pa.Field()
    copy_field(InputSchema, "reporting_date", __annotations__)
    # reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
    # not_inherit: Optional[Series[str]]


data = pd.DataFrame({
    'reporting_date': ['2023-01-11', '2023-01-12'],
    'not_inherit': ['a','a']
})

@pa.check_types
def func(
    inputs: DataFrame[InputSchema]
) -> DataFrame[OutputSchema]:
    inputs = inputs.drop(columns=['not_inherit'])
    inputs['test'] = 'a'
    return inputs

func(data)

相似文章

随机推荐

最新文章