From 84ff913e65732fa6f153dbb8b41df1e932694c21 Mon Sep 17 00:00:00 2001 From: Alek Petuskey Date: Wed, 20 May 2026 12:59:40 -0700 Subject: [PATCH 1/2] Preserve literal string types in cond --- .../src/reflex_components_core/core/cond.py | 14 ++++++++++ tests/units/components/core/test_cond.py | 28 +++++++++++++------ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index a35443cb17c..5f808f6317d 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -14,6 +14,7 @@ from reflex_base.vars import VarData from reflex_base.vars.base import LiteralVar, Var from reflex_base.vars.number import ternary_operation +from typing_extensions import LiteralString from reflex_components_core.base.bare import Bare from reflex_components_core.base.fragment import Fragment @@ -131,6 +132,19 @@ def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright T = TypeVar("T", covariant=True) U = TypeVar("U", covariant=True) +S = TypeVar("S", bound=LiteralString) + + +@overload +def cond(condition: Any, c1: S, c2: S, /) -> Var[S]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond(condition: Any, c1: S, c2: Var[U], /) -> Var[S | U]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond(condition: Any, c1: Var[T], c2: S, /) -> Var[T | S]: ... # pyright: ignore [reportOverlappingOverload] @overload diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index a3c45875417..2b905605fc7 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -181,20 +181,32 @@ def test_cond_assert_types() -> None: # non-component, Component -> Component _ = assert_type(cond(True, "hello", text_comp), Component) - # T, T -> Var[T] - _ = assert_type(cond(True, "hello", "world"), Var[str]) + # literal str, literal str -> Var[Literal[...]] + _ = assert_type(cond(True, "hello", "world"), Var[Literal["hello", "world"]]) # T, U -> Var[T | U] _ = assert_type(cond(True, "hello", 3), Var[str | int]) - # T, Var[T] -> Var[T] - _ = assert_type(cond(True, "hello", var_str), Var[str]) + # literal str, literal Var[str] -> Var[Literal[...]] + _ = assert_type(cond(True, "hello", var_str), Var[Literal["hello", "a"]]) - # Var[T], T -> Var[T] - _ = assert_type(cond(True, var_str, "world"), Var[str]) + # literal str, literal Var[str] -> Var[Literal[...]] + _ = assert_type( + cond(True, "hello", LiteralVar.create("world")), + Var[Literal["hello", "world"]], + ) + + # literal Var[str], literal str -> Var[Literal[...]] + _ = assert_type(cond(True, var_str, "world"), Var[Literal["a", "world"]]) + + # literal Var[str], literal str -> Var[Literal[...]] + _ = assert_type( + cond(True, LiteralVar.create("hello"), "world"), + Var[Literal["hello", "world"]], + ) - # T, Var[U] -> Var[T | U] - _ = assert_type(cond(True, "hello", var_int), Var[str | int]) + # literal str, Var[U] -> Var[Literal[...] | U] + _ = assert_type(cond(True, "hello", var_int), Var[int | Literal["hello"]]) # Var[T], U -> Var[T | U] _ = assert_type(cond(True, var_str, 3), Var[int | Literal["a"]]) From 6a66f73187584115fe79a0775bdc396f291b3295 Mon Sep 17 00:00:00 2001 From: Alek Petuskey Date: Wed, 20 May 2026 13:26:18 -0700 Subject: [PATCH 2/2] Address cond typing review feedback --- .../src/reflex_components_core/core/cond.py | 14 ++++++--- tests/units/components/core/test_cond.py | 31 ++++++++----------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index 5f808f6317d..b6bed3f308c 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -132,19 +132,25 @@ def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright T = TypeVar("T", covariant=True) U = TypeVar("U", covariant=True) -S = TypeVar("S", bound=LiteralString) +LITERAL_STRING_S = TypeVar("LITERAL_STRING_S", bound=LiteralString) @overload -def cond(condition: Any, c1: S, c2: S, /) -> Var[S]: ... # pyright: ignore [reportOverlappingOverload] +def cond( + condition: Any, c1: LITERAL_STRING_S, c2: LITERAL_STRING_S, / +) -> Var[LITERAL_STRING_S]: ... # pyright: ignore [reportOverlappingOverload] @overload -def cond(condition: Any, c1: S, c2: Var[U], /) -> Var[S | U]: ... # pyright: ignore [reportOverlappingOverload] +def cond( + condition: Any, c1: LITERAL_STRING_S, c2: Var[U], / +) -> Var[LITERAL_STRING_S | U]: ... # pyright: ignore [reportOverlappingOverload] @overload -def cond(condition: Any, c1: Var[T], c2: S, /) -> Var[T | S]: ... # pyright: ignore [reportOverlappingOverload] +def cond( + condition: Any, c1: Var[T], c2: LITERAL_STRING_S, / +) -> Var[T | LITERAL_STRING_S]: ... # pyright: ignore [reportOverlappingOverload] @overload diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index 2b905605fc7..8ae0613c333 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -1,5 +1,5 @@ import json -from typing import Any, Literal +from typing import Any, Literal, cast import pytest from reflex_base.components.component import Component @@ -167,7 +167,8 @@ def test_cond_assert_types() -> None: text_comp = Text.create("hello") text_comp2 = Text.create("world") var_int: Var[int] = LiteralVar.create(1) - var_str: Var[str] = LiteralVar.create("a") + literal_var_str = LiteralVar.create("a") + widened_var_str = cast(Var[str], literal_var_str) # Component, Component -> Component _ = assert_type(cond(True, text_comp, text_comp2), Component) @@ -188,28 +189,22 @@ def test_cond_assert_types() -> None: _ = assert_type(cond(True, "hello", 3), Var[str | int]) # literal str, literal Var[str] -> Var[Literal[...]] - _ = assert_type(cond(True, "hello", var_str), Var[Literal["hello", "a"]]) - - # literal str, literal Var[str] -> Var[Literal[...]] - _ = assert_type( - cond(True, "hello", LiteralVar.create("world")), - Var[Literal["hello", "world"]], - ) - - # literal Var[str], literal str -> Var[Literal[...]] - _ = assert_type(cond(True, var_str, "world"), Var[Literal["a", "world"]]) + _ = assert_type(cond(True, "hello", literal_var_str), Var[Literal["hello", "a"]]) # literal Var[str], literal str -> Var[Literal[...]] - _ = assert_type( - cond(True, LiteralVar.create("hello"), "world"), - Var[Literal["hello", "world"]], - ) + _ = assert_type(cond(True, literal_var_str, "world"), Var[Literal["a", "world"]]) # literal str, Var[U] -> Var[Literal[...] | U] _ = assert_type(cond(True, "hello", var_int), Var[int | Literal["hello"]]) + # Var[T], literal str -> Var[T] + _ = assert_type(cond(True, widened_var_str, "world"), Var[str]) + + # literal str, Var[U] -> Var[U] + _ = assert_type(cond(True, "hello", widened_var_str), Var[str]) + # Var[T], U -> Var[T | U] - _ = assert_type(cond(True, var_str, 3), Var[int | Literal["a"]]) + _ = assert_type(cond(True, widened_var_str, 3), Var[str | int]) # Var[T], Var[U] -> Var[T | U] - _ = assert_type(cond(True, var_int, var_str), Var[int | Literal["a"]]) + _ = assert_type(cond(True, var_int, widened_var_str), Var[int | str])