Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
TypedDict,
)

from librt.internal import cache_version

# from librt.internal import cache_version
# from mypy.cache import CACHE_VERSION as cache_version
def cache_version() -> int:
from mypy.cache import CACHE_VERSION

return CACHE_VERSION


import mypy.semanal_main
from mypy.cache import (
Expand Down
59 changes: 54 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

# sys.exit(1)
import enum
import itertools
import time
Expand Down Expand Up @@ -1749,7 +1750,41 @@ def check_callable_call(
callee = callee.copy_modified(ret_type=fresh_ret_type)

if callee.is_generic():
# import sys
# sys.stderr.write(f"DEBUG: Checking generic callee: {callee}\n")
callee = freshen_function_type_vars(callee)
ctx = self.type_context[-1]
if ctx:
p_ctx = get_proper_type(ctx)
if isinstance(p_ctx, UnionType):
# sys.stderr.write(f"DEBUG: Union Context found: {ctx}\n")
candidates = []
for item in p_ctx.items:
candidate = self.infer_function_type_arguments_using_context(
callee, context, type_context=item
)
# Filter out candidates that did not respect the context (e.g. remained generic
# or inferred something incompatible).
if is_subtype(candidate.ret_type, item):
candidates.append(candidate)

if candidates:
# We use 'None' context to prevent infinite recursion when checking overloads
# provided one of the candidates remains generic.
self.type_context.append(None)
try:
return self.check_overload_call(
Overloaded(candidates),
args,
arg_kinds,
arg_names,
callable_name,
object_type,
context,
)
finally:
self.type_context.pop()

callee = self.infer_function_type_arguments_using_context(callee, context)

formal_to_actual = map_actuals_to_formals(
Expand Down Expand Up @@ -1982,19 +2017,23 @@ def infer_arg_types_in_context(
return cast(list[Type], res)

def infer_function_type_arguments_using_context(
self, callable: CallableType, error_context: Context
self, callable: CallableType, error_context: Context, type_context: Type | None = None
) -> CallableType:
"""Unify callable return type to type context to infer type vars.

For example, if the return type is set[t] where 't' is a type variable
of callable, and if the context is set[int], return callable modified
by substituting 't' with 'int'.
"""
ctx = self.type_context[-1]
ctx: Type | None
if type_context:
ctx = type_context
else:
ctx = self.type_context[-1]
if not ctx:
return callable
# The return type may have references to type metavariables that
# we are inferring right now. We must consider them as indeterminate
# we are inferred right now. We must consider them as indeterminate
# and they are not potential results; thus we replace them with the
# special ErasedType type. On the other hand, class type variables are
# valid results.
Expand Down Expand Up @@ -3143,13 +3182,23 @@ def type_overrides_set(
) -> Iterator[None]:
"""Set _temporary_ type overrides for given expressions."""
assert len(exprs) == len(overrides)
# Use a dict to store original values. This handles duplicates in exprs automatically
# by only storing the original value for the first occurrence (since we iterate and
# populate if not present).
original_values: dict[Expression, Type | None] = {}
for expr, typ in zip(exprs, overrides):
if expr not in original_values:
original_values[expr] = self.type_overrides.get(expr)
self.type_overrides[expr] = typ
try:
yield
finally:
for expr in exprs:
del self.type_overrides[expr]
for expr, prev in original_values.items():
if prev is None:
if expr in self.type_overrides:
del self.type_overrides[expr]
else:
self.type_overrides[expr] = prev

def combine_function_signatures(self, types: list[ProperType]) -> AnyType | CallableType:
"""Accepts a list of function signatures and attempts to combine them together into a
Expand Down
Loading