如何使类泛型并正确推断其方法的返回类型

2024-10-02 06:30:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试用python中的SOLID原则重构一些类,我有一个关于如何将SOLID与python类型混合的问题

假设我有这些课程:

from asyncpg import Pool


class PGQuery:
    async def execute(self, connection: Pool):
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery):
        return await query.execute(self._connection)

from pydantic import BaseModel, parse_obj_as


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        result = await connection.fetchrow(...)

        return parse_obj_as(QualitySummary, result)

用法示例:

pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))

问题是result的推断类型是Any,它是基类PGQuery之一。我希望PGQuery的每个子类(可能使用泛型?)通过PGQueryExecutorexecute方法正确地推断出它自己的execute方法和它自己的返回类型,因此PGQueryExecutorexecute方法的返回值:

pgqueryexecutor.execute(PGQueryAnySubclass(...))

正是PGQueryAnySubclass.execute的返回类型

我怎样才能做到这一点


Tags: selfnode类型executeasyncdefresultawait
1条回答
网友
1楼 · 发布于 2024-10-02 06:30:04

您可以通过使用generic protocol来实现这一点,您可以从中继承特定的查询类。下面提供了一个例子。考虑到类型变量被协变地用作返回类型,因此我们将其标记为covariant=True(关于上面链接的更多详细信息)

from abc import abstractmethod
from typing import TypeVar, Protocol

from pydantic import BaseModel, parse_obj_as

T = TypeVar('T', covariant=True)


class Pool:
    ...


class PGQuery(Protocol[T]):
    @abstractmethod
    async def execute(self, connection: Pool) -> T:
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery[T]) -> T:
        return await query.execute(self._connection)


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery[QualitySummary]):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        # ...
        return parse_obj_as(QualitySummary, {})


async def main() -> None:
    q = PGQueryQualitySummary("node")
    ex = PGQueryExecutor(Pool())
    reveal_type(await ex.execute(q))  # revealed type QualitySummary


相关问题 更多 >

    热门问题