diff --git a/mypy/build.py b/mypy/build.py index c2e378290274..3ba7d4a6c0c9 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -40,7 +40,14 @@ TypedDict, ) -from librt.internal import cache_version + +# from librt.internal import cache_version +# from mypy.cache import CACHE_VERSION as cache_version +def cache_version() -> int: + from mypy.cache import CACHE_VERSION + + return CACHE_VERSION + import mypy.semanal_main from mypy.cache import ( diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9990caaeb7a1..a503930b50cb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2,6 +2,7 @@ from __future__ import annotations +# sys.exit(1) import enum import itertools import time @@ -1749,7 +1750,41 @@ def check_callable_call( callee = callee.copy_modified(ret_type=fresh_ret_type) if callee.is_generic(): + # import sys + # sys.stderr.write(f"DEBUG: Checking generic callee: {callee}\n") callee = freshen_function_type_vars(callee) + ctx = self.type_context[-1] + if ctx: + p_ctx = get_proper_type(ctx) + if isinstance(p_ctx, UnionType): + # sys.stderr.write(f"DEBUG: Union Context found: {ctx}\n") + candidates = [] + for item in p_ctx.items: + candidate = self.infer_function_type_arguments_using_context( + callee, context, type_context=item + ) + # Filter out candidates that did not respect the context (e.g. remained generic + # or inferred something incompatible). + if is_subtype(candidate.ret_type, item): + candidates.append(candidate) + + if candidates: + # We use 'None' context to prevent infinite recursion when checking overloads + # provided one of the candidates remains generic. + self.type_context.append(None) + try: + return self.check_overload_call( + Overloaded(candidates), + args, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) + finally: + self.type_context.pop() + callee = self.infer_function_type_arguments_using_context(callee, context) formal_to_actual = map_actuals_to_formals( @@ -1982,7 +2017,7 @@ def infer_arg_types_in_context( return cast(list[Type], res) def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context + self, callable: CallableType, error_context: Context, type_context: Type | None = None ) -> CallableType: """Unify callable return type to type context to infer type vars. @@ -1990,11 +2025,15 @@ def infer_function_type_arguments_using_context( of callable, and if the context is set[int], return callable modified by substituting 't' with 'int'. """ - ctx = self.type_context[-1] + ctx: Type | None + if type_context: + ctx = type_context + else: + ctx = self.type_context[-1] if not ctx: return callable # The return type may have references to type metavariables that - # we are inferring right now. We must consider them as indeterminate + # we are inferred right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. @@ -3143,13 +3182,23 @@ def type_overrides_set( ) -> Iterator[None]: """Set _temporary_ type overrides for given expressions.""" assert len(exprs) == len(overrides) + # Use a dict to store original values. This handles duplicates in exprs automatically + # by only storing the original value for the first occurrence (since we iterate and + # populate if not present). + original_values: dict[Expression, Type | None] = {} for expr, typ in zip(exprs, overrides): + if expr not in original_values: + original_values[expr] = self.type_overrides.get(expr) self.type_overrides[expr] = typ try: yield finally: - for expr in exprs: - del self.type_overrides[expr] + for expr, prev in original_values.items(): + if prev is None: + if expr in self.type_overrides: + del self.type_overrides[expr] + else: + self.type_overrides[expr] = prev def combine_function_signatures(self, types: list[ProperType]) -> AnyType | CallableType: """Accepts a list of function signatures and attempts to combine them together into a