diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 816b7a7e8d0c..8f9d64b32aab 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator from datetime import datetime from decimal import Decimal from types import TracebackType -from typing import Any, ClassVar, Final, Literal, NamedTuple, NoReturn, TypeVar, overload, type_check_only +from typing import Any, ClassVar, Final, Generic, Literal, NamedTuple, NoReturn, TypeVar, overload, type_check_only from typing_extensions import Self, TypeIs from uuid import UUID @@ -18,6 +18,7 @@ def reraise(tp: Unused, value: BaseException, tb: TracebackType | None = None) - _T = TypeVar("_T") _VT = TypeVar("_VT") _F = TypeVar("_F", bound=Callable[..., Any]) +_TModel = TypeVar("_TModel", bound=Model) class attrdict(dict[str, _VT]): def __getattr__(self, attr: str) -> _VT: ... @@ -1658,9 +1659,9 @@ class Model(metaclass=ModelBase): @classmethod def validate_model(cls) -> None: ... @classmethod - def alias(cls, alias=None) -> ModelAlias: ... + def alias(cls, alias=None) -> ModelAlias[Self]: ... @classmethod - def select(cls, *fields) -> ModelSelect: ... + def select(cls, *fields) -> ModelSelect[Self]: ... @classmethod def update(cls, data=None, /, **update) -> ModelUpdate: ... @classmethod @@ -1684,7 +1685,7 @@ class Model(metaclass=ModelBase): @classmethod def bulk_update(cls, model_list, fields, batch_size=None): ... @classmethod - def noop(cls) -> NoopModelSelect: ... + def noop(cls) -> NoopModelSelect[Self]: ... @classmethod def get(cls, *query, **filters): ... @classmethod @@ -1729,12 +1730,12 @@ class Model(metaclass=ModelBase): @classmethod def add_index(cls, *fields, **kwargs) -> None: ... -class ModelAlias(Node): - def __init__(self, model, alias=None) -> None: ... +class ModelAlias(Node, Generic[_TModel]): + def __init__(self, model: type[_TModel], alias=None) -> None: ... def __getattr__(self, attr: str): ... def __setattr__(self, attr: str, value) -> None: ... def get_field_aliases(self) -> list[Incomplete]: ... - def select(self, *selection) -> ModelSelect: ... + def select(self, *selection) -> ModelSelect[_TModel]: ... def __call__(self, **kwargs): ... def __sql__(self, ctx): ... @@ -1782,11 +1783,11 @@ class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): # type: i model: Incomplete def __init__(self, model, *args, **kwargs) -> None: ... -class ModelSelect(BaseModelSelect, Select): # type: ignore[misc] - model: Incomplete - def __init__(self, model, fields_or_models, is_default: bool = False) -> None: ... +class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): # type: ignore[misc] + model: type[_TModel] + def __init__(self, model: type[_TModel], fields_or_models, is_default: bool = False) -> None: ... def clone(self) -> Self: ... - def select(self, *fields_or_models): ... + def select(self, *fields_or_models) -> ModelSelect[_TModel]: ... def select_extend(self, *columns): ... def switch(self, ctx=None) -> Self: ... def join(self, dest, join_type="INNER JOIN", on=None, src=None, attr=None) -> Self: ... # type: ignore[override] @@ -1798,7 +1799,7 @@ class ModelSelect(BaseModelSelect, Select): # type: ignore[misc] def create_table(self, name, safe: bool = True, **meta): ... def __sql_selection__(self, ctx, is_subquery: bool = False): ... -class NoopModelSelect(ModelSelect): +class NoopModelSelect(ModelSelect[_TModel]): def __sql__(self, ctx): ... class _ModelWriteQueryHelper(_ModelQueryHelper): @@ -1817,7 +1818,7 @@ class ModelInsert(_ModelWriteQueryHelper, Insert): # type: ignore[misc] class ModelDelete(_ModelWriteQueryHelper, Delete): ... # type: ignore[misc] -class ManyToManyQuery(ModelSelect): +class ManyToManyQuery(ModelSelect[_TModel]): def __init__(self, instance, accessor, rel, *args, **kwargs) -> None: ... def add(self, value, clear_existing: bool = False) -> None: ... def remove(self, value): ...