diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index a35443cb17c..b6bed3f308c 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -14,6 +14,7 @@ from reflex_base.vars import VarData from reflex_base.vars.base import LiteralVar, Var from reflex_base.vars.number import ternary_operation +from typing_extensions import LiteralString from reflex_components_core.base.bare import Bare from reflex_components_core.base.fragment import Fragment @@ -131,6 +132,25 @@ def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright T = TypeVar("T", covariant=True) U = TypeVar("U", covariant=True) +LITERAL_STRING_S = TypeVar("LITERAL_STRING_S", bound=LiteralString) + + +@overload +def cond( + condition: Any, c1: LITERAL_STRING_S, c2: LITERAL_STRING_S, / +) -> Var[LITERAL_STRING_S]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond( + condition: Any, c1: LITERAL_STRING_S, c2: Var[U], / +) -> Var[LITERAL_STRING_S | U]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond( + condition: Any, c1: Var[T], c2: LITERAL_STRING_S, / +) -> Var[T | LITERAL_STRING_S]: ... # pyright: ignore [reportOverlappingOverload] @overload diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index a3c45875417..8ae0613c333 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -1,5 +1,5 @@ import json -from typing import Any, Literal +from typing import Any, Literal, cast import pytest from reflex_base.components.component import Component @@ -167,7 +167,8 @@ def test_cond_assert_types() -> None: text_comp = Text.create("hello") text_comp2 = Text.create("world") var_int: Var[int] = LiteralVar.create(1) - var_str: Var[str] = LiteralVar.create("a") + literal_var_str = LiteralVar.create("a") + widened_var_str = cast(Var[str], literal_var_str) # Component, Component -> Component _ = assert_type(cond(True, text_comp, text_comp2), Component) @@ -181,23 +182,29 @@ def test_cond_assert_types() -> None: # non-component, Component -> Component _ = assert_type(cond(True, "hello", text_comp), Component) - # T, T -> Var[T] - _ = assert_type(cond(True, "hello", "world"), Var[str]) + # literal str, literal str -> Var[Literal[...]] + _ = assert_type(cond(True, "hello", "world"), Var[Literal["hello", "world"]]) # T, U -> Var[T | U] _ = assert_type(cond(True, "hello", 3), Var[str | int]) - # T, Var[T] -> Var[T] - _ = assert_type(cond(True, "hello", var_str), Var[str]) + # literal str, literal Var[str] -> Var[Literal[...]] + _ = assert_type(cond(True, "hello", literal_var_str), Var[Literal["hello", "a"]]) - # Var[T], T -> Var[T] - _ = assert_type(cond(True, var_str, "world"), Var[str]) + # literal Var[str], literal str -> Var[Literal[...]] + _ = assert_type(cond(True, literal_var_str, "world"), Var[Literal["a", "world"]]) - # T, Var[U] -> Var[T | U] - _ = assert_type(cond(True, "hello", var_int), Var[str | int]) + # literal str, Var[U] -> Var[Literal[...] | U] + _ = assert_type(cond(True, "hello", var_int), Var[int | Literal["hello"]]) + + # Var[T], literal str -> Var[T] + _ = assert_type(cond(True, widened_var_str, "world"), Var[str]) + + # literal str, Var[U] -> Var[U] + _ = assert_type(cond(True, "hello", widened_var_str), Var[str]) # Var[T], U -> Var[T | U] - _ = assert_type(cond(True, var_str, 3), Var[int | Literal["a"]]) + _ = assert_type(cond(True, widened_var_str, 3), Var[str | int]) # Var[T], Var[U] -> Var[T | U] - _ = assert_type(cond(True, var_int, var_str), Var[int | Literal["a"]]) + _ = assert_type(cond(True, var_int, widened_var_str), Var[int | str])