|
1 | 1 | from typing import Literal |
| 2 | +from urllib.parse import urlparse |
2 | 3 |
|
3 | | -from pydantic import BaseModel, Field, SecretStr |
| 4 | +from pydantic import BaseModel, Field, SecretStr, model_validator, PostgresDsn |
4 | 5 |
|
5 | 6 |
|
6 | 7 | class ElasticSearchConfig(BaseModel): |
@@ -216,6 +217,67 @@ class SqlAlchemyConfig(BaseModel): |
216 | 217 | PORT: int | None = 5432 |
217 | 218 | QUERY_CACHE_SIZE: int = 500 |
218 | 219 | USERNAME: str | None = None |
| 220 | + DB_URL: PostgresDsn | None = None |
| 221 | + |
| 222 | + @model_validator(mode="after") |
| 223 | + def build_connection_url(self) -> "SqlAlchemyConfig": |
| 224 | + """Build and populate DB_URL if not provided but all component parts are present.""" |
| 225 | + if self.DB_URL is not None: |
| 226 | + return self |
| 227 | + |
| 228 | + if all([self.USERNAME, self.HOST, self.PORT, self.DATABASE]): |
| 229 | + password_part = f":{self.PASSWORD}" if self.PASSWORD else "" |
| 230 | + url_str = f"{self.DRIVER_NAME}://{self.USERNAME}{password_part}@{self.HOST}:{self.PORT}/{self.DATABASE}" |
| 231 | + self.DB_URL = self.model_construct(DB_URL=url_str).DB_URL |
| 232 | + |
| 233 | + return self |
| 234 | + |
| 235 | + @model_validator(mode="after") |
| 236 | + def extract_connection_parts(self) -> "SqlAlchemyConfig": |
| 237 | + """Extract connection parts from DB_URL if provided but component parts are missing.""" |
| 238 | + if self.DB_URL is None: |
| 239 | + return self |
| 240 | + |
| 241 | + # Check if we need to extract components (if any are None) |
| 242 | + if any(x is None for x in [self.DRIVER_NAME, self.USERNAME, self.HOST, self.PORT, self.DATABASE]): |
| 243 | + url = str(self.DB_URL) |
| 244 | + parsed = urlparse(url) |
| 245 | + |
| 246 | + # Extract scheme/driver |
| 247 | + if self.DRIVER_NAME is None and parsed.scheme: |
| 248 | + self.DRIVER_NAME = parsed.scheme |
| 249 | + |
| 250 | + # Extract username and password |
| 251 | + if parsed.netloc: |
| 252 | + auth_part = parsed.netloc.split("@")[0] if "@" in parsed.netloc else "" |
| 253 | + if ":" in auth_part: |
| 254 | + username, password = auth_part.split(":", 1) |
| 255 | + if self.USERNAME is None: |
| 256 | + self.USERNAME = username |
| 257 | + if self.PASSWORD is None: |
| 258 | + self.PASSWORD = password |
| 259 | + elif auth_part and self.USERNAME is None: |
| 260 | + self.USERNAME = auth_part |
| 261 | + |
| 262 | + # Extract host and port |
| 263 | + host_part = parsed.netloc.split("@")[-1] if "@" in parsed.netloc else parsed.netloc |
| 264 | + if ":" in host_part: |
| 265 | + host, port_str = host_part.split(":", 1) |
| 266 | + if self.HOST is None: |
| 267 | + self.HOST = host |
| 268 | + if self.PORT is None: |
| 269 | + try: |
| 270 | + self.PORT = int(port_str) |
| 271 | + except ValueError: |
| 272 | + pass |
| 273 | + elif host_part and self.HOST is None: |
| 274 | + self.HOST = host_part |
| 275 | + |
| 276 | + # Extract database name |
| 277 | + if self.DATABASE is None and parsed.path and parsed.path.startswith("/"): |
| 278 | + self.DATABASE = parsed.path[1:] |
| 279 | + |
| 280 | + return self |
219 | 281 |
|
220 | 282 |
|
221 | 283 | class PrometheusConfig(BaseModel): |
|
0 commit comments