11from typing import override
22
33from sqlalchemy import URL
4+ from sqlalchemy .dialects .postgresql import UUID as PostgresUUID
45from sqlalchemy .exc import SQLAlchemyError
6+ from sqlalchemy .sql .functions import GenericFunction
7+ from starrocks .dialect import StarRocksSQLCompiler , StarRocksTypeCompiler
58
69from archipy .adapters .base .sqlalchemy .session_managers import (
710 AsyncBaseSQLAlchemySessionManager ,
1316from archipy .models .errors import DatabaseConnectionError
1417
1518
19+ # Patch the StarRocks type compiler to map UUID to VARCHAR at module level
20+ # This ensures the patch is applied before any engines are created
21+ def _patch_starrocks_uuid_mapping () -> None :
22+ """Patch the StarRocks type compiler to map UUID to VARCHAR.
23+
24+ StarRocks doesn't support UUID type natively, so we need to map it to VARCHAR(36).
25+ This is patched at module level to ensure it's applied before engine creation.
26+ """
27+
28+ def visit_UUID (self : StarRocksTypeCompiler , type_ : PostgresUUID , ** kw : object ) -> str : # noqa: ARG001
29+ """Map PostgreSQL UUID to VARCHAR(36) for StarRocks."""
30+ return "VARCHAR(36)"
31+
32+ # Patch the type compiler class
33+ StarRocksTypeCompiler .visit_UUID = visit_UUID # ty:ignore[invalid-assignment]
34+
35+
36+ def _patch_starrocks_now_function () -> None :
37+ """Patch the StarRocks SQL compiler to map func.now() to CURRENT_TIMESTAMP.
38+
39+ StarRocks doesn't support now() function, it requires CURRENT_TIMESTAMP instead.
40+ This is patched at module level to ensure it's applied before engine creation.
41+ """
42+ # Store original visit_function if it exists
43+ original_visit_function = getattr (StarRocksSQLCompiler , "visit_function" , None )
44+
45+ def visit_function (self : StarRocksSQLCompiler , func_ : GenericFunction , ** kw : object ) -> str :
46+ """Map func.now() to CURRENT_TIMESTAMP for StarRocks."""
47+ # Check if this is func.now()
48+ if func_ .name == "now" :
49+ return "CURRENT_TIMESTAMP"
50+ # For other functions, use the original handler if it exists
51+ if original_visit_function :
52+ return original_visit_function (self , func_ , ** kw )
53+ # Fallback to default behavior
54+ return f"{ func_ .name } ()"
55+
56+ # Patch the SQL compiler class
57+ StarRocksSQLCompiler .visit_function = visit_function # ty:ignore[invalid-assignment]
58+
59+
60+ # Apply the patches when the module is imported
61+ _patch_starrocks_uuid_mapping ()
62+ _patch_starrocks_now_function ()
63+
64+
1665class StarRocksSQlAlchemySessionManager (BaseSQLAlchemySessionManager [StarRocksSQLAlchemyConfig ], metaclass = Singleton ):
1766 """Synchronous SQLAlchemy session manager for StarRocks.
1867
@@ -122,18 +171,27 @@ def _get_database_name(self) -> str:
122171 def _create_url (self , configs : StarRocksSQLAlchemyConfig ) -> URL :
123172 """Create an async StarRocks connection URL.
124173
174+ For async operations, StarRocks requires an async driver (mysql+aiomysql)
175+ instead of the sync driver (mysql+pymysql).
176+
125177 Args:
126178 configs: StarRocks configuration.
127179
128180 Returns:
129- A SQLAlchemy URL object for StarRocks.
181+ A SQLAlchemy URL object for StarRocks with async driver .
130182
131183 Raises:
132184 DatabaseConnectionError: If there's an error creating the URL.
133185 """
134186 try :
187+ # For async operations, use mysql+aiomysql driver
188+ # If the driver is mysql+pymysql or starrocks, replace with mysql+aiomysql
189+ async_driver = configs .DRIVER_NAME
190+ if async_driver in ("mysql+pymysql" , "starrocks" , "mysql" ):
191+ async_driver = "mysql+aiomysql"
192+
135193 return URL .create (
136- drivername = configs . DRIVER_NAME ,
194+ drivername = async_driver ,
137195 username = configs .USERNAME ,
138196 password = configs .PASSWORD ,
139197 host = configs .HOST ,
@@ -144,3 +202,15 @@ def _create_url(self, configs: StarRocksSQLAlchemyConfig) -> URL:
144202 raise DatabaseConnectionError (
145203 database = self ._get_database_name (),
146204 ) from e
205+
206+ @override
207+ def _get_connect_args (self ) -> dict :
208+ """Return connection arguments for async StarRocks to ensure proper transaction support.
209+
210+ StarRocks (using MySQL protocol) requires autocommit to be explicitly disabled
211+ to ensure transactions work properly with rollback support.
212+
213+ Returns:
214+ A dictionary with autocommit=False to ensure transaction support.
215+ """
216+ return {"autocommit" : False }
0 commit comments