From 348fd715e48966d5958fd66b400253b8e426e01a Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 11:25:59 +0200 Subject: [PATCH 01/21] Add configuration models and tests. --- pyaml/validation/models.py | 37 ++++++++++++++ tests/validation/test_models.py | 87 +++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 pyaml/validation/models.py create mode 100644 tests/validation/test_models.py diff --git a/pyaml/validation/models.py b/pyaml/validation/models.py new file mode 100644 index 000000000..978d12dbc --- /dev/null +++ b/pyaml/validation/models.py @@ -0,0 +1,37 @@ +"""Base datamodels for configuration.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class PyAMLBaseModel(BaseModel): + """ + Base model for pyAML. + + Overrides ``model_dump()`` and ``model_dump_json()`` to enable + ``serialize_as_any=True`` by default. This ensures that fields are + serialized according to their runtime type rather than their declared + annotation type. + """ + + def model_dump(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + def model_dump_json(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump_json(**kwargs) + + +class ConfigurationSchema(PyAMLBaseModel): + """ + Base model for configuration schemas. + + Includes mandatory fields and functionality for all schemas which is to be registered in the :class:`SchemaRegistry`. + """ + + model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, arbitrary_types_allowed=False, extra="forbid") + + class_path: str = Field( + description="Fully qualified class path.", + alias="class", + ) diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py new file mode 100644 index 000000000..e39290565 --- /dev/null +++ b/tests/validation/test_models.py @@ -0,0 +1,87 @@ +"""Tests of the configuration models.""" + +import json + +import pytest +from pydantic import ValidationError +from pydantic.errors import PydanticSchemaGenerationError + +from pyaml.validation import ConfigurationSchema +from pyaml.validation.models import PyAMLBaseModel + + +def test_model_dump_serializes_subclass_fields(): + class Device(PyAMLBaseModel): + name: str + + class Magnet(Device): + type: str + + class Accelerator(PyAMLBaseModel): + device: Device + + accelerator = Accelerator(device=Magnet(name="QF", type="Quadrupole")) + + dumped = accelerator.model_dump() + + assert dumped == {"device": {"name": "QF", "type": "Quadrupole"}} + + +def test_model_dump_json_serializes_subclass_fields(): + class Device(PyAMLBaseModel): + name: str + + class Magnet(Device): + type: str + + class Accelerator(PyAMLBaseModel): + device: Device + + accelerator = Accelerator(device=Magnet(name="QF", type="Quadrupole")) + + dumped_json = accelerator.model_dump_json() + dumped = json.loads(dumped_json) + + assert dumped == {"device": {"name": "QF", "type": "Quadrupole"}} + + +def test_configuration_schema_accepts_alias_class(): + schema = ConfigurationSchema.model_validate({"class": "pkg.module.Class"}) + + assert schema.class_path == "pkg.module.Class" + + +def test_configuration_schema_accepts_field_name_class_path(): + schema = ConfigurationSchema.model_validate({"class_path": "pkg.module.Class"}) + + assert schema.class_path == "pkg.module.Class" + + +def test_configuration_schema_forbids_extra_fields(): + with pytest.raises(ValidationError) as exc_info: + ConfigurationSchema.model_validate( + { + "class": "pkg.module.Class", + "unexpected": "value", + } + ) + + assert "extra_forbidden" in str(exc_info.value) + + +def test_models_do_not_allow_arbitrary_types(): + class ArbitraryType: + pass + + with pytest.raises(PydanticSchemaGenerationError): + + class TestModel(ConfigurationSchema): + value: ArbitraryType + + +def test_configuration_schema_dump_uses_alias_when_requested(): + schema = ConfigurationSchema.model_validate({"class": "pkg.module.Class"}) + + dumped = schema.model_dump(by_alias=True) + + assert dumped == {"class": "pkg.module.Class"} From e54c4ff36ef8d0b820e70d2631832267673b83da Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 11:44:20 +0200 Subject: [PATCH 02/21] Add schema registry and tests. --- pyaml/validation/__init__.py | 12 ++ pyaml/validation/registry.py | 345 +++++++++++++++++++++++++++++ tests/validation/test_registry.py | 347 ++++++++++++++++++++++++++++++ 3 files changed, 704 insertions(+) create mode 100644 pyaml/validation/__init__.py create mode 100644 pyaml/validation/registry.py create mode 100644 tests/validation/test_registry.py diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py new file mode 100644 index 000000000..744ee008b --- /dev/null +++ b/pyaml/validation/__init__.py @@ -0,0 +1,12 @@ +""" +PyAML validation subpackage. +""" + +from .models import ConfigurationSchema +from .registry import SchemaRegistry, register_schema + +__all__ = [ + "ConfigurationSchema", + "SchemaRegistry", + "register_schema", +] diff --git a/pyaml/validation/registry.py b/pyaml/validation/registry.py new file mode 100644 index 000000000..b9f39cb69 --- /dev/null +++ b/pyaml/validation/registry.py @@ -0,0 +1,345 @@ +"""Registry for schemas.""" + +import importlib +import logging +import pkgutil +from collections.abc import ItemsView, Iterator, KeysView, ValuesView +from typing import Callable, Type, TypeVar + +from .models import ConfigurationSchema + +logger = logging.getLogger(__name__) + + +class SchemaRegistry: + """ + Singleton registry for dynamically registered schemas. + + The registry is used to validate data and produce + jsonschemas for dynamic nested models. + """ + + _instance: "SchemaRegistry | None" = None + _schemas: dict[str, Type[ConfigurationSchema]] + + def __new__(cls) -> "SchemaRegistry": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._schemas = {} + return cls._instance + + # ========================================================== + # Registration + # ========================================================== + + def register( + self, + class_path: str, + schema: type[ConfigurationSchema], + ) -> None: + """Register a schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + schema : type[ConfigurationSchema] + Schema class used for validation. Must inherit from + :class:`ConfigurationSchema`. + + Raises + ------ + TypeError + If ``schema`` is not a subclass of + :class:`ConfigurationSchema`. + ValueError + If a different schema has already been registered for + ``class_path``. + """ + existing = self._schemas.get(class_path) + if existing is not None and existing is not schema: + raise ValueError(f"{class_path} already registered with a different schema.") + + if not isinstance(schema, type) or not issubclass(schema, ConfigurationSchema): + raise TypeError(f"{schema!r} must inherit from ConfigurationSchema.") + + self._schemas[class_path] = schema + + def discover(self) -> None: + """Discover and register schemas. + + This imports modules in the package so classes decorated with + :func:`register_schema` are registered, then registers legacy + schemas from ``pyproject.toml``. + """ + + # Import package modules so schema registration runs. + root_package = __package__.split(".")[0] + package = importlib.import_module(root_package) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, + package.__name__ + ".", + ): + importlib.import_module(module_name) + + def unregister( + self, + class_path: str, + ) -> None: + """Unregister a schema. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Raises + ------ + KeyError + If no schema has been registered for ``class_path``. + """ + try: + del self._schemas[class_path] + + except KeyError: + raise KeyError(f"No schema registered for '{class_path}'") from None + + def clear(self) -> None: + """Remove all registered schemas. + + This clears the registry in place. + """ + self._schemas.clear() + + def __repr__( + self, + ) -> str: + """Return a string representation of the registry.""" + if not self._schemas: + return f"{self.__class__.__name__}({{}})" + + lines = [f"{self.__class__.__name__}("] + + for class_path, schema in sorted(self._schemas.items()): + lines.append(f" {class_path!r}: {schema.__module__}.{schema.__name__},") + + lines.append(")") + + return "\n".join(lines) + + # ========================================================== + # Lookup + # ========================================================== + + def __getitem__( + self, + class_path: str, + ) -> Type[ConfigurationSchema]: + """Return the registered schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + Type[ConfigurationSchema] + Registered schema class. + + Raises + ------ + KeyError + If no schema has been registered for ``class_path``. + """ + + try: + return self._schemas[class_path] + + except KeyError: + raise KeyError(f"No schema registered for '{class_path}.'") from None + + def get( + self, + class_path: str, + ) -> type[ConfigurationSchema] | None: + """Return the registered schema for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + type[ConfigurationSchema] | None + Registered schema class, or ``None`` if no schema is + registered for ``class_path``. + """ + return self._schemas.get(class_path) + + # ========================================================== + # Contents + # ========================================================== + + def __contains__( + self, + class_path: str, + ) -> bool: + """Return whether a schema is registered for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + + Returns + ------- + bool + ``True`` if a schema is registered for ``class_path``, + otherwise ``False``. + """ + return class_path in self._schemas + + def items( + self, + ) -> ItemsView[str, Type[ConfigurationSchema]]: + """Return a view of registered schema items. + + Returns + ------- + ItemsView[str, Type[ConfigurationSchema]] + View of registered ``(class_path, schema)`` pairs. + """ + return self._schemas.items() + + def keys( + self, + ) -> KeysView[str]: + """Return a view of registered class paths. + + Returns + ------- + KeysView[str] + View of registered class paths. + """ + return self._schemas.keys() + + def values( + self, + ) -> ValuesView[Type[ConfigurationSchema]]: + """Return a view of registered schemas. + + Returns + ------- + ValuesView[Type[ConfigurationSchema]] + View of registered schema classes. + """ + return self._schemas.values() + + def __len__( + self, + ) -> int: + """Return the number of registered schemas. + + Returns + ------- + int + Number of registered schemas. + """ + return len(self._schemas) + + def __iter__( + self, + ) -> Iterator[str]: + """Iterate over registered class paths. + + Returns + ------- + Iterator[str] + Iterator over registered class paths. + """ + return iter(self._schemas) + + # ========================================================== + # Updating + # ========================================================== + + def update( + self, + class_path: str, + schema: type[ConfigurationSchema], + ) -> None: + """Replace the schema registered for a class path. + + Parameters + ---------- + class_path : str + Fully qualified class path. + schema : type[ConfigurationSchema] + Schema class used for validation. Must inherit from + :class:`ConfigurationSchema`. + + Raises + ------ + TypeError + If ``schema`` is not a subclass of + :class:`ConfigurationSchema`. + KeyError + If no schema has been registered for ``class_path``. + """ + if not isinstance(schema, type) or not issubclass(schema, ConfigurationSchema): + raise TypeError(f"{schema!r} must inherit from ConfigurationSchema.") + + if class_path not in self._schemas: + raise KeyError(f"{class_path} is not registered.") + + self._schemas[class_path] = schema + + +# ========================================================== +# Decorator to register schemas +# ========================================================== + +ModelT = TypeVar("ModelT", bound=ConfigurationSchema) +ClassT = TypeVar("ClassT") + + +def register_schema( + schema: Type[ModelT], +) -> Callable[[Type[ClassT]], Type[ClassT]]: + """Register a runtime class with a Pydantic schema. + + Parameters + ---------- + schema : Type[ModelT] + Schema class to register. Must inherit from + :class:`ConfigurationSchema`. + + Returns + ------- + Callable[[Type[ClassT]], Type[ClassT]] + Decorator that registers the decorated class with ``schema``. + + Examples + -------- + >>> @register_schema(MySchema) + ... class MyClass: + ... pass + """ + + registry = SchemaRegistry() + + def decorator( + cls: Type[ClassT], + ) -> Type[ClassT]: + class_path = f"{cls.__module__}.{cls.__name__}" + + registry.register( + class_path=class_path, + schema=schema, + ) + + return cls + + return decorator diff --git a/tests/validation/test_registry.py b/tests/validation/test_registry.py new file mode 100644 index 000000000..ce68aa822 --- /dev/null +++ b/tests/validation/test_registry.py @@ -0,0 +1,347 @@ +"""Tests of the schema registry.""" + +import re +from collections.abc import Generator + +import pytest + +from pyaml.validation import ConfigurationSchema, SchemaRegistry, register_schema + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + pass + + +class OtherSchema(ConfigurationSchema): + pass + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Singleton behaviour +# ========================================================== + + +def test_singleton_returns_same_instance(): + assert SchemaRegistry() is SchemaRegistry() + + +# ========================================================== +# Registration +# ========================================================== + + +def test_register_stores_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + assert registry["pkg.module.Class"] is DummySchema + + +def test_register_allows_same_schema_for_existing_class_path(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + registry.register(class_path, DummySchema) + + assert registry[class_path] is DummySchema + assert len(registry) == 1 + + +def test_register_raises_valueerror_for_different_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + with pytest.raises( + ValueError, + match=re.escape(f"{class_path} already registered with a different schema."), + ): + registry.register(class_path, OtherSchema) + + +def test_register_raises_typeerror_for_invalid_schema(registry: SchemaRegistry): + with pytest.raises( + TypeError, + match=re.escape("must inherit from ConfigurationSchema"), + ): + registry.register( + "pkg.module.Class", + object, # type: ignore[arg-type] + ) + + +# ========================================================== +# Unregistering +# ========================================================== + + +def test_unregister_removes_registered_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + registry.unregister(class_path) + + assert class_path not in registry + + +def test_unregister_raises_clean_keyerror_for_missing_schema( + registry: SchemaRegistry, +): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"No schema registered for '{class_path}'"), + ): + registry.unregister(class_path) + + +def test_unregister_removes_only_requested_schema(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + registry.unregister("pkg.module.ClassA") + + assert "pkg.module.ClassA" not in registry + assert registry["pkg.module.ClassB"] is OtherSchema + assert len(registry) == 1 + + +# ========================================================== +# Clearing +# ========================================================== + + +def test_clear_removes_all_registered_schemas(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + registry.clear() + + assert len(registry) == 0 + assert "pkg.module.ClassA" not in registry + assert "pkg.module.ClassB" not in registry + + +def test_clear_on_empty_registry_keeps_registry_empty( + registry: SchemaRegistry, +): + registry.clear() + + assert len(registry) == 0 + + +def test_clear_allows_new_registrations_afterwards( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + registry.clear() + + registry.register("pkg.module.OtherClass", OtherSchema) + + assert len(registry) == 1 + assert registry["pkg.module.OtherClass"] is OtherSchema + + +# ========================================================== +# Lookup +# ========================================================== + + +def test_getitem_raises_clean_keyerror_for_missing_schema(registry: SchemaRegistry): + with pytest.raises(KeyError, match=r"No schema registered for 'pkg\.module\.Class.'"): + _ = registry["pkg.module.Class"] + + +def test_get_returns_registered_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + assert registry.get("pkg.module.Class") is DummySchema + + +def test_get_returns_none_for_missing_schema(registry: SchemaRegistry): + assert registry.get("pkg.module.Class") is None + + +# ========================================================== +# Contents +# ========================================================== + + +def test_items_returns_registered_items(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + items = registry.items() + + assert ("pkg.module.ClassA", DummySchema) in items + assert ("pkg.module.ClassB", OtherSchema) in items + assert len(items) == 2 + + +def test_keys_returns_registered_class_paths(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + keys = registry.keys() + + assert "pkg.module.ClassA" in keys + assert "pkg.module.ClassB" in keys + assert len(keys) == 2 + + +def test_values_returns_registered_schemas(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + values = registry.values() + + assert DummySchema in values + assert OtherSchema in values + assert len(values) == 2 + + +def test_iter_returns_registered_class_paths(registry: SchemaRegistry): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + class_paths = list(iter(registry)) + + assert "pkg.module.ClassA" in class_paths + assert "pkg.module.ClassB" in class_paths + assert len(class_paths) == 2 + + +# ========================================================== +# Updating +# ========================================================== + + +def test_update_replaces_registered_schema(registry: SchemaRegistry): + registry.register("pkg.module.Class", DummySchema) + + registry.update("pkg.module.Class", OtherSchema) + + assert registry["pkg.module.Class"] is OtherSchema + + +def test_update_raises_keyerror_for_missing_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"{class_path} is not registered."), + ): + registry.update(class_path, DummySchema) + + +def test_update_raises_typeerror_for_invalid_schema(registry: SchemaRegistry): + with pytest.raises( + TypeError, + match=r"must inherit from ConfigurationSchema", + ): + registry.update( + "pkg.module.Class", + object, # type: ignore[arg-type] + ) + + +# ========================================================== +# Representation +# ========================================================== + + +def test_repr_returns_empty_registry_representation( + registry: SchemaRegistry, +): + assert repr(registry) == "SchemaRegistry({})" + + +def test_repr_returns_registered_schemas( + registry: SchemaRegistry, +): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + result = repr(registry) + + assert result.startswith("SchemaRegistry(") + assert "'pkg.module.ClassA'" in result + assert "'pkg.module.ClassB'" in result + + assert f"{DummySchema.__module__}.{DummySchema.__name__}" in result + assert f"{OtherSchema.__module__}.{OtherSchema.__name__}" in result + + assert result.endswith(")") + + +def test_repr_sorts_registered_class_paths( + registry: SchemaRegistry, +): + registry.register("pkg.module.ZClass", DummySchema) + registry.register("pkg.module.AClass", OtherSchema) + + result = repr(registry) + + assert result.index("'pkg.module.AClass'") < result.index("'pkg.module.ZClass'") + + +# ========================================================== +# Register schema decorator +# ========================================================== + + +def test_register_schema_registers_the_decorated_class( + registry: SchemaRegistry, +): + @register_schema(DummySchema) + class DecoratedClass: + pass + + class_path = f"{DecoratedClass.__module__}.{DecoratedClass.__name__}" + + assert registry[class_path] is DummySchema + + +def test_register_schema_can_register_multiple_classes_with_same_schema( + registry: SchemaRegistry, +): + @register_schema(DummySchema) + class FirstClass: + pass + + @register_schema(DummySchema) + class SecondClass: + pass + + first_path = f"{FirstClass.__module__}.{FirstClass.__name__}" + second_path = f"{SecondClass.__module__}.{SecondClass.__name__}" + + assert registry[first_path] is DummySchema + assert registry[second_path] is DummySchema + assert len(registry) == 2 From 32184519e8aa364908154f8fe1010374a7046058 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 16:09:02 +0200 Subject: [PATCH 03/21] Add schema validator. --- pyaml/validation/__init__.py | 2 + pyaml/validation/validator.py | 144 ++++++++++++++++++++++ tests/validation/test_validator.py | 192 +++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+) create mode 100644 pyaml/validation/validator.py create mode 100644 tests/validation/test_validator.py diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py index 744ee008b..ab02d64d1 100644 --- a/pyaml/validation/__init__.py +++ b/pyaml/validation/__init__.py @@ -4,9 +4,11 @@ from .models import ConfigurationSchema from .registry import SchemaRegistry, register_schema +from .validator import SchemaValidator __all__ = [ "ConfigurationSchema", "SchemaRegistry", + "SchemaValidator", "register_schema", ] diff --git a/pyaml/validation/validator.py b/pyaml/validation/validator.py new file mode 100644 index 000000000..da5ed5500 --- /dev/null +++ b/pyaml/validation/validator.py @@ -0,0 +1,144 @@ +"""Module for schema validation.""" + +import logging +import warnings +from typing import Any + +from pydantic import ValidationError + +from .models import ConfigurationSchema +from .registry import SchemaRegistry + +logger = logging.getLogger(__name__) + + +class SchemaValidator: + """Recursive validator for configuration dictionaries. + + The validator traverses nested configuration data structures and + converts dictionaries representing configuration objects into + validated Pydantic schema models. + + Validation is performed recursively: + + - Lists are traversed element-by-element + - Dictionaries are recursively validated + - Dictionaries matching configuration schemas are converted into + validated schema models + - Dictionaries with unknown schemas are left unchanged + + Schema lookup is performed through the :class:`SchemaRegistry`. + """ + + _registry = SchemaRegistry() + + @classmethod + def validate( + cls, + data: dict[str, Any], + ) -> ConfigurationSchema: + """Validate configuration data recursively. + + Parameters + ---------- + data : dict[str, Any] + Configuration dictionary to validate. + + Returns + ------- + ConfigurationSchema + Fully validated top-level configuration model. + + Raises + ------ + TypeError + If the validated top-level object is not a + :class:`ConfigurationSchema`. + """ + validated = cls._recursive_validate(data) + + if not isinstance(validated, ConfigurationSchema): + raise TypeError("Top-level configuration did not validate to a ConfigurationSchema.") + + return validated + + @classmethod + def _recursive_validate(cls, obj: Any) -> Any: + """Recursively validate nested configuration objects. + + Lists are traversed recursively element-by-element. Dictionaries + are recursively traversed and then interpreted as configuration + objects when possible. + + If a dictionary corresponds to a registered configuration schema, + it is converted into a validated schema model. Otherwise, the + dictionary is returned unchanged. + + Parameters + ---------- + obj : Any + Object to validate recursively. + + Returns + ------- + Any + Validated object. This may be: + + - A validated configuration model + - A recursively validated list + - A recursively validated dictionary + - The original object if no validation applies + """ + if isinstance(obj, list): + return [cls._recursive_validate(item) for item in obj] + + if not isinstance(obj, dict): + return obj + + logger.debug("Validating dict with keys: %s", list(obj)) + validated_dict = {key: cls._recursive_validate(value) for key, value in obj.items()} + + # Check if the dict is a configuration object + config = cls._parse_configuration(validated_dict) + if config is None: + return validated_dict + + class_path = config.class_path + schema = cls._registry.get(class_path) + + if schema is None: + warnings.warn( + f"Unknown schema for '{class_path}' so cannot validate. Leaving data as raw dict.", + stacklevel=2, + ) + return validated_dict + + return schema.model_validate(validated_dict) + + @classmethod + def _parse_configuration( + cls, + validated_dict: dict[str, Any], + ) -> ConfigurationSchema | None: + """Parse a dictionary as configuration metadata. + + Parameters + ---------- + validated_dict : dict[str, Any] + Dictionary to interpret as configuration metadata. + + Returns + ------- + ConfigurationSchema | None + Parsed configuration model if validation succeeds, + otherwise ``None``. + """ + try: + return ConfigurationSchema.model_validate( + validated_dict, + extra="allow", + ) + except ValidationError: + logger.debug("Could not validate against ConfigurationSchema.") + + return None diff --git a/tests/validation/test_validator.py b/tests/validation/test_validator.py new file mode 100644 index 000000000..e7eec6e95 --- /dev/null +++ b/tests/validation/test_validator.py @@ -0,0 +1,192 @@ +"""Tests of the schema validator.""" + +from collections.abc import Generator + +import pytest + +from pyaml.validation import ( + ConfigurationSchema, + SchemaRegistry, + SchemaValidator, +) + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + value: int | None = None + + +class OtherSchema(ConfigurationSchema): + name: str | None = None + children: list[DummySchema] | None = None + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Recursive validation +# ========================================================== + + +def test_recursive_validate_returns_validated_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + data = { + "class_path": "pkg.module.Class", + "value": 42, + } + + result = SchemaValidator._recursive_validate(data) + + assert isinstance(result, DummySchema) + assert result.class_path == "pkg.module.Class" + assert result.value == 42 + + +def test_recursive_validate_recurses_through_nested_lists_and_dicts( + registry: SchemaRegistry, +): + registry.register("pkg.module.ClassA", DummySchema) + registry.register("pkg.module.ClassB", OtherSchema) + + data = { + "class_path": "pkg.module.ClassB", + "name": "dummy", + "children": [ + { + "class_path": "pkg.module.ClassA", + "value": "42", + }, + { + "class_path": "pkg.module.ClassA", + "value": "73", + }, + ], + } + + result = SchemaValidator._recursive_validate(data) + + assert isinstance(result, OtherSchema) + + assert isinstance(result.children[0], DummySchema) + assert result.children[0].value == 42 + + assert isinstance(result.children[1], DummySchema) + assert result.children[1].value == 73 + + +def test_recursive_validate_leaves_plain_dicts_unchanged(): + data = { + "plain": "dict", + } + + result = SchemaValidator._recursive_validate(data) + + assert result == data + + +def test_recursive_validate_leaves_non_container_values_unchanged(): + assert SchemaValidator._recursive_validate("text") == "text" + assert SchemaValidator._recursive_validate(123) == 123 + assert SchemaValidator._recursive_validate(True) is True + assert SchemaValidator._recursive_validate(None) is None + + +def test_recursive_validate_warns_for_unknown_schema( + registry: SchemaRegistry, +): + data = { + "class_path": "pkg.module.Unknown", + "value": 42, + } + + with pytest.warns( + UserWarning, + match=r"Unknown schema for 'pkg\.module\.Unknown' so cannot validate\. Leaving data as raw dict\.", + ): + result = SchemaValidator._recursive_validate(data) + + assert result == data + + +# ========================================================== +# Configuration parsing +# ========================================================== + + +def test_parse_configuration_returns_configuration_schema(): + data = { + "class_path": "pkg.module.Class", + } + + result = SchemaValidator._parse_configuration(data) + + assert isinstance(result, ConfigurationSchema) + assert result.class_path == "pkg.module.Class" + + +def test_parse_configuration_returns_none_for_non_configuration_dict(): + data = { + "plain": "dict", + } + + result = SchemaValidator._parse_configuration(data) + + assert result is None + + +# ========================================================== +# Top-level validation +# ========================================================== + + +def test_validate_returns_validated_configuration_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Class", DummySchema) + + data = { + "class_path": "pkg.module.Class", + "value": 42, + } + + result = SchemaValidator.validate(data) + + assert isinstance(result, DummySchema) + assert result.class_path == "pkg.module.Class" + assert result.value == 42 + + +def test_validate_raises_typeerror_for_non_configuration_dict(): + data = { + "plain": "dict", + } + + with pytest.raises( + TypeError, + match=r"Top-level configuration did not validate to a ConfigurationSchema\.", + ): + SchemaValidator.validate(data) From 61f57c4e0178bf254e60ae4191476dc0c2ce7441 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 16:39:29 +0200 Subject: [PATCH 04/21] Add schema generator and tests. --- pyaml/validation/__init__.py | 2 + pyaml/validation/generator.py | 290 +++++++++++++++++++++++++++++ tests/validation/test_generator.py | 204 ++++++++++++++++++++ 3 files changed, 496 insertions(+) create mode 100644 pyaml/validation/generator.py create mode 100644 tests/validation/test_generator.py diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py index ab02d64d1..54710e235 100644 --- a/pyaml/validation/__init__.py +++ b/pyaml/validation/__init__.py @@ -2,6 +2,7 @@ PyAML validation subpackage. """ +from .generator import SchemaGenerator from .models import ConfigurationSchema from .registry import SchemaRegistry, register_schema from .validator import SchemaValidator @@ -10,5 +11,6 @@ "ConfigurationSchema", "SchemaRegistry", "SchemaValidator", + "SchemaGenerator", "register_schema", ] diff --git a/pyaml/validation/generator.py b/pyaml/validation/generator.py new file mode 100644 index 000000000..8cdaeced0 --- /dev/null +++ b/pyaml/validation/generator.py @@ -0,0 +1,290 @@ +"""Module for generating JSON Schema from registered configuration schemas.""" + +import json +import logging +from copy import deepcopy +from pathlib import Path +from typing import Any + +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from pydantic_core import core_schema + +from .models import ConfigurationSchema +from .registry import SchemaRegistry + +logger = logging.getLogger(__name__) + + +METADATA_KEYS = ( + "title", + "description", + "examples", + "deprecated", + "readOnly", + "writeOnly", +) + +CLASS_ALIAS = "class" + + +class SchemaGenerator: + """ + Generate JSON Schemas for registered configuration models. + + This class provides convenience methods for generating and exporting + JSON Schemas from models registered in the ``SchemaRegistry``. Schema + generation is delegated to a custom Pydantic JSON Schema generator + that adds support for registry-aware polymorphism. + + Configuration base classes with registered subclasses are represented + as ``oneOf`` unions over their concrete implementations, allowing + generated schemas to describe all valid registered configuration types. + + Primitive unions such as ``str | None`` are emitted using compact + ``type: [...]`` representations when supported by Pydantic. + """ + + _registry = SchemaRegistry() + + @classmethod + def generate(cls, class_path: str) -> dict[str, Any]: + """ + Generate a JSON Schema for a registered configuration schema. + + The schema is generated using a custom Pydantic JSON Schema generator + that expands registered configuration subclasses into ``oneOf`` unions + and preserves compact representations for primitive unions such as + ``str | None``. + + Parameters + ---------- + class_path : str + Registry key identifying the configuration schema class. + + Returns + ------- + dict[str, Any] + Generated JSON Schema for the requested configuration schema. + + Raises + ------ + KeyError + If no schema is registered for the given class path. + """ + + schema_cls = cls._registry.get(class_path) + + logger.debug("Generating schema for %s.", schema_cls) + + if schema_cls is None: + raise KeyError(f"No schema registered for '{class_path}'") + + return schema_cls.model_json_schema( + by_alias=True, + union_format="primitive_type_array", + schema_generator=RegistryJsonSchema, + ) + + @classmethod + def save( + cls, + class_path: str, + filename: str | Path, + *, + indent: int = 2, + ) -> Path: + """ + Generate JSON Schema and save it to a file. + + Parameters + ---------- + class_path : str + Registered class path to generate schema for. + filename : str or Path + Output filename. + indent : int, optional + JSON indentation level. Default: 2. + + Returns + ------- + Path + Path to the written file. + """ + schema = cls.generate(class_path) + + path = Path(filename) + + with path.open("w", encoding="utf-8") as file: + json.dump(schema, file, indent=indent) + + return path + + +class RegistryJsonSchema(GenerateJsonSchema): + """ + Custom Pydantic JSON Schema generator for configuration schemas. + + This generator extends the default Pydantic schema generation to support + registry-aware polymorphism for ``ConfigurationSchema`` subclasses. + + For configuration base classes with registered subclasses, the generated + schema is replaced by a ``oneOf`` union over all registered concrete + subclasses. Human-facing schema metadata such as titles and descriptions + are preserved from the original schema. + + In addition, all generated schema unions are normalized to use ``oneOf`` + instead of ``anyOf`` for improved compatibility with downstream tooling. + Primitive unions such as ``str | None`` continue to use compact + ``type: [...]`` representations when supported by Pydantic. + """ + + _registry = SchemaRegistry() + + def model_schema(self, schema: core_schema.ModelSchema) -> dict[str, Any]: + """ + Generate a JSON Schema for a Pydantic model. + + For ``ConfigurationSchema`` subclasses, the generated schema may be + transformed into a polymorphic schema based on the registered schema + registry: + + - If the model defines a ``class`` field, all registered aliases + corresponding to the model are added as allowed literal values. + - If registered subclasses exist, the schema is replaced by an + ``anyOf`` union containing the schemas of all registered subclasses. + + Metadata fields from the original schema, such as titles and + descriptions, are preserved in the merged schema. + + Parameters + ---------- + schema : core_schema.ModelSchema + Pydantic core schema describing the model. + + Returns + ------- + dict[str, Any] + Generated JSON Schema for the model or polymorphic union schema. + + Notes + ----- + The generated polymorphic schema uses ``anyOf`` instead of ``oneOf`` + because nested ``oneOf`` unions may lead to ambiguous validation in + downstream JSON Schema tooling when subclass schemas contain nullable + or overlapping branches. + """ + + base_schema = super().model_schema(schema) + model_cls = schema.get("cls") + logging.debug(f"Base schema is extracted from {model_cls}.") + + if not isinstance(model_cls, type) or not issubclass(model_cls, ConfigurationSchema): + return base_schema + + # If the baseschema has a class field, add literal for all keys. + properties = base_schema.get("properties") + if isinstance(properties, dict) and CLASS_ALIAS in properties and isinstance(properties[CLASS_ALIAS], dict): + logging.debug(f"Adding list of classes to: {model_cls}.") + + # Find keys that correspond to the same schema + base_keys = sorted(key for key, schema_cls in self._registry.items() if schema_cls is model_cls) + + base_schema = deepcopy(base_schema) + properties = base_schema["properties"] + self._add_literals_to_class_path(properties[CLASS_ALIAS], base_keys) + + # Get subclasses in registry sorted by module name + subclasses = sorted( + { + schema_cls + for _, schema_cls in self._registry.items() + if isinstance(schema_cls, type) and issubclass(schema_cls, model_cls) and schema_cls is not model_cls + }, + key=lambda cls: f"{cls.__module__}.{cls.__name__}", + ) + logging.debug(f"Subclasses found in registry: {subclasses}.") + + if not subclasses: + return base_schema + + # Generate schemas of subclasses + subschemas = [self.generate_inner(item.__pydantic_core_schema__) for item in subclasses] + + # TODO: get the schemas to work when using oneOf instead + merged: dict[str, Any] = {"anyOf": subschemas} + + for key in METADATA_KEYS: + if key in base_schema and key not in merged: + merged[key] = deepcopy(base_schema[key]) + + return merged + + def get_union_of_schemas(self, schemas: list[JsonSchemaValue]) -> JsonSchemaValue: + """ + Combine multiple JSON Schemas into a union schema. + + This override normalizes generated union schemas to use ``oneOf`` + instead of ``anyOf`` for improved downstream compatibility. + + Parameters + ---------- + schemas : list[JsonSchemaValue] + JSON Schemas to combine into a union. + + Returns + ------- + JsonSchemaValue + Combined union schema. + """ + + schema = super().get_union_of_schemas(schemas) + logging.debug(f"Modifying union schema for {schema}.") + + # TODO: make is possible to use oneOf + # if "anyOf" in schema: + # schema = deepcopy(schema) + # schema["oneOf"] = schema.pop("anyOf") + + return schema + + @staticmethod + def _add_literals_to_class_path(schema: dict[str, Any], literals: list[str]) -> None: + """ + Add allowed literal values to a JSON Schema string field. + + The provided literals are merged with any existing ``enum`` values in + the schema while preserving insertion order and removing duplicates. + + If the resulting set contains only a single value, the schema is + simplified by replacing ``enum`` with ``const``. + + The schema is modified in place. + + Parameters + ---------- + schema : dict[str, Any] + JSON Schema fragment representing a string-like field. + literals : list[str] + Literal values to add to the schema. + + Notes + ----- + Only schemas representing string values or existing enumerations are + modified. Empty literal lists are ignored. + """ + + if not literals: + return + + # Add registry keys as literals + if schema.get("type") == "string" or "enum" in schema: + existing = schema.get("enum", []) + merged = list(dict.fromkeys([*existing, *literals])) + schema["enum"] = merged + + # If only one value exists use const + if len(merged) == 1: + schema["const"] = merged[0] + schema.pop("enum", None) + + return diff --git a/tests/validation/test_generator.py b/tests/validation/test_generator.py new file mode 100644 index 000000000..e32521f96 --- /dev/null +++ b/tests/validation/test_generator.py @@ -0,0 +1,204 @@ +"""Tests of the schema generator.""" + +import json +import re +from collections.abc import Generator +from pathlib import Path + +import pytest +from pydantic import Field + +from pyaml.validation import ( + ConfigurationSchema, + SchemaGenerator, + SchemaRegistry, +) +from pyaml.validation.generator import RegistryJsonSchema + +# ========================================================== +# Dummy schemas +# ========================================================== + + +class DummySchema(ConfigurationSchema): + pass + + +class OtherSchema(ConfigurationSchema): + pass + + +class ParentSchema(ConfigurationSchema): + """Parent schema used to test inheritance.""" + + pass + + +class ChildSchemaA(ParentSchema): + a: int = 1 + + +class ChildSchemaB(ParentSchema): + b: str = "x" + + +class ContainerSchema(ConfigurationSchema): + model: ChildSchemaA | None = Field( + default=None, + description="Container schema used for testing.", + ) + + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +@pytest.fixture +def registry() -> SchemaRegistry: + return SchemaRegistry() + + +# ========================================================== +# Generate +# ========================================================== + + +def test_generate_raises_clean_keyerror_for_missing_schema(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + with pytest.raises( + KeyError, + match=re.escape(f"No schema registered for '{class_path}'"), + ): + SchemaGenerator.generate(class_path) + + +def test_generate_returns_schema_for_registered_class(registry: SchemaRegistry): + class_path = "pkg.module.Class" + + registry.register(class_path, DummySchema) + + schema = SchemaGenerator.generate(class_path) + + assert schema["title"] == "DummySchema" + + +# ========================================================== +# Save +# ========================================================== + + +def test_save_writes_schema_to_file(registry: SchemaRegistry, tmp_path: Path): + class_path = "pkg.module.Class" + registry.register(class_path, DummySchema) + + filename = tmp_path / "schema.json" + + result = SchemaGenerator.save(class_path, filename, indent=2) + + assert result == filename + assert json.loads(filename.read_text(encoding="utf-8")) == SchemaGenerator.generate(class_path) + + +# ========================================================== +# Registry-aware polymorphism +# ========================================================== + + +def test_generate_replaces_parent_schema_with_registered_subclasses( + registry: SchemaRegistry, +): + registry.register("pkg.module.Parent", ParentSchema) + registry.register("pkg.module.ChildA", ChildSchemaA) + registry.register("pkg.module.ChildB", ChildSchemaB) + + schema = SchemaGenerator.generate("pkg.module.Parent") + + child_refs = {item["$ref"] for item in schema.get("anyOf", [])} + + assert "#/$defs/ChildSchemaA" in child_refs + assert "#/$defs/ChildSchemaB" in child_refs + + +def test_model_schema_preserves_metadata_from_parent_schema( + registry: SchemaRegistry, +): + registry.register("pkg.module.Parent", ParentSchema) + registry.register("pkg.module.ChildA", ChildSchemaA) + registry.register("pkg.module.ChildB", ChildSchemaB) + + generator = RegistryJsonSchema() + base_schema = ParentSchema.__pydantic_core_schema__ + + schema = generator.model_schema(base_schema) + + assert schema["title"] == "ParentSchema" + + +# ========================================================== +# Literals +# ========================================================== + + +def test_add_literals_to_class_path_ignores_empty_literal_list(): + schema = {"type": "string"} + + RegistryJsonSchema._add_literals_to_class_path(schema, []) + + assert schema == {"type": "string"} + + +def test_add_literals_to_class_path_replaces_single_value_with_const(): + schema = {"type": "string"} + + RegistryJsonSchema._add_literals_to_class_path( + schema, + ["pkg.module.Parent"], + ) + + assert schema["const"] == "pkg.module.Parent" + assert "enum" not in schema + + +def test_add_literals_to_class_path_merges_existing_enum_and_removes_duplicates(): + schema = {"type": "string", "enum": ["pkg.module.Parent", "pkg.module.ChildA"]} + + RegistryJsonSchema._add_literals_to_class_path( + schema, + ["pkg.module.ChildA", "pkg.module.ChildB", "pkg.module.Parent"], + ) + + assert schema["enum"] == [ + "pkg.module.Parent", + "pkg.module.ChildA", + "pkg.module.ChildB", + ] + assert "const" not in schema + + +def test_add_literals_to_class_path_does_not_modify_non_string_schema_without_enum(): + schema = {"type": "integer"} + + RegistryJsonSchema._add_literals_to_class_path(schema, ["pkg.module.Parent"]) + + assert schema == {"type": "integer"} + + +def test_add_literals_to_class_path_updates_existing_enum_even_without_string_type(): + schema = {"enum": ["pkg.module.Parent"]} + + RegistryJsonSchema._add_literals_to_class_path(schema, ["pkg.module.ChildA"]) + + assert schema["enum"] == ["pkg.module.Parent", "pkg.module.ChildA"] + assert "const" not in schema From 1192097c6aeb4957b2eee6b7e251a60da7ca2a36 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 16:48:39 +0200 Subject: [PATCH 05/21] Remove unused parts from schema generator. --- pyaml/validation/generator.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/pyaml/validation/generator.py b/pyaml/validation/generator.py index 8cdaeced0..874ee4afa 100644 --- a/pyaml/validation/generator.py +++ b/pyaml/validation/generator.py @@ -219,34 +219,6 @@ def model_schema(self, schema: core_schema.ModelSchema) -> dict[str, Any]: return merged - def get_union_of_schemas(self, schemas: list[JsonSchemaValue]) -> JsonSchemaValue: - """ - Combine multiple JSON Schemas into a union schema. - - This override normalizes generated union schemas to use ``oneOf`` - instead of ``anyOf`` for improved downstream compatibility. - - Parameters - ---------- - schemas : list[JsonSchemaValue] - JSON Schemas to combine into a union. - - Returns - ------- - JsonSchemaValue - Combined union schema. - """ - - schema = super().get_union_of_schemas(schemas) - logging.debug(f"Modifying union schema for {schema}.") - - # TODO: make is possible to use oneOf - # if "anyOf" in schema: - # schema = deepcopy(schema) - # schema["oneOf"] = schema.pop("anyOf") - - return schema - @staticmethod def _add_literals_to_class_path(schema: dict[str, Any], literals: list[str]) -> None: """ From a0da9912dfb42fd11f86bb1c63417b08cbed3e39 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 10 Jun 2026 16:56:41 +0200 Subject: [PATCH 06/21] Add error of no class is given in the register schema decorator. --- pyaml/validation/registry.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyaml/validation/registry.py b/pyaml/validation/registry.py index b9f39cb69..ea1239f75 100644 --- a/pyaml/validation/registry.py +++ b/pyaml/validation/registry.py @@ -328,6 +328,9 @@ def register_schema( ... pass """ + if not (isinstance(schema, type) and issubclass(schema, ConfigurationSchema)): + raise TypeError("register_schema must be called with a schema class, e.g. @register_schema(MySchema)") + registry = SchemaRegistry() def decorator( @@ -335,6 +338,8 @@ def decorator( ) -> Type[ClassT]: class_path = f"{cls.__module__}.{cls.__name__}" + logger.debug("Register schema for %s.", class_path) + registry.register( class_path=class_path, schema=schema, From ead20b0c113bf0ec86864641e236c24938298e2e Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Mon, 15 Jun 2026 15:05:27 +0200 Subject: [PATCH 07/21] Added validation at object creation. --- pyaml/validation/__init__.py | 4 +- pyaml/validation/models.py | 186 +++++++++++++++++++++++++++++++- tests/validation/test_models.py | 162 +++++++++++++++++++++++++++- 3 files changed, 346 insertions(+), 6 deletions(-) diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py index 54710e235..1d90a08b6 100644 --- a/pyaml/validation/__init__.py +++ b/pyaml/validation/__init__.py @@ -3,14 +3,16 @@ """ from .generator import SchemaGenerator -from .models import ConfigurationSchema +from .models import ConfigurationSchema, DynamicValidation, StaticValidation from .registry import SchemaRegistry, register_schema from .validator import SchemaValidator __all__ = [ "ConfigurationSchema", + "DynamicValidation", "SchemaRegistry", "SchemaValidator", "SchemaGenerator", + "StaticValidation", "register_schema", ] diff --git a/pyaml/validation/models.py b/pyaml/validation/models.py index 978d12dbc..e561e509e 100644 --- a/pyaml/validation/models.py +++ b/pyaml/validation/models.py @@ -1,6 +1,12 @@ """Base datamodels for configuration.""" -from pydantic import BaseModel, ConfigDict, Field +import inspect +import logging +from typing import Any, get_type_hints + +from pydantic import BaseModel, ConfigDict, Field, create_model + +logger = logging.getLogger(__name__) class PyAMLBaseModel(BaseModel): @@ -26,7 +32,7 @@ class ConfigurationSchema(PyAMLBaseModel): """ Base model for configuration schemas. - Includes mandatory fields and functionality for all schemas which is to be registered in the :class:`SchemaRegistry`. + Provides common fields and functionality for schemas which are to be registered in the :class:`SchemaRegistry`. """ model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, arbitrary_types_allowed=False, extra="forbid") @@ -35,3 +41,179 @@ class ConfigurationSchema(PyAMLBaseModel): description="Fully qualified class path.", alias="class", ) + + +class ValidationSchema(PyAMLBaseModel): + """ + Base model for validation schemas. + + Provides common fields and functionality for schemas used to validate arguments during object creation. + """ + + model_config = ConfigDict(arbitrary_types_allowed=False, extra="forbid") + + +class ValidationMeta(type): + """ + Metaclass that validates constructor arguments using a Pydantic model. + + Classes using this metaclass must define a ``validation_model`` + attribute containing a subclass of :class:`pydantic.BaseModel`. + Before an instance is created, the supplied arguments are bound to + the ``__init__`` signature and validated against the model. + + Both positional and keyword arguments are validated before + ``__init__`` is executed. + """ + + def __call__(cls, *args: Any, **kwargs: Any): + """ + Create an instance after validating constructor arguments. + + The supplied arguments are bound to the class ``__init__`` signature, + default values are applied, and the resulting argument mapping is + validated using ``validation_model``. The validated values are then + passed to the constructor. + + Raises + ------ + TypeError + If the class does not define ``validation_model``. + + ValidationError + If the supplied arguments do not conform to the validation + model. + """ + + validation_model = getattr(cls, "validation_model", None) + + if validation_model is None: + raise TypeError(f"{cls.__name__} must define validation_model.") + + # Inspect the signature of the class + signature = inspect.signature(cls.__init__) + + # Map arguments to parameters + bound = signature.bind(None, *args, **kwargs) + + # Include default arguments + bound.apply_defaults() + + # Remove self from list + bound.arguments.pop("self", None) + arguments = dict(bound.arguments) + + # Validate the model + logger.debug("Validating input against schema: %s", validation_model.model_fields) + validated = validation_model.model_validate(arguments) + + # Return the object + return super().__call__(**validated.model_dump()) + + +class DynamicValidation(metaclass=ValidationMeta): + """ + Base class that generates a validation schema from the constructor + signature. + + When a subclass is defined, a schema derived from + :class:`ValidationSchema` is generated automatically and assigned to + ``validation_model``. The generated schema is used by + :class:`ValidationMeta` to validate constructor arguments before + instance creation. + + Subclasses must not define ``validation_model`` manually. + """ + + validation_model: type[ValidationSchema] | None = None + + def __init_subclass__(cls, **kwargs): + """ + Generate and attach a validation schema for the subclass. + + A schema derived from :class:`ValidationSchema` is created from the + subclass's ``__init__`` signature and assigned to + ``validation_model``. Defining ``validation_model`` explicitly is + not permitted and results in a :class:`TypeError`. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is not None: + raise TypeError(f"{cls.__name__} may not define validation_model manually.") + + cls.validation_model = cls._build_validation_model() + + @classmethod + def _build_validation_model(cls) -> type[ValidationSchema]: + """ + Build a validation schema from the constructor signature. + + The generated schema contains one field for each parameter in the + subclass's ``__init__`` method, excluding ``self``, ``*args`` and + ``**kwargs``. Field types are obtained from the constructor's type + annotations and default values are preserved. + + Returns + ------- + type[ValidationSchema] + A dynamically generated subclass of :class:`ValidationSchema` + representing the constructor arguments accepted by the subclass. + """ + + logger.debug("Building validation schema for %s.", f"{cls.__module__}.{cls.__name__}") + + signature = inspect.signature(cls.__init__) + type_hints = get_type_hints(cls.__init__) + + fields: dict[str, tuple[Any, Any]] = {} + + # Skip *args and **kwargs + for name, param in signature.parameters.items(): + if name == "self": + continue + + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + annotation = type_hints.get(name, Any) + default = param.default if param.default is not inspect._empty else ... + fields[name] = (annotation, default) + + model = create_model(f"{cls.__name__}ValidationSchema", **fields, __base__=ValidationSchema) + + logger.debug("Created model: %s", model.model_fields) + + return model + + +class StaticValidation(metaclass=ValidationMeta): + """ + Base class for explicit constructor validation. + + Subclasses must define a ``validation_model`` attribute containing a + subclass of :class:`pydantic.BaseModel`. The model is used by + :class:`ValidationMeta` to validate constructor arguments before + instance creation. + """ + + validation_model: type[BaseModel] + + def __init_subclass__(cls, **kwargs): + """ + Verify that the subclass defines a validation model. + + Raises + ------ + TypeError + If the subclass does not define a ``validation_model`` + attribute. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is None: + raise TypeError(f"{cls.__name__} must define validation_model.") diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index e39290565..e89d7675a 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -3,11 +3,15 @@ import json import pytest -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from pydantic.errors import PydanticSchemaGenerationError -from pyaml.validation import ConfigurationSchema -from pyaml.validation.models import PyAMLBaseModel +from pyaml.validation import ConfigurationSchema, DynamicValidation, StaticValidation +from pyaml.validation.models import PyAMLBaseModel, ValidationSchema + +# ========================================================== +# PyAMLBaseModel +# ========================================================== def test_model_dump_serializes_subclass_fields(): @@ -45,6 +49,11 @@ class Accelerator(PyAMLBaseModel): assert dumped == {"device": {"name": "QF", "type": "Quadrupole"}} +# ========================================================== +# ConfigurationSchema +# ========================================================== + + def test_configuration_schema_accepts_alias_class(): schema = ConfigurationSchema.model_validate({"class": "pkg.module.Class"}) @@ -85,3 +94,150 @@ def test_configuration_schema_dump_uses_alias_when_requested(): dumped = schema.model_dump(by_alias=True) assert dumped == {"class": "pkg.module.Class"} + + +# ========================================================== +# DynamicValidation +# ========================================================== + + +def test_dynamic_validation_builds_schema_from_init_signature(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + assert issubclass(MyClass.validation_model, ValidationSchema) + assert list(MyClass.validation_model.model_fields) == ["name", "count"] + + name_field = MyClass.validation_model.model_fields["name"] + count_field = MyClass.validation_model.model_fields["count"] + + assert name_field.annotation is str + assert name_field.is_required() + + assert count_field.annotation is int + assert count_field.default == 0 + + +def test_dynamic_validation_accepts_positional_and_keyword_arguments(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + obj1 = MyClass("test", 1) + obj2 = MyClass(name="test", count=1) + obj3 = MyClass("test") + + assert obj1.name == "test" + assert obj1.count == 1 + + assert obj2.name == "test" + assert obj2.count == 1 + + assert obj3.name == "test" + assert obj3.count == 0 + + +def test_dynamic_validation_coerces_and_rejects_invalid_input(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + obj = MyClass(name="test", count="12") + assert obj.count == 12 + + with pytest.raises(ValidationError): + MyClass(name="test", count="not-an-int") + + +def test_dynamic_validation_rejects_manual_validation_model(): + class ManualModel(BaseModel): + name: str + + with pytest.raises(TypeError, match="may not define validation_model manually"): + + class Broken(DynamicValidation): + validation_model = ManualModel + + def __init__(self, name: str): + self.name = name + + +# ========================================================== +# StaticValidation +# ========================================================== + + +def test_static_validation_accepts_explicit_basemodel(): + class ExampleSchema(BaseModel): + name: str + count: int = 0 + + class Example(StaticValidation): + validation_model = ExampleSchema + + def __init__(self, name: str, count: int = 0): + self.name = name + self.count = count + + obj1 = Example("test", 1) + obj2 = Example(name="test", count=1) + obj3 = Example("test") + + assert obj1.name == "test" + assert obj1.count == 1 + + assert obj2.name == "test" + assert obj2.count == 1 + + assert obj3.name == "test" + assert obj3.count == 0 + + +def test_static_validation_validates_and_coerces_input(): + class ExampleSchema(BaseModel): + name: str + count: int + + class Example(StaticValidation): + validation_model = ExampleSchema + + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + obj = Example(name="test", count="12") + assert obj.count == 12 + + with pytest.raises(ValidationError): + Example(name="test", count="not-an-int") + + +def test_static_validation_inherits_validation_model(): + class ParentSchema(BaseModel): + name: str + + class Parent(StaticValidation): + validation_model = ParentSchema + + def __init__(self, name: str): + self.name = name + + class Child(Parent): + def __init__(self, name: str): + super().__init__(name) + + obj = Child("test") + assert obj.name == "test" + assert Child.validation_model is ParentSchema + + +def test_static_validation_requires_a_validation_model(): + with pytest.raises(TypeError, match="must define validation_model"): + + class Broken(StaticValidation): + def __init__(self, name: str): + self.name = name From ebe239d3bd1b3d72331f1adbbc37ed887209262f Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Tue, 16 Jun 2026 17:31:23 +0200 Subject: [PATCH 08/21] Add arbitrary types allowed on validation schema since that was a bug. --- pyaml/validation/models.py | 2 +- tests/validation/test_models.py | 28 +++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/pyaml/validation/models.py b/pyaml/validation/models.py index e561e509e..43cb70695 100644 --- a/pyaml/validation/models.py +++ b/pyaml/validation/models.py @@ -50,7 +50,7 @@ class ValidationSchema(PyAMLBaseModel): Provides common fields and functionality for schemas used to validate arguments during object creation. """ - model_config = ConfigDict(arbitrary_types_allowed=False, extra="forbid") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") class ValidationMeta(type): diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index e89d7675a..b2ad86db2 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -78,7 +78,7 @@ def test_configuration_schema_forbids_extra_fields(): assert "extra_forbidden" in str(exc_info.value) -def test_models_do_not_allow_arbitrary_types(): +def test_configuration_schema_do_not_allow_arbitrary_types(): class ArbitraryType: pass @@ -96,6 +96,32 @@ def test_configuration_schema_dump_uses_alias_when_requested(): assert dumped == {"class": "pkg.module.Class"} +# ========================================================== +# ValidationSchema +# ========================================================== + + +class DummyDevice: + pass + + +class DummySchema(ValidationSchema): + device: DummyDevice + + +def test_validation_schema_allows_arbitrary_types(): + device = DummyDevice() + + schema = DummySchema.model_validate({"device": device}) + + assert schema.device is device + + +def test_validation_schema_forbids_extra_fields(): + with pytest.raises(ValidationError): + DummySchema.model_validate({"device": DummyDevice(), "extra_field": 123}) + + # ========================================================== # DynamicValidation # ========================================================== From e07b73a9f9a69da581374bd75327e780e17b6f74 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Tue, 16 Jun 2026 18:27:20 +0200 Subject: [PATCH 09/21] Change RFTransmitter to not have ConfigModel and include dynamic validation. --- pyaml/common/element.py | 119 +++++++++++++++++++++++---------- pyaml/configuration/factory.py | 14 ++-- pyaml/control/controlsystem.py | 4 +- pyaml/lattice/simulator.py | 4 +- pyaml/rf/rf_plant.py | 8 +-- pyaml/rf/rf_transmitter.py | 78 ++++++++++----------- 6 files changed, 135 insertions(+), 92 deletions(-) diff --git a/pyaml/common/element.py b/pyaml/common/element.py index 11ab1f02e..15e566faa 100644 --- a/pyaml/common/element.py +++ b/pyaml/common/element.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict @@ -12,21 +12,45 @@ def __pyaml_repr__(obj): """ Returns a string representation of a pyaml object """ - if hasattr(obj, "_cfg"): + + cls_name = obj.__class__.__name__ + + # Keep the old behavior when _cfg exists + cfg = getattr(obj, "_cfg", None) + if cfg is not None: if isinstance(obj, Element): - return repr(obj._cfg).replace( + return repr(cfg).replace( "ConfigModel(", - obj.__class__.__name__ + "(peer='" + obj.attached_to() + "', ", + f"{cls_name}(peer={obj.attached_to()!r}, ", + 1, ) - else: - # no peer - return repr(obj._cfg).replace("ConfigModel", obj.__class__.__name__) - else: - # Object is not yet fully constructed - if isinstance(obj, Element): - return f"{obj.__class__.__name__}: {obj.get_name()}" - else: - return f"{obj.__class__.__name__}" + return repr(cfg).replace("ConfigModel", cls_name, 1) + + # Generic fallback when there is no _cfg + attrs = {} + + # Instance attributes + for k, v in obj.__dict__.items(): + # Exclude private attributes + if not k.startswith("_"): + attrs[k] = v + + # Properties + for name, attr in vars(type(obj)).items(): + if isinstance(attr, property): + try: + attrs[name] = getattr(obj, name) + except Exception as e: + attrs[name] = f"" + + if isinstance(obj, Element) and "name" not in attrs: + try: + attrs["name"] = obj.get_name() + except Exception as e: + attrs["name"] = f"" + + parts = ", ".join(f"{k}={v!r}" for k, v in attrs.items()) + return f"{cls_name}({parts})" if parts else cls_name class ElementConfigModel(BaseModel): @@ -57,39 +81,64 @@ class ElementConfigModel(BaseModel): lattice_names: str | None = None -class Element(object): +class Element: """ Class providing access to one element of a physical or simulated lattice - - Attributes: - name: str - The unique name identifying the element in the configuration file """ - def __init__(self, name: str): - self._name: str = name - self._peer: "ElementHolder" = None # Peer: ControlSystem, Simulator + def __init__( + self, + name: str, + lattice_names: str | None = None, + description: str | None = None, + ): + self._name = name + self._lattice_names = lattice_names + self._description = description + self._peer: ElementHolder | None = None - def get_name(self) -> str: + def _cfg_value(self, attr: str, fallback: Any) -> Any: """ - Returns the name of the element + Return an attribute from _cfg if available, otherwise fallback. """ - return self._name + cfg = getattr(self, "_cfg", None) + if cfg is not None: + value = getattr(cfg, attr, None) + if value is not None: + return value + return fallback - def get_lattice_names(self) -> str: - """ - Returns the name of associated lattice element(s) - """ - if not hasattr(self, "_cfg"): - return self._name - else: - return self._cfg.lattice_names + @property + def name(self) -> str: + return self._cfg_value("name", self._name) + + @property + def lattice_names(self) -> str: + cfg = getattr(self, "_cfg", None) + + if cfg is not None and cfg.lattice_names is not None: + return cfg.lattice_names + + if self._lattice_names is not None: + return self._lattice_names + + return self.name - def get_description(self) -> str: + @property + def description(self) -> str | None: + return self._cfg_value("description", self._description) + + def get_name(self) -> str: """ - Returns the description of the element + Returns the name of the element """ - return self._cfg.description + return self.name + + def get_lattice_names(self) -> str | None: + return self.lattice_names + + def get_description(self) -> str | None: + return self.description def set_energy(self, E: float): """ diff --git a/pyaml/configuration/factory.py b/pyaml/configuration/factory.py index df63fbe78..e41d28345 100644 --- a/pyaml/configuration/factory.py +++ b/pyaml/configuration/factory.py @@ -161,7 +161,7 @@ class BuildInfo: ---------- module : ModuleType Imported module containing the object class and validation model. - config_cls : type[BaseModel] + config_cls : type[BaseModel], optional Pydantic model used to validate the configuration. class_str : str Name of the class to instantiate. @@ -174,7 +174,7 @@ class BuildInfo: """ module: ModuleType - config_cls: type[BaseModel] + config_cls: type[BaseModel] | None class_str: str field_locations: dict | None location_str: str @@ -248,8 +248,6 @@ def resolve_build_info(data: dict, ignore_external: bool) -> BuildInfo | None: # Get the validation class config_cls = getattr(module, validation_class_str, None) - if config_cls is None: - raise PyAMLConfigException(f"No validation class for '{module.__name__}.{class_str}' {location_str}") return BuildInfo( module=module, @@ -456,6 +454,8 @@ def _construct_element( try: if control_modes is None: + if isinstance(cfg, dict): + return elem_cls(**cfg) return elem_cls(cfg) return UnboundElement(elem_cls, module_name, control_modes, cfg) @@ -495,9 +495,11 @@ def build_object(self, data: dict, ignore_external: bool = False): cleaned_data, control_modes = self._strip_build_metadata(data) - # Validate the model try: - cfg = config_cls.model_validate(cleaned_data) + if config_cls is not None: + cfg = config_cls.model_validate(cleaned_data) + else: + cfg = cleaned_data except ValidationError as e: handle_validation_error(e, module.__name__, location_str, field_locations) diff --git a/pyaml/control/controlsystem.py b/pyaml/control/controlsystem.py index a622ebd8c..90bf84f8d 100644 --- a/pyaml/control/controlsystem.py +++ b/pyaml/control/controlsystem.py @@ -193,8 +193,8 @@ def fill_device(self, elements: list[Element]): attachedTrans: list[RFTransmitter] = [] if e._cfg.transmitters: for t in e._cfg.transmitters: - vDev = self.get_device_access(t._cfg.voltage) - pDev = self.get_device_access(t._cfg.phase) + vDev = self.get_device_access(t.voltage_str) + pDev = self.get_device_access(t.phase_str) voltage = RWRFVoltageScalar(t, vDev) phase = RWRFPhaseScalar(t, pDev) nt = t.attach(self, voltage, phase) diff --git a/pyaml/lattice/simulator.py b/pyaml/lattice/simulator.py index a877588aa..81a95d1ba 100644 --- a/pyaml/lattice/simulator.py +++ b/pyaml/lattice/simulator.py @@ -206,7 +206,7 @@ def fill_device(self, elements: list[Element]): attachedTrans: list[RFTransmitter] = [] for t in e._cfg.transmitters: cavsPerTrans: list[at.Element] = [] - for c in t._cfg.cavities: + for c in t.cavities: # Expect unique name for cavities cav = self.get_at_elems(Element(c)) if len(cav) > 1: @@ -214,7 +214,7 @@ def fill_device(self, elements: list[Element]): if len(cav) == 0: raise PyAMLException(f"RF transmitter {t.get_name()}, No cavity found") cavsPerTrans.append(cav[0]) - harmonics.append(t._cfg.harmonic) + harmonics.append(t.harmonic) voltage = RWRFVoltageScalar(cavsPerTrans) phase = RWRFPhaseScalar(cavsPerTrans) nt = t.attach(self, voltage, phase) diff --git a/pyaml/rf/rf_plant.py b/pyaml/rf/rf_plant.py index 87e8706a7..bba953b7f 100644 --- a/pyaml/rf/rf_plant.py +++ b/pyaml/rf/rf_plant.py @@ -76,19 +76,19 @@ def get(self) -> float: sum = 0 # Count only fundamental harmonic for t in self.__trans: - if t._cfg.harmonic == 1.0: + if t.harmonic == 1.0: sum += t.voltage.get() return sum def set(self, value: float): # Assume that sum of transmitter (fundamental harmonic) distribution is 1 for t in self.__trans: - if t._cfg.harmonic == 1.0: - v = value * t._cfg.distribution + if t.harmonic == 1.0: + v = value * t.distribution t.voltage.set(v) def set_and_wait(self, value: float): raise NotImplementedError("Not implemented yet.") def unit(self) -> str: - return self.__trans[0]._cfg.phase.unit() + return self.__trans[0].phase_device_access.unit() diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index 8eb942a6b..57ccb6494 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -1,56 +1,39 @@ -import numpy as np -from pydantic import BaseModel, ConfigDict - -try: - from typing import Self # Python 3.11+ -except ImportError: - from typing_extensions import Self # Python 3.10 and earlier +from copy import copy +from typing import Self from .. import PyAMLException from ..common import abstract -from ..common.element import Element, ElementConfigModel +from ..common.element import Element from ..control.deviceaccess import DeviceAccess +from ..validation import DynamicValidation # Define the main class name for this module PYAMLCLASS = "RFTransmitter" -class ConfigModel(ElementConfigModel): - """ - Configuration model for RF Transmitter. - - Attributes - ---------- - voltage : str or None, optional - Device to apply cavity voltage - phase : str or None, optional - Device to apply cavity phase - cavities : list[str] - List of cavity names connected to this transmitter - harmonic : float, optional - Harmonic frequency ratio, 1.0 for main frequency, by default 1.0 - distribution : float, optional - RF distribution (Part of the total RF voltage powered by this transmitter), - by default 1.0 - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - voltage: str | None = None - phase: str | None = None - cavities: list[str] - harmonic: float = 1.0 - distribution: float = 1.0 - - -class RFTransmitter(Element): +class RFTransmitter(Element, DynamicValidation): """ Class that handle a RF transmitter """ - def __init__(self, cfg: ConfigModel): - super().__init__(cfg.name) - self._cfg = cfg + def __init__( + self, + name: str, + cavities: list[str], + voltage: str | None = None, + phase: str | None = None, + harmonic: float = 1.0, + distribution: float = 1.0, + lattice_names: str | None = None, + description: str | None = None, + ): + super().__init__(name, lattice_names, description) + self.voltage_str = voltage + self.phase_str = phase + self.cavities = cavities + self.harmonic = harmonic + self.distribution = distribution + self.__voltage = None self.__phase = None @@ -70,7 +53,7 @@ def voltage(self) -> abstract.ReadWriteFloatScalar: If transmitter is unattached or has no voltage device defined """ if self.__voltage is None: - raise PyAMLException(f"{str(self)} is unattached or has no voltage device defined") + raise PyAMLException(f"{str(self.name)} is unattached or has no voltage device defined") return self.__voltage @property @@ -89,7 +72,7 @@ def phase(self) -> abstract.ReadWriteFloatScalar: If transmitter is unattached or has no phase device defined """ if self.__phase is None: - raise PyAMLException(f"{str(self)} is unattached or has no phase device defined") + raise PyAMLException(f"{str(self.name)} is unattached or has no phase device defined") return self.__phase def attach( @@ -116,7 +99,16 @@ def attach( A new attached instance of RFTransmitter """ # Attach voltage and phase attribute and returns a new reference - obj = self.__class__(self._cfg) + obj = type(self)( + name=self.name, + cavities=self.cavities, + voltage=self.voltage_str, + phase=self.phase_str, + harmonic=self.harmonic, + distribution=self.distribution, + lattice_names=self.lattice_names, + description=self.description, + ) obj.__voltage = voltage obj.__phase = phase obj._peer = peer From 24c2f0a51d5880ca72fd3244b2d15a934d0c8ec7 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 17 Jun 2026 18:19:47 +0200 Subject: [PATCH 10/21] Move RFPlant to version without ConfigModel. --- pyaml/common/element_holder.py | 2 +- pyaml/control/controlsystem.py | 10 +++---- pyaml/lattice/simulator.py | 6 ++--- pyaml/rf/rf_plant.py | 48 ++++++++++++++++++---------------- tests/rf/test_rf.py | 2 +- 5 files changed, 35 insertions(+), 33 deletions(-) diff --git a/pyaml/common/element_holder.py b/pyaml/common/element_holder.py index 22df8bc13..a06a06935 100644 --- a/pyaml/common/element_holder.py +++ b/pyaml/common/element_holder.py @@ -272,7 +272,7 @@ def get_rf_plant(self, name: str) -> RFPlant: def add_rf_plant(self, rf: RFPlant): self.__add(self.__RFPLANT, rf) - def add_rf_transnmitter(self, rf: RFTransmitter): + def add_rf_transmitter(self, rf: RFTransmitter): self.__add(self.__RFTRANSMITTER, rf) def get_rf_trasnmitter(self, name: str) -> RFTransmitter: diff --git a/pyaml/control/controlsystem.py b/pyaml/control/controlsystem.py index 90bf84f8d..a6d77310d 100644 --- a/pyaml/control/controlsystem.py +++ b/pyaml/control/controlsystem.py @@ -191,19 +191,19 @@ def fill_device(self, elements: list[Element]): elif isinstance(e, RFPlant): attachedTrans: list[RFTransmitter] = [] - if e._cfg.transmitters: - for t in e._cfg.transmitters: + if e.transmitters: + for t in e.transmitters: vDev = self.get_device_access(t.voltage_str) pDev = self.get_device_access(t.phase_str) voltage = RWRFVoltageScalar(t, vDev) phase = RWRFPhaseScalar(t, pDev) nt = t.attach(self, voltage, phase) - self.add_rf_transnmitter(nt) + self.add_rf_transmitter(nt) attachedTrans.append(nt) - fDev = self.get_device_access(e._cfg.masterclock) + fDev = self.get_device_access(e.masterclock) frequency = RWRFFrequencyScalar(e, fDev) - voltage = RWTotalVoltage(attachedTrans) if e._cfg.transmitters else None + voltage = RWTotalVoltage(attachedTrans) if e.transmitters else None ne = e.attach(self, frequency, voltage) self.add_rf_plant(ne) diff --git a/pyaml/lattice/simulator.py b/pyaml/lattice/simulator.py index 81a95d1ba..44a92df92 100644 --- a/pyaml/lattice/simulator.py +++ b/pyaml/lattice/simulator.py @@ -200,11 +200,11 @@ def fill_device(self, elements: list[Element]): self.add_bpm(e) elif isinstance(e, RFPlant): - if e._cfg.transmitters: + if e.transmitters: cavs: list[at.Element] = [] harmonics: list[float] = [] attachedTrans: list[RFTransmitter] = [] - for t in e._cfg.transmitters: + for t in e.transmitters: cavsPerTrans: list[at.Element] = [] for c in t.cavities: # Expect unique name for cavities @@ -218,7 +218,7 @@ def fill_device(self, elements: list[Element]): voltage = RWRFVoltageScalar(cavsPerTrans) phase = RWRFPhaseScalar(cavsPerTrans) nt = t.attach(self, voltage, phase) - self.add_rf_transnmitter(nt) + self.add_rf_transmitter(nt) cavs.extend(cavsPerTrans) attachedTrans.append(nt) diff --git a/pyaml/rf/rf_plant.py b/pyaml/rf/rf_plant.py index bba953b7f..69e8c7cee 100644 --- a/pyaml/rf/rf_plant.py +++ b/pyaml/rf/rf_plant.py @@ -1,49 +1,45 @@ -import numpy as np -from pydantic import BaseModel, ConfigDict - -try: - from typing import Self # Python 3.11+ -except ImportError: - from typing_extensions import Self # Python 3.10 and earlier +from typing import Self from .. import PyAMLException from ..common import abstract -from ..common.element import Element, ElementConfigModel -from ..control.deviceaccess import DeviceAccess +from ..common.element import Element +from ..validation import DynamicValidation from .rf_transmitter import RFTransmitter # Define the main class name for this module PYAMLCLASS = "RFPlant" -class ConfigModel(ElementConfigModel): - masterclock: str | None = None - """Device to apply main RF frequency""" - transmitters: list[RFTransmitter] | None = None - """List of RF trasnmitters""" - - -class RFPlant(Element): +class RFPlant(Element, DynamicValidation): """ Main RF object """ - def __init__(self, cfg: ConfigModel): - super().__init__(cfg.name) - self._cfg = cfg + def __init__( + self, + name: str, + masterclock: str | None = None, + transmitters: list[RFTransmitter] | None = None, + lattice_names: str | None = None, + description: str | None = None, + ): + super().__init__(name, lattice_names, description) + + self.masterclock = masterclock + self.transmitters = transmitters self.__frequency = None self.__voltage = None @property def frequency(self) -> abstract.ReadWriteFloatScalar: if self.__frequency is None: - raise PyAMLException(f"{str(self)} has no masterclock device defined") + raise PyAMLException(f"{str(self.name)} has no masterclock device defined") return self.__frequency @property def voltage(self) -> abstract.ReadWriteFloatScalar: if self.__voltage is None: - raise PyAMLException(f"{str(self)} has no trasmitter device defined") + raise PyAMLException(f"{str(self.name)} has no transmitter device defined") return self.__voltage def attach( @@ -53,7 +49,13 @@ def attach( voltage: abstract.ReadWriteFloatScalar, ) -> Self: # Attach frequency attribute and returns a new reference - obj = self.__class__(self._cfg) + obj = self.__class__( + name=self.name, + masterclock=self.masterclock, + transmitters=self.transmitters, + lattice_names=self.lattice_names, + description=self.description, + ) obj.__frequency = frequency obj.__voltage = voltage obj._peer = peer diff --git a/tests/rf/test_rf.py b/tests/rf/test_rf.py index 5752d4d07..964e25e7b 100644 --- a/tests/rf/test_rf.py +++ b/tests/rf/test_rf.py @@ -107,7 +107,7 @@ def test_rf_multi_notrans(install_test_package): RF.frequency.set(3.523e8) with pytest.raises(PyAMLException) as exc: RF.voltage.set(10e6) - assert "has no trasmitter device defined" in str(exc) + assert "has no transmitter device defined" in str(exc) # Check that frequency and voltage has been applied on the masterclock device assert np.isclose(RF.frequency.get(), 3.523e8) From c8224a93ae4fd3d33fb4e4012fd3e28c7a1147ea Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Wed, 17 Jun 2026 18:20:10 +0200 Subject: [PATCH 11/21] Remove unused import from rf_transmitter. --- pyaml/rf/rf_transmitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index 57ccb6494..fbbae6a61 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -4,7 +4,6 @@ from .. import PyAMLException from ..common import abstract from ..common.element import Element -from ..control.deviceaccess import DeviceAccess from ..validation import DynamicValidation # Define the main class name for this module From bd284920f3c67fb354a2c25d5506e2ec7276c01d Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 18 Jun 2026 10:17:00 +0200 Subject: [PATCH 12/21] Change attach to use copy for RF transmitter and plant. --- pyaml/rf/rf_plant.py | 9 ++------- pyaml/rf/rf_transmitter.py | 13 ++----------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/pyaml/rf/rf_plant.py b/pyaml/rf/rf_plant.py index 69e8c7cee..a9ad488b9 100644 --- a/pyaml/rf/rf_plant.py +++ b/pyaml/rf/rf_plant.py @@ -1,3 +1,4 @@ +import copy from typing import Self from .. import PyAMLException @@ -49,13 +50,7 @@ def attach( voltage: abstract.ReadWriteFloatScalar, ) -> Self: # Attach frequency attribute and returns a new reference - obj = self.__class__( - name=self.name, - masterclock=self.masterclock, - transmitters=self.transmitters, - lattice_names=self.lattice_names, - description=self.description, - ) + obj = copy.copy(self) obj.__frequency = frequency obj.__voltage = voltage obj._peer = peer diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index fbbae6a61..e8378e9fc 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -1,4 +1,4 @@ -from copy import copy +import copy from typing import Self from .. import PyAMLException @@ -98,16 +98,7 @@ def attach( A new attached instance of RFTransmitter """ # Attach voltage and phase attribute and returns a new reference - obj = type(self)( - name=self.name, - cavities=self.cavities, - voltage=self.voltage_str, - phase=self.phase_str, - harmonic=self.harmonic, - distribution=self.distribution, - lattice_names=self.lattice_names, - description=self.description, - ) + obj = copy.copy(self) obj.__voltage = voltage obj.__phase = phase obj._peer = peer From 6825fb5fb8467886c5f9d56701e85f4ab6a576bf Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 18 Jun 2026 10:20:19 +0200 Subject: [PATCH 13/21] Change voltage_str and phase_str to voltage_name and phase_name in RF transmitter. --- pyaml/control/controlsystem.py | 4 ++-- pyaml/rf/rf_transmitter.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyaml/control/controlsystem.py b/pyaml/control/controlsystem.py index a6d77310d..fdde41718 100644 --- a/pyaml/control/controlsystem.py +++ b/pyaml/control/controlsystem.py @@ -193,8 +193,8 @@ def fill_device(self, elements: list[Element]): attachedTrans: list[RFTransmitter] = [] if e.transmitters: for t in e.transmitters: - vDev = self.get_device_access(t.voltage_str) - pDev = self.get_device_access(t.phase_str) + vDev = self.get_device_access(t.voltage_name) + pDev = self.get_device_access(t.phase_name) voltage = RWRFVoltageScalar(t, vDev) phase = RWRFPhaseScalar(t, pDev) nt = t.attach(self, voltage, phase) diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index e8378e9fc..2dcdd5098 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -27,8 +27,8 @@ def __init__( description: str | None = None, ): super().__init__(name, lattice_names, description) - self.voltage_str = voltage - self.phase_str = phase + self.voltage_name = voltage + self.phase_name = phase self.cavities = cavities self.harmonic = harmonic self.distribution = distribution From 0f627547464f7b1bf19490e3e106c1a8dce2079c Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 10:55:21 +0200 Subject: [PATCH 14/21] Separate models into two modules and rename ValidationSchema to ValidationModel. --- pyaml/validation/__init__.py | 3 +- pyaml/validation/configuration_models.py | 101 +++++++++++ pyaml/validation/models.py | 219 ----------------------- pyaml/validation/validation_models.py | 166 +++++++++++++++++ tests/validation/test_models.py | 33 ++-- 5 files changed, 286 insertions(+), 236 deletions(-) create mode 100644 pyaml/validation/configuration_models.py delete mode 100644 pyaml/validation/models.py create mode 100644 pyaml/validation/validation_models.py diff --git a/pyaml/validation/__init__.py b/pyaml/validation/__init__.py index 1d90a08b6..be0ca3a05 100644 --- a/pyaml/validation/__init__.py +++ b/pyaml/validation/__init__.py @@ -2,9 +2,10 @@ PyAML validation subpackage. """ +from .configuration_models import ConfigurationSchema from .generator import SchemaGenerator -from .models import ConfigurationSchema, DynamicValidation, StaticValidation from .registry import SchemaRegistry, register_schema +from .validation_models import DynamicValidation, StaticValidation from .validator import SchemaValidator __all__ = [ diff --git a/pyaml/validation/configuration_models.py b/pyaml/validation/configuration_models.py new file mode 100644 index 000000000..dc9bc1048 --- /dev/null +++ b/pyaml/validation/configuration_models.py @@ -0,0 +1,101 @@ +"""Datamodels for configuration.""" + +import importlib +import logging +from typing import ClassVar + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class PyAMLBaseModel(BaseModel): + """ + Base model for pyAML schemas. + + Overrides ``model_dump()`` and ``model_dump_json()`` to enable + ``serialize_as_any=True`` by default. This ensures that fields are + serialized according to their runtime type rather than their declared + annotation type. + """ + + def model_dump(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + def model_dump_json(self, **kwargs): + kwargs.setdefault("serialize_as_any", True) + return super().model_dump_json(**kwargs) + + +class ConfigurationSchema(PyAMLBaseModel): + """ + Base model for validating externally supplied configuration data. + + Each configuration schema defines the expected input for configuring a + specific object and includes a ``class`` field containing its fully + qualified class path. + """ + + model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, arbitrary_types_allowed=False, extra="forbid") + + class_path: str = Field( + description="Fully qualified class path.", + alias="class", + ) + + +class ModuleConfigurationSchema(PyAMLBaseModel): + """ + Base model for validating externally supplied configuration data. + + This schema exists to support legacy module-based configurations. It + defines the expected input for configuring a specific object, with the + target class resolved from the module's ``PYAMLCLASS`` attribute. + """ + + model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, extra="forbid") + + MODULE_PATH_ALIASES: ClassVar[tuple[str, ...]] = ("module", "type") + + module_path: str = Field( + description="Fully qualified module path.", + alias=AliasChoices(*MODULE_PATH_ALIASES), + ) + + def to_configuration(self) -> ConfigurationSchema: + """ + Convert the module-based configuration to a ``ConfigurationSchema``. + + Imports the referenced module, resolves the target class from its + ``PYAMLCLASS`` attribute, and returns an equivalent + :class:`ConfigurationSchema`. Any additional configuration fields are + preserved. + + Returns + ------- + ConfigurationSchema + Configuration schema with the resolved fully qualified class path. + + Raises + ------ + ImportError + If the referenced module cannot be imported. + ValueError + If the module does not define ``PYAMLCLASS``. + """ + + module = importlib.import_module(self.module_path) + + try: + class_name = module.PYAMLCLASS + except AttributeError as e: + raise ValueError(f"Module '{self.module_path}' does not define PYAMLCLASS.") from e + + return ConfigurationSchema.model_validate( + { + "class_path": f"{self.module_path}.{class_name}", + **(self.model_extra or {}), + }, + extra="allow", + ) diff --git a/pyaml/validation/models.py b/pyaml/validation/models.py deleted file mode 100644 index 43cb70695..000000000 --- a/pyaml/validation/models.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Base datamodels for configuration.""" - -import inspect -import logging -from typing import Any, get_type_hints - -from pydantic import BaseModel, ConfigDict, Field, create_model - -logger = logging.getLogger(__name__) - - -class PyAMLBaseModel(BaseModel): - """ - Base model for pyAML. - - Overrides ``model_dump()`` and ``model_dump_json()`` to enable - ``serialize_as_any=True`` by default. This ensures that fields are - serialized according to their runtime type rather than their declared - annotation type. - """ - - def model_dump(self, **kwargs): - kwargs.setdefault("serialize_as_any", True) - return super().model_dump(**kwargs) - - def model_dump_json(self, **kwargs): - kwargs.setdefault("serialize_as_any", True) - return super().model_dump_json(**kwargs) - - -class ConfigurationSchema(PyAMLBaseModel): - """ - Base model for configuration schemas. - - Provides common fields and functionality for schemas which are to be registered in the :class:`SchemaRegistry`. - """ - - model_config = ConfigDict(validate_by_name=True, validate_by_alias=True, arbitrary_types_allowed=False, extra="forbid") - - class_path: str = Field( - description="Fully qualified class path.", - alias="class", - ) - - -class ValidationSchema(PyAMLBaseModel): - """ - Base model for validation schemas. - - Provides common fields and functionality for schemas used to validate arguments during object creation. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - -class ValidationMeta(type): - """ - Metaclass that validates constructor arguments using a Pydantic model. - - Classes using this metaclass must define a ``validation_model`` - attribute containing a subclass of :class:`pydantic.BaseModel`. - Before an instance is created, the supplied arguments are bound to - the ``__init__`` signature and validated against the model. - - Both positional and keyword arguments are validated before - ``__init__`` is executed. - """ - - def __call__(cls, *args: Any, **kwargs: Any): - """ - Create an instance after validating constructor arguments. - - The supplied arguments are bound to the class ``__init__`` signature, - default values are applied, and the resulting argument mapping is - validated using ``validation_model``. The validated values are then - passed to the constructor. - - Raises - ------ - TypeError - If the class does not define ``validation_model``. - - ValidationError - If the supplied arguments do not conform to the validation - model. - """ - - validation_model = getattr(cls, "validation_model", None) - - if validation_model is None: - raise TypeError(f"{cls.__name__} must define validation_model.") - - # Inspect the signature of the class - signature = inspect.signature(cls.__init__) - - # Map arguments to parameters - bound = signature.bind(None, *args, **kwargs) - - # Include default arguments - bound.apply_defaults() - - # Remove self from list - bound.arguments.pop("self", None) - arguments = dict(bound.arguments) - - # Validate the model - logger.debug("Validating input against schema: %s", validation_model.model_fields) - validated = validation_model.model_validate(arguments) - - # Return the object - return super().__call__(**validated.model_dump()) - - -class DynamicValidation(metaclass=ValidationMeta): - """ - Base class that generates a validation schema from the constructor - signature. - - When a subclass is defined, a schema derived from - :class:`ValidationSchema` is generated automatically and assigned to - ``validation_model``. The generated schema is used by - :class:`ValidationMeta` to validate constructor arguments before - instance creation. - - Subclasses must not define ``validation_model`` manually. - """ - - validation_model: type[ValidationSchema] | None = None - - def __init_subclass__(cls, **kwargs): - """ - Generate and attach a validation schema for the subclass. - - A schema derived from :class:`ValidationSchema` is created from the - subclass's ``__init__`` signature and assigned to - ``validation_model``. Defining ``validation_model`` explicitly is - not permitted and results in a :class:`TypeError`. - """ - - super().__init_subclass__(**kwargs) - - if getattr(cls, "validation_model", None) is not None: - raise TypeError(f"{cls.__name__} may not define validation_model manually.") - - cls.validation_model = cls._build_validation_model() - - @classmethod - def _build_validation_model(cls) -> type[ValidationSchema]: - """ - Build a validation schema from the constructor signature. - - The generated schema contains one field for each parameter in the - subclass's ``__init__`` method, excluding ``self``, ``*args`` and - ``**kwargs``. Field types are obtained from the constructor's type - annotations and default values are preserved. - - Returns - ------- - type[ValidationSchema] - A dynamically generated subclass of :class:`ValidationSchema` - representing the constructor arguments accepted by the subclass. - """ - - logger.debug("Building validation schema for %s.", f"{cls.__module__}.{cls.__name__}") - - signature = inspect.signature(cls.__init__) - type_hints = get_type_hints(cls.__init__) - - fields: dict[str, tuple[Any, Any]] = {} - - # Skip *args and **kwargs - for name, param in signature.parameters.items(): - if name == "self": - continue - - if param.kind in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): - continue - - annotation = type_hints.get(name, Any) - default = param.default if param.default is not inspect._empty else ... - fields[name] = (annotation, default) - - model = create_model(f"{cls.__name__}ValidationSchema", **fields, __base__=ValidationSchema) - - logger.debug("Created model: %s", model.model_fields) - - return model - - -class StaticValidation(metaclass=ValidationMeta): - """ - Base class for explicit constructor validation. - - Subclasses must define a ``validation_model`` attribute containing a - subclass of :class:`pydantic.BaseModel`. The model is used by - :class:`ValidationMeta` to validate constructor arguments before - instance creation. - """ - - validation_model: type[BaseModel] - - def __init_subclass__(cls, **kwargs): - """ - Verify that the subclass defines a validation model. - - Raises - ------ - TypeError - If the subclass does not define a ``validation_model`` - attribute. - """ - - super().__init_subclass__(**kwargs) - - if getattr(cls, "validation_model", None) is None: - raise TypeError(f"{cls.__name__} must define validation_model.") diff --git a/pyaml/validation/validation_models.py b/pyaml/validation/validation_models.py new file mode 100644 index 000000000..129710428 --- /dev/null +++ b/pyaml/validation/validation_models.py @@ -0,0 +1,166 @@ +"""Classes for validation during object creation.""" + +import inspect +import logging +from typing import Any + +from pydantic import BaseModel, ConfigDict, create_model + +from .configuration_models import PyAMLBaseModel +from .schema_builder import _fields_from_constructor_signature + +logger = logging.getLogger(__name__) + + +class ValidationModel(PyAMLBaseModel): + """ + Base model for validating object constructor arguments. + + Each validation model defines the expected input for constructing a + specific object. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + +class ValidationMeta(type): + """ + Metaclass that validates constructor arguments before object creation. + + Classes using this metaclass must define a ``validation_model`` + attribute containing a subclass of :class:`pydantic.BaseModel`. + Whenever an instance is created, the supplied constructor arguments + are validated against the model before the class constructor is + invoked. + """ + + def __call__(cls, *args: Any, **kwargs: Any): + """ + Create an instance after validating constructor arguments. + + The supplied arguments are bound to the class ``__init__`` signature, + default values are applied, and the resulting argument mapping is + validated using ``validation_model``. The validated values are then + passed to the constructor. + + Raises + ------ + TypeError + If the class does not define ``validation_model``. + + ValidationError + If the supplied arguments do not conform to the validation + model. + """ + + validation_model = getattr(cls, "validation_model", None) + + if validation_model is None: + raise TypeError(f"{cls.__name__} must define validation_model.") + + # Inspect the signature of the class + signature = inspect.signature(cls.__init__) + + # Map arguments to parameters + bound = signature.bind(None, *args, **kwargs) + + # Include default arguments + bound.apply_defaults() + + # Remove self from list + bound.arguments.pop("self", None) + arguments = dict(bound.arguments) + + # Validate the model + logger.debug("Validating input against schema: %s", validation_model.model_fields) + validated = validation_model.model_validate(arguments) + + # Return the object + return super().__call__(**validated.model_dump()) + + +class DynamicValidation(metaclass=ValidationMeta): + """ + Base class for automatic constructor argument validation. + + When a subclass is defined, a validation model is generated + automatically from its constructor signature and assigned to + ``validation_model``. The generated model is then used to validate + constructor arguments before object creation. + + Subclasses must not define ``validation_model`` manually. + """ + + validation_model: type[ValidationModel] | None = None + + def __init_subclass__(cls, **kwargs): + """ + Generate and attach a validation model for the subclass. + + A validation model is generated from the subclass's constructor + signature and assigned to ``validation_model``. Defining + ``validation_model`` explicitly is not permitted and results in a + :class:`TypeError`. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is not None: + raise TypeError(f"{cls.__name__} may not define validation_model manually.") + + cls.validation_model = cls._build_validation_model() + + @classmethod + def _build_validation_model(cls) -> type[ValidationModel]: + """ + Generate a validation model from the constructor signature. + + The generated model contains one field for each parameter in the + subclass's ``__init__`` method, excluding ``self``, ``*args``, and + ``**kwargs``. Field types are obtained from the constructor's type + annotations and default values are preserved. + + Returns + ------- + type[ValidationModel] + A dynamically generated subclass of :class:`ValidationModel` + representing the constructor arguments accepted by the subclass. + """ + + logger.debug("Building validation model for %s.", f"{cls.__module__}.{cls.__name__}") + + fields = _fields_from_constructor_signature(cls, expand_arbitrary_types=False) + + model = create_model(f"{cls.__name__}ValidationModel", **fields, __base__=ValidationModel) + + logger.debug("Created model: %s", model.model_fields) + + return model + + +class StaticValidation(metaclass=ValidationMeta): + """ + Base class for explicit constructor argument validation. + + Subclasses must define a ``validation_model`` attribute containing a + subclass of :class:`pydantic.BaseModel`. The model is used to validate + constructor arguments before object creation. + """ + + validation_model: type[BaseModel] + + def __init_subclass__(cls, **kwargs): + """ + Verify that the subclass defines a validation model. + + Raises + ------ + TypeError + If the subclass does not define a ``validation_model`` + attribute. + """ + + super().__init_subclass__(**kwargs) + + if getattr(cls, "validation_model", None) is None: + raise TypeError(f"{cls.__name__} must define validation_model.") diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index b2ad86db2..6bab84987 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -7,7 +7,8 @@ from pydantic.errors import PydanticSchemaGenerationError from pyaml.validation import ConfigurationSchema, DynamicValidation, StaticValidation -from pyaml.validation.models import PyAMLBaseModel, ValidationSchema +from pyaml.validation.configuration_models import PyAMLBaseModel +from pyaml.validation.validation_models import ValidationModel # ========================================================== # PyAMLBaseModel @@ -97,7 +98,7 @@ def test_configuration_schema_dump_uses_alias_when_requested(): # ========================================================== -# ValidationSchema +# ValidationModel # ========================================================== @@ -105,21 +106,21 @@ class DummyDevice: pass -class DummySchema(ValidationSchema): +class DummyModel(ValidationModel): device: DummyDevice -def test_validation_schema_allows_arbitrary_types(): +def test_validation_model_allows_arbitrary_types(): device = DummyDevice() - schema = DummySchema.model_validate({"device": device}) + schema = DummyModel.model_validate({"device": device}) assert schema.device is device -def test_validation_schema_forbids_extra_fields(): +def test_validation_model_forbids_extra_fields(): with pytest.raises(ValidationError): - DummySchema.model_validate({"device": DummyDevice(), "extra_field": 123}) + DummyModel.model_validate({"device": DummyDevice(), "extra_field": 123}) # ========================================================== @@ -127,13 +128,13 @@ def test_validation_schema_forbids_extra_fields(): # ========================================================== -def test_dynamic_validation_builds_schema_from_init_signature(): +def test_dynamic_validation_builds_model_from_init_signature(): class MyClass(DynamicValidation): def __init__(self, name: str, count: int = 0): self.name = name self.count = count - assert issubclass(MyClass.validation_model, ValidationSchema) + assert issubclass(MyClass.validation_model, ValidationModel) assert list(MyClass.validation_model.model_fields) == ["name", "count"] name_field = MyClass.validation_model.model_fields["name"] @@ -198,12 +199,12 @@ def __init__(self, name: str): def test_static_validation_accepts_explicit_basemodel(): - class ExampleSchema(BaseModel): + class ExampleModel(BaseModel): name: str count: int = 0 class Example(StaticValidation): - validation_model = ExampleSchema + validation_model = ExampleModel def __init__(self, name: str, count: int = 0): self.name = name @@ -224,12 +225,12 @@ def __init__(self, name: str, count: int = 0): def test_static_validation_validates_and_coerces_input(): - class ExampleSchema(BaseModel): + class ExampleModel(BaseModel): name: str count: int class Example(StaticValidation): - validation_model = ExampleSchema + validation_model = ExampleModel def __init__(self, name: str, count: int): self.name = name @@ -243,11 +244,11 @@ def __init__(self, name: str, count: int): def test_static_validation_inherits_validation_model(): - class ParentSchema(BaseModel): + class ParentModel(BaseModel): name: str class Parent(StaticValidation): - validation_model = ParentSchema + validation_model = ParentModel def __init__(self, name: str): self.name = name @@ -258,7 +259,7 @@ def __init__(self, name: str): obj = Child("test") assert obj.name == "test" - assert Child.validation_model is ParentSchema + assert Child.validation_model is ParentModel def test_static_validation_requires_a_validation_model(): From 50b8cc569cd6d330c7794ba1236aea8fa7c6e6f8 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 11:16:04 +0200 Subject: [PATCH 15/21] Add functionality to build schemas dynamically. --- pyaml/validation/schema_builder.py | 320 ++++++++++++++++++++++++ tests/validation/test_schema_builder.py | 173 +++++++++++++ 2 files changed, 493 insertions(+) create mode 100644 pyaml/validation/schema_builder.py create mode 100644 tests/validation/test_schema_builder.py diff --git a/pyaml/validation/schema_builder.py b/pyaml/validation/schema_builder.py new file mode 100644 index 000000000..3281091cf --- /dev/null +++ b/pyaml/validation/schema_builder.py @@ -0,0 +1,320 @@ +"""Functionality for dynamically generating configuration schemas.""" + +import inspect +import types +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from functools import reduce +from pathlib import Path +from typing import Annotated, Any, Literal, Union, get_args, get_origin, get_type_hints +from uuid import UUID + +from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +from .configuration_models import ConfigurationSchema +from .registry import SchemaRegistry + +RESERVED_CONFIGURATION_FIELDS = {"class_path"} + +SUPPORTED_TYPES = ( + int, + str, + float, + bool, + bytes, + dict, + list, + tuple, + set, + frozenset, + type(None), + datetime, + date, + time, + timedelta, + Decimal, + Path, + UUID, +) + + +def generate_configuration_schema(source: type) -> type[ConfigurationSchema]: + """ + Generate a configuration schema for a class or Pydantic model. + + If the source defines a ``validation_model`` attribute containing a + Pydantic model, the configuration schema is generated from that model. + Otherwise, the schema is generated from the source class constructor + signature. + + Generated schemas are registered in the :class:`SchemaRegistry` and + reused on subsequent requests. + """ + + if not isinstance(source, type): + raise TypeError("Source must be a class.") + + registry = SchemaRegistry() + class_path = f"{source.__module__}.{source.__name__}" + + existing = registry.get(class_path) + if existing is not None: + return existing + + # Check if the class has a validation model + validation_model = getattr(source, "validation_model", None) + + if isinstance(validation_model, type) and issubclass(validation_model, BaseModel): + schema = _configuration_schema_from_basemodel( + validation_model, + _generate_schema_name(source), + source.__module__, + ) + else: + schema = _configuration_schema_from_constructor(source) + + registry.register(class_path, schema) + + return schema + + +def _generate_schema_name(source: type) -> str: + """ + Generate the default configuration schema name for a source class. + """ + return f"{source.__name__}ConfigurationSchema" + + +def _configuration_schema_from_basemodel( + validation_model: type[BaseModel], + schema_name: str, + module_name: str, +) -> type[ConfigurationSchema]: + """ + Generate a configuration schema from a Pydantic model. + + The resulting schema contains one field for each field defined on the + validation model, with field metadata and defaults preserved. + """ + + if not isinstance(validation_model, type) or not issubclass(validation_model, BaseModel): + raise TypeError("validation_model must be a subclass of pydantic.BaseModel.") + + fields: dict[str, tuple[object, object]] = {} + + for field_name, field_info in validation_model.model_fields.items(): + if field_name in RESERVED_CONFIGURATION_FIELDS: + raise ValueError( + f"{validation_model.__name__} defines reserved field {field_name!r}, which is owned by ConfigurationSchema." + ) + + fields[field_name] = _field_definition_from_field_info(field_name, field_info) + + return create_model( + schema_name, + __base__=ConfigurationSchema, + __module__=module_name, + **fields, + ) + + +def _field_definition_from_field_info(field_name: str, field_info: FieldInfo) -> tuple[object, object]: + """ + Convert a Pydantic field definition into a ``create_model`` field tuple. + """ + + # Get the annotation + annotation = _resolve_annotation(field_info.annotation) + + # Handle default + default = field_info.default + if default is PydanticUndefined: + default = ... + + # Handle default factory + default_factory = field_info.default_factory + + # Collect metadata + field_kwargs = _field_kwargs(field_name, field_info) + + if default_factory is not None: + return annotation, Field(default_factory=default_factory, **field_kwargs) + elif field_kwargs: + return annotation, Field(default, **field_kwargs) + else: + return annotation, default + + +def _resolve_annotation(annotation: object) -> object: + """ + Resolve an annotation into a schema-friendly type. + + Supported annotations are passed through directly, while supported + generic types are resolved recursively into equivalent type hints. + """ + + if annotation is inspect._empty: + return Any + + if annotation is None: + return type(None) + + if isinstance(annotation, str): + raise TypeError(f"Unable to resolve annotation {annotation!r}. Forward references are not currently supported.") + + # Check if is a generic type + origin = get_origin(annotation) + + # If not a generic type + if origin is None: + if isinstance(annotation, type): + if annotation in SUPPORTED_TYPES: + return annotation + + if issubclass(annotation, Enum): + return annotation + + if issubclass(annotation, ConfigurationSchema): + return annotation + + return generate_configuration_schema(annotation) + + raise TypeError(f"Unsupported annotation: {annotation!r}") + + # If is generic type + args = get_args(annotation) + + if origin is Annotated: + return Annotated[_resolve_annotation(args[0]), *args[1:]] # type: ignore[index] + + elif origin is Literal: + return annotation + + elif origin is list: + return list[_resolve_annotation(args[0])] + + elif origin is dict: + return dict[_resolve_annotation(args[0]), _resolve_annotation(args[1])] + + elif origin is tuple: + # Handle variable length tuple + if len(args) == 2 and args[1] is Ellipsis: + return tuple[_resolve_annotation(args[0]), ...] + + resolved_args = tuple(_resolve_annotation(arg) for arg in args) + return tuple[resolved_args] + + elif origin is set: + return set[_resolve_annotation(args[0])] + + elif origin is frozenset: + return frozenset[_resolve_annotation(args[0])] + + elif origin is Union or origin is types.UnionType: + resolved_args = tuple(_resolve_annotation(arg) for arg in args) + return reduce(lambda a, b: a | b, resolved_args[1:], resolved_args[0]) + + else: + raise TypeError(f"Unsupported generic annotation: {annotation!r}") + + +def _field_kwargs(field_name: str, field_info: object) -> dict[str, object]: + """ + Collect supported field metadata for ``pydantic.create_model``. + """ + + kwargs: dict[str, object] = {} + + description = getattr(field_info, "description", None) + if description is not None: + kwargs["description"] = description + + title = getattr(field_info, "title", None) + if title is not None: + kwargs["title"] = title + + examples = getattr(field_info, "examples", None) + if examples: + kwargs["examples"] = examples + + deprecated = getattr(field_info, "deprecated", None) + if deprecated is not None: + kwargs["deprecated"] = deprecated + + alias = getattr(field_info, "alias", None) + if alias and alias != field_name: + kwargs["alias"] = alias + + return kwargs + + +def _configuration_schema_from_constructor(cls: type) -> type[ConfigurationSchema]: + """ + Generate a configuration schema from a class constructor signature. + + The resulting schema contains one field for each constructor parameter, + with annotations and default values preserved. + """ + + if not isinstance(cls, type): + raise TypeError("cls must be a class.") + + fields = _fields_from_constructor_signature( + cls, + expand_arbitrary_types=True, + ) + + return create_model( + _generate_schema_name(cls), + __base__=ConfigurationSchema, + __module__=cls.__module__, + **fields, + ) + + +def _fields_from_constructor_signature(cls: type, expand_arbitrary_types: bool = False) -> dict[str, tuple[object, object]]: + """ + Extract field definitions from a class constructor signature. + + Parameters + ---------- + cls + Class whose ``__init__`` method should be inspected. + expand_arbitrary_types + If true, unsupported annotations are expanded recursively into + schema-friendly types. + """ + + signature = inspect.signature(cls.__init__) + type_hints = get_type_hints(cls.__init__, include_extras=True) + + fields: dict[str, tuple[Any, Any]] = {} + + # Skip *args and **kwargs + for name, param in signature.parameters.items(): + if name == "self": + continue + + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + if name in RESERVED_CONFIGURATION_FIELDS: + raise ValueError( + f"{cls.__name__}.__init__ defines reserved parameter {name!r}, which is owned by ConfigurationSchema." + ) + + annotation = type_hints.get(name, Any) + + if expand_arbitrary_types: + annotation = _resolve_annotation(annotation) + + default = param.default if param.default is not inspect._empty else ... + fields[name] = (annotation, default) + + return fields diff --git a/tests/validation/test_schema_builder.py b/tests/validation/test_schema_builder.py new file mode 100644 index 000000000..8c84dd13f --- /dev/null +++ b/tests/validation/test_schema_builder.py @@ -0,0 +1,173 @@ +"""Tests of the schema builder.""" + +import inspect +from collections.abc import Generator +from enum import Enum +from typing import Annotated, Any, Literal, get_args, get_origin + +import pytest +from pydantic import BaseModel, Field + +from pyaml.validation.configuration_models import ConfigurationSchema +from pyaml.validation.registry import SchemaRegistry +from pyaml.validation.schema_builder import ( + _configuration_schema_from_basemodel, + _field_definition_from_field_info, + _fields_from_constructor_signature, + _generate_schema_name, + _resolve_annotation, + generate_configuration_schema, +) + + +@pytest.fixture(autouse=True) +def clear_registry() -> Generator[None, None, None]: + """Ensure the registry is empty for each test.""" + + registry = SchemaRegistry() + registry.clear() + yield + registry.clear() + + +class Color(Enum): + RED = "red" + + +def test_generate_schema_name(): + class Magnet: + pass + + assert _generate_schema_name(Magnet) == "MagnetConfigurationSchema" + + +def test_generate_configuration_schema_requires_class(): + with pytest.raises(TypeError, match="Source must be a class"): + generate_configuration_schema(1) + + +def test_generate_configuration_schema_from_validation_model(): + class ValidationModel(BaseModel): + strength: int + name: str = "Q1" + + class Magnet: + validation_model = ValidationModel + + schema = generate_configuration_schema(Magnet) + + assert issubclass(schema, ConfigurationSchema) + assert schema.__name__ == "MagnetConfigurationSchema" + assert schema.model_fields["strength"].annotation is int + assert schema.model_fields["name"].default == "Q1" + + registry = SchemaRegistry() + assert registry.get(f"{Magnet.__module__}.{Magnet.__name__}") is schema + + +def test_generate_configuration_schema_from_constructor(): + class Magnet: + def __init__(self, strength: int, name: str = "Q1"): + pass + + schema = generate_configuration_schema(Magnet) + + assert issubclass(schema, ConfigurationSchema) + assert schema.model_fields["strength"].annotation is int + assert schema.model_fields["name"].default == "Q1" + + registry = SchemaRegistry() + assert registry.get(f"{Magnet.__module__}.{Magnet.__name__}") is schema + + +def test_generate_configuration_schema_is_cached(): + class Magnet: + def __init__(self, strength: int): + pass + + schema1 = generate_configuration_schema(Magnet) + schema2 = generate_configuration_schema(Magnet) + + assert schema1 is schema2 + + +def test_configuration_schema_from_basemodel_rejects_reserved_field(): + class ValidationModel(BaseModel): + class_path: str + + with pytest.raises(ValueError, match="reserved field"): + _configuration_schema_from_basemodel( + ValidationModel, + "BadConfigurationSchema", + __name__, + ) + + +def test_fields_from_constructor_signature(): + class Example: + def __init__(self, x: int, y: str = "a", *args, **kwargs): + pass + + fields = _fields_from_constructor_signature(Example) + + assert fields["x"] == (int, ...) + assert fields["y"] == (str, "a") + assert set(fields) == {"x", "y"} + + +def test_fields_from_constructor_signature_rejects_reserved_parameter(): + class Example: + def __init__(self, class_path: str): + pass + + with pytest.raises(ValueError, match="reserved parameter"): + _fields_from_constructor_signature(Example) + + +def test_resolve_annotation(): + assert _resolve_annotation(inspect._empty) is Any + assert _resolve_annotation(None) is type(None) + assert _resolve_annotation(int) is int + assert _resolve_annotation(Color) is Color + assert _resolve_annotation(list[int]) == list[int] + assert _resolve_annotation(dict[str, int]) == dict[str, int] + assert _resolve_annotation(tuple[int, ...]) == tuple[int, ...] + assert _resolve_annotation(tuple[int, str]) == tuple[int, str] + assert _resolve_annotation(set[int]) == set[int] + assert _resolve_annotation(frozenset[int]) == frozenset[int] + assert _resolve_annotation(int | str) == (int | str) + assert _resolve_annotation(Literal["a", "b"]) == Literal["a", "b"] + + annotated = _resolve_annotation(Annotated[int, "meta"]) + assert get_origin(annotated) is Annotated + assert get_args(annotated) == (int, "meta") + + +def test_resolve_annotation_rejects_forward_reference(): + with pytest.raises(TypeError, match="Forward references"): + _resolve_annotation("SomeClass") + + +def test_field_definition_from_field_info(): + class Model(BaseModel): + x: int = Field( + default_factory=lambda: 5, + description="desc", + title="title", + examples=[1], + deprecated=True, + alias="alias", + ) + + annotation, field = _field_definition_from_field_info( + "x", + Model.model_fields["x"], + ) + + assert annotation is int + assert field.default_factory is not None + assert field.alias == "alias" + assert field.description == "desc" + assert field.title == "title" + assert field.examples == [1] + assert field.deprecated is True From 0fd5c53afbd4fbd622ba57d4fadb1a5ddbfcbfee Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 11:22:40 +0200 Subject: [PATCH 16/21] Add separate module formatting of validation errors. --- pyaml/validation/errors.py | 113 +++++++++++++++++++++ tests/validation/test_validation_errors.py | 104 +++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 pyaml/validation/errors.py create mode 100644 tests/validation/test_validation_errors.py diff --git a/pyaml/validation/errors.py b/pyaml/validation/errors.py new file mode 100644 index 000000000..8386d072f --- /dev/null +++ b/pyaml/validation/errors.py @@ -0,0 +1,113 @@ +"""Functionality for attaching location information to validation errors.""" + +from dataclasses import dataclass +from typing import Any + +from pydantic import ValidationError + +from ..common.exception import PyAMLConfigException + + +@dataclass(frozen=True) +class Location: + """ + Source location within a configuration file. + + Stores the file name together with the line and column at which a + configuration object or field was defined. + """ + + file: str + line: int + column: int + + def __str__(self) -> str: + return f"{self.file} at line {self.line}, column {self.column}." + + +@dataclass(frozen=True) +class LocationMetadata: + """ + Location metadata extracted from configuration data. + + Stores the source location of a configuration object together with + optional locations for individual configuration fields. + """ + + location: Location | None + field_locations: dict[str, Location] | None = None + + +def extract_location_metadata(data: dict[str, Any]) -> tuple[dict[str, Any], LocationMetadata]: + """ + Extract loader-added location metadata from configuration data. + + Returns a copy of the configuration dictionary with the metadata + removed together with the extracted location information. + """ + + cleaned = dict(data) + + # Get the location + raw_location = cleaned.pop("__location__", None) + location = Location(*raw_location) if raw_location is not None else None + + # Get the field locations + raw_field_locations = cleaned.pop("__fieldlocations__", None) + field_locations = ( + {field: Location(*raw_loc) for field, raw_loc in raw_field_locations.items()} + if raw_field_locations is not None + else None + ) + + return cleaned, LocationMetadata( + location=location, + field_locations=field_locations, + ) + + +def raise_validation_error( + exc: ValidationError, + class_path: str, + location_metadata: LocationMetadata | None = None, +) -> None: + """ + Raise a configuration exception from a Pydantic validation error. + + Validation messages are formatted into a human-readable error message. + If location metadata is available, source locations for the + configuration object and its fields are included in the reported + error. + """ + + messages: list[str] = [] + + for err in exc.errors(): + loc = err.get("loc", ()) + msg = err["msg"] + + if len(loc) == 2: + field, field_idx = loc + message = f"'{field}.{field_idx}': {msg}" + field_name = field + elif len(loc) == 1: + field_name = loc[0] + message = f"'{field_name}': {msg}" + else: + field_name = None + message = f"{loc}: {msg}" + + if ( + location_metadata is not None + and location_metadata.field_locations is not None + and field_name in location_metadata.field_locations + ): + message += f" ({location_metadata.field_locations[field_name]})" + + messages.append(message) + + location_str = "" + if location_metadata is not None and location_metadata.location is not None: + location_str = f" ({location_metadata.location})" + + raise PyAMLConfigException(f"{'; '.join(messages)} for class: '{class_path}'{location_str}") from None diff --git a/tests/validation/test_validation_errors.py b/tests/validation/test_validation_errors.py new file mode 100644 index 000000000..f0049d343 --- /dev/null +++ b/tests/validation/test_validation_errors.py @@ -0,0 +1,104 @@ +"""Tests of validation errors.""" + +import pytest +from pydantic import BaseModel, ValidationError + +from pyaml.common.exception import PyAMLConfigException +from pyaml.validation.errors import ( + Location, + LocationMetadata, + extract_location_metadata, + raise_validation_error, +) + + +def test_location_str_formats_readably(): + loc = Location(file="config.yaml", line=12, column=4) + assert str(loc) == "config.yaml at line 12, column 4." + + +def test_extract_location_metadata_removes_metadata_and_converts_values(): + data = { + "__location__": ("config.yaml", 22, 3), + "__fieldlocations__": { + "class": ("config.yaml", 22, 10), + "name": ("config.yaml", 23, 9), + }, + "class": "pkg.module.Class", + "name": "test_device", + } + + cleaned, metadata = extract_location_metadata(data) + + assert cleaned == { + "class": "pkg.module.Class", + "name": "test_device", + } + assert metadata.location == Location("config.yaml", 22, 3) + assert metadata.field_locations == { + "class": Location("config.yaml", 22, 10), + "name": Location("config.yaml", 23, 9), + } + + +def test_extract_location_metadata_without_metadata_returns_clean_data(): + data = {"class": "pkg.module.Class", "name": "test_device"} + + cleaned, metadata = extract_location_metadata(data) + + assert cleaned == data + assert metadata.location is None + assert metadata.field_locations is None + + +class SimpleModel(BaseModel): + age: int + + +class DeepNestedModel(BaseModel): + items: list[SimpleModel] + + +def get_validation_error(model: type[BaseModel], payload: dict) -> ValidationError: + with pytest.raises(ValidationError) as exc: + model.model_validate(payload) + return exc.value + + +def test_raise_validation_error_formats_error_with_location_metadata(): + exc = get_validation_error(SimpleModel, {"age": "not-an-int"}) + + metadata = LocationMetadata( + location=Location("config.yaml", 20, 1), + field_locations={ + "age": Location("config.yaml", 21, 7), + }, + ) + + with pytest.raises(PyAMLConfigException) as err: + raise_validation_error( + exc, + class_path="pkg.module.Class", + location_metadata=metadata, + ) + + message = str(err.value) + + assert "'age':" in message + assert "for class: 'pkg.module.Class'" in message + assert "config.yaml at line 21, column 7." in message + assert "config.yaml at line 20, column 1." in message + + +def test_raise_validation_error_formats_deep_nested_error_tuple_repr(): + exc = get_validation_error(DeepNestedModel, {"items": [{"age": "nope"}]}) + + with pytest.raises(PyAMLConfigException) as err: + raise_validation_error( + exc, + class_path="demo.DeepNestedModel", + ) + + message = str(err.value) + assert "('items', 0, 'age'):" in message + assert "for class: 'demo.DeepNestedModel'" in message From ac4fde989bbf81fdfc0660205e69654a9a4b29ff Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 11:40:59 +0200 Subject: [PATCH 17/21] Add error formatting and translation from legacy config format to schema validator. --- pyaml/validation/validator.py | 82 +++++++++++++++++++++-------- tests/validation/test_validator.py | 83 +++++++++++++++++++++++++++--- 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/pyaml/validation/validator.py b/pyaml/validation/validator.py index da5ed5500..2032ff38f 100644 --- a/pyaml/validation/validator.py +++ b/pyaml/validation/validator.py @@ -6,7 +6,8 @@ from pydantic import ValidationError -from .models import ConfigurationSchema +from .configuration_models import ConfigurationSchema, ModuleConfigurationSchema +from .errors import extract_location_metadata, raise_validation_error from .registry import SchemaRegistry logger = logging.getLogger(__name__) @@ -62,6 +63,18 @@ def validate( return validated + @classmethod + def validate_to_dict( + cls, + data: dict[str, Any], + ) -> dict: + """ + Validate configuration data recursively and return it as a dictionary. + """ + + validated = cls.validate(data) + return validated.model_dump() + @classmethod def _recursive_validate(cls, obj: Any) -> Any: """Recursively validate nested configuration objects. @@ -95,15 +108,23 @@ def _recursive_validate(cls, obj: Any) -> Any: if not isinstance(obj, dict): return obj + # Remove loader metadata and keep it for error reporting. + cleaned_dict, location_metadata = extract_location_metadata(obj) + logger.debug("Validating dict with keys: %s", list(obj)) - validated_dict = {key: cls._recursive_validate(value) for key, value in obj.items()} + validated_dict = {key: cls._recursive_validate(value) for key, value in cleaned_dict.items()} # Check if the dict is a configuration object - config = cls._parse_configuration(validated_dict) - if config is None: + # If it follows the old convention for configuration + # translate the data to the new one + parsed = cls._parse_configuration(validated_dict) + + if parsed is None: return validated_dict + else: + validated_dict = parsed - class_path = config.class_path + class_path = validated_dict["class_path"] schema = cls._registry.get(class_path) if schema is None: @@ -113,32 +134,51 @@ def _recursive_validate(cls, obj: Any) -> Any: ) return validated_dict - return schema.model_validate(validated_dict) + try: + return schema.model_validate(validated_dict) + except ValidationError as exc: + raise_validation_error( + exc, + class_path=class_path, + location_metadata=location_metadata, + ) @classmethod def _parse_configuration( cls, validated_dict: dict[str, Any], - ) -> ConfigurationSchema | None: - """Parse a dictionary as configuration metadata. - - Parameters - ---------- - validated_dict : dict[str, Any] - Dictionary to interpret as configuration metadata. + ) -> dict[str, Any] | None: + """ + Interpret a dictionary as configuration metadata. - Returns - ------- - ConfigurationSchema | None - Parsed configuration model if validation succeeds, - otherwise ``None``. + Returns the dictionary unchanged if it already matches + :class:`ConfigurationSchema`. If it matches + :class:`ModuleConfigurationSchema`, the legacy module-based format is + rewritten into the modern ``class_path`` form. Otherwise returns + ``None``. """ + try: - return ConfigurationSchema.model_validate( + ConfigurationSchema.model_validate(validated_dict, extra="allow") + return validated_dict + except ValidationError: + logger.debug("Could not validate against ConfigurationSchema.") + + try: + module_config = ModuleConfigurationSchema.model_validate( validated_dict, extra="allow", ) except ValidationError: - logger.debug("Could not validate against ConfigurationSchema.") + logger.debug("Could not validate against ModuleConfigurationSchema; returning raw dict.") + return None + + class_config = module_config.to_configuration() + + rewritten = dict(validated_dict) + rewritten["class_path"] = class_config.class_path + rewritten.pop("type", None) + rewritten.pop("module", None) - return None + logger.debug("Configuration transformed from legacy configuration format.") + return rewritten diff --git a/tests/validation/test_validator.py b/tests/validation/test_validator.py index e7eec6e95..ad6ff995e 100644 --- a/tests/validation/test_validator.py +++ b/tests/validation/test_validator.py @@ -1,9 +1,12 @@ """Tests of the schema validator.""" +import sys from collections.abc import Generator +from types import ModuleType import pytest +from pyaml.common.exception import PyAMLConfigException from pyaml.validation import ( ConfigurationSchema, SchemaRegistry, @@ -137,25 +140,47 @@ def test_recursive_validate_warns_for_unknown_schema( # ========================================================== -def test_parse_configuration_returns_configuration_schema(): +def test_parse_configuration_returns_none_for_non_configuration_dict(): + data = { + "plain": "dict", + } + + result = SchemaValidator._parse_configuration(data) + + assert result is None + + +def test_parse_configuration_accepts_modern_configuration() -> None: data = { "class_path": "pkg.module.Class", + "value": 42, } result = SchemaValidator._parse_configuration(data) - assert isinstance(result, ConfigurationSchema) - assert result.class_path == "pkg.module.Class" + assert result == data -def test_parse_configuration_returns_none_for_non_configuration_dict(): +def test_parse_configuration_translates_legacy_module_configuration( + monkeypatch: pytest.MonkeyPatch, +) -> None: + module_name = "legacy_test_module" + + module = ModuleType(module_name) + module.PYAMLCLASS = "LegacyClass" + monkeypatch.setitem(sys.modules, module_name, module) + data = { - "plain": "dict", + "module": module_name, + "value": 42, } result = SchemaValidator._parse_configuration(data) - assert result is None + assert result == { + "class_path": f"{module_name}.LegacyClass", + "value": 42, + } # ========================================================== @@ -190,3 +215,49 @@ def test_validate_raises_typeerror_for_non_configuration_dict(): match=r"Top-level configuration did not validate to a ConfigurationSchema\.", ): SchemaValidator.validate(data) + + +def test_validate_to_dict_returns_dict(registry: SchemaRegistry) -> None: + registry.register("pkg.module.Class", DummySchema) + + data = { + "class_path": "pkg.module.Class", + "value": 42, + } + + result = SchemaValidator.validate_to_dict(data) + + assert result == { + "class_path": "pkg.module.Class", + "value": 42, + } + + +# ========================================================== +# Error handling +# ========================================================== + + +def test_recursive_validate_includes_location_metadata_in_error( + registry: SchemaRegistry, +) -> None: + registry.register("pkg.module.Class", DummySchema) + + data = { + "__location__": ("config.yaml", 10, 4), + "__fieldlocations__": { + "value": ("config.yaml", 11, 8), + }, + "class_path": "pkg.module.Class", + "value": "not-an-int", + } + + with pytest.raises(PyAMLConfigException) as exc_info: + SchemaValidator._recursive_validate(data) + + message = str(exc_info.value) + + assert "pkg.module.Class" in message + assert "config.yaml at line 10, column 4." in message + assert "config.yaml at line 11, column 8." in message + assert "'value'" in message From 0525f73e9f54f642576ef26cca76b727e7eb0642 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 16:14:48 +0200 Subject: [PATCH 18/21] Add option to register dynamically generated schema. --- pyaml/validation/registry.py | 89 ++++++++++++++++++------------ pyaml/validation/schema_builder.py | 4 ++ tests/validation/test_registry.py | 31 +++++++++++ 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/pyaml/validation/registry.py b/pyaml/validation/registry.py index ea1239f75..4ba9bca29 100644 --- a/pyaml/validation/registry.py +++ b/pyaml/validation/registry.py @@ -4,9 +4,9 @@ import logging import pkgutil from collections.abc import ItemsView, Iterator, KeysView, ValuesView -from typing import Callable, Type, TypeVar +from typing import Callable, Type, TypeVar, overload -from .models import ConfigurationSchema +from .configuration_models import ConfigurationSchema logger = logging.getLogger(__name__) @@ -301,50 +301,71 @@ def update( # Decorator to register schemas # ========================================================== -ModelT = TypeVar("ModelT", bound=ConfigurationSchema) ClassT = TypeVar("ClassT") -def register_schema( - schema: Type[ModelT], -) -> Callable[[Type[ClassT]], Type[ClassT]]: - """Register a runtime class with a Pydantic schema. +@overload +def register_schema(cls: type[ClassT]) -> type[ClassT]: ... - Parameters - ---------- - schema : Type[ModelT] - Schema class to register. Must inherit from - :class:`ConfigurationSchema`. - Returns - ------- - Callable[[Type[ClassT]], Type[ClassT]] - Decorator that registers the decorated class with ``schema``. +@overload +def register_schema(schema: type[ConfigurationSchema]) -> Callable[[type[ClassT]], type[ClassT]]: ... - Examples - -------- - >>> @register_schema(MySchema) - ... class MyClass: - ... pass + +@overload +def register_schema() -> Callable[[type[ClassT]], type[ClassT]]: ... + + +def register_schema(arg: type | None = None): """ + Register a configuration schema for a class. - if not (isinstance(schema, type) and issubclass(schema, ConfigurationSchema)): - raise TypeError("register_schema must be called with a schema class, e.g. @register_schema(MySchema)") + This decorator supports three forms: - registry = SchemaRegistry() + - ``@register_schema``: generate and register a configuration schema + from the decorated class. + - ``@register_schema()``: equivalent to ``@register_schema``. + - ``@register_schema(MyConfigurationSchema)``: register an explicit + :class:`ConfigurationSchema` subclass for the decorated class. - def decorator( - cls: Type[ClassT], - ) -> Type[ClassT]: - class_path = f"{cls.__module__}.{cls.__name__}" + Automatically generated schemas are created using + :func:`generate_configuration_schema` and registered in the + :class:`SchemaRegistry`. + """ - logger.debug("Register schema for %s.", class_path) + from .schema_builder import generate_configuration_schema - registry.register( - class_path=class_path, - schema=schema, - ) + registry = SchemaRegistry() + def _generate_and_register_schema(cls: type[ClassT]) -> type[ClassT]: + generate_configuration_schema(cls) return cls - return decorator + # Used as: @register_schema(schema) + # Explicit registration of schema + if isinstance(arg, type) and issubclass(arg, ConfigurationSchema): + schema = arg + + def decorator(cls: type[ClassT]) -> type[ClassT]: + class_path = f"{cls.__module__}.{cls.__name__}" + logger.debug("Register schema for %s.", class_path) + registry.register(class_path=class_path, schema=schema) + return cls + + return decorator + + # Used as: @register_schema() + # Registration is done when generating the schema + if arg is None: + + def decorator(cls: type[ClassT]) -> type[ClassT]: + return _generate_and_register_schema(cls) + + return decorator + + # Used as: @register_schema + # Registration is done when generating the schema + if isinstance(arg, type): + return _generate_and_register_schema(arg) + + raise TypeError("register_schema must be used as a decorator or decorator factory.") diff --git a/pyaml/validation/schema_builder.py b/pyaml/validation/schema_builder.py index 3281091cf..652a30fbe 100644 --- a/pyaml/validation/schema_builder.py +++ b/pyaml/validation/schema_builder.py @@ -1,6 +1,7 @@ """Functionality for dynamically generating configuration schemas.""" import inspect +import logging import types from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -17,6 +18,8 @@ from .configuration_models import ConfigurationSchema from .registry import SchemaRegistry +logger = logging.getLogger(__name__) + RESERVED_CONFIGURATION_FIELDS = {"class_path"} SUPPORTED_TYPES = ( @@ -76,6 +79,7 @@ def generate_configuration_schema(source: type) -> type[ConfigurationSchema]: else: schema = _configuration_schema_from_constructor(source) + logger.debug("Register schema for %s.", class_path) registry.register(class_path, schema) return schema diff --git a/tests/validation/test_registry.py b/tests/validation/test_registry.py index ce68aa822..df5b68e6f 100644 --- a/tests/validation/test_registry.py +++ b/tests/validation/test_registry.py @@ -345,3 +345,34 @@ class SecondClass: assert registry[first_path] is DummySchema assert registry[second_path] is DummySchema assert len(registry) == 2 + + +def test_register_schema_generates_and_registers_schema_for_class(registry: SchemaRegistry): + @register_schema + class DecoratedClass: + def __init__(self, value: int): + pass + + class_path = f"{DecoratedClass.__module__}.{DecoratedClass.__name__}" + + schema = registry[class_path] + assert schema.__name__ == "DecoratedClassConfigurationSchema" + assert "value" in schema.model_fields + + +def test_register_schema_with_empty_call_generates_and_registers_schema( + registry: SchemaRegistry, +): + @register_schema() + class DecoratedClass: + def __init__(self, value: int): + pass + + class_path = f"{DecoratedClass.__module__}.{DecoratedClass.__name__}" + + assert class_path in registry + + +def test_register_schema_rejects_non_schema_explicit_argument(): + with pytest.raises(TypeError): + register_schema(123) From 78aca2631bbd1ad1e1bf24aa3f25f373d2e48494 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 16:15:55 +0200 Subject: [PATCH 19/21] Update import in schema generator after changes. --- pyaml/validation/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyaml/validation/generator.py b/pyaml/validation/generator.py index 874ee4afa..d890b046c 100644 --- a/pyaml/validation/generator.py +++ b/pyaml/validation/generator.py @@ -9,7 +9,7 @@ from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import core_schema -from .models import ConfigurationSchema +from .configuration_models import ConfigurationSchema from .registry import SchemaRegistry logger = logging.getLogger(__name__) From a71f35e43aab5947de8dec31cd70be9b9be58c65 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Thu, 25 Jun 2026 16:19:07 +0200 Subject: [PATCH 20/21] Changes to RF to remove config models. --- pyaml/rf/rf_plant.py | 3 ++- pyaml/rf/rf_transmitter.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyaml/rf/rf_plant.py b/pyaml/rf/rf_plant.py index a9ad488b9..1f835d905 100644 --- a/pyaml/rf/rf_plant.py +++ b/pyaml/rf/rf_plant.py @@ -4,13 +4,14 @@ from .. import PyAMLException from ..common import abstract from ..common.element import Element -from ..validation import DynamicValidation +from ..validation import DynamicValidation, register_schema from .rf_transmitter import RFTransmitter # Define the main class name for this module PYAMLCLASS = "RFPlant" +@register_schema class RFPlant(Element, DynamicValidation): """ Main RF object diff --git a/pyaml/rf/rf_transmitter.py b/pyaml/rf/rf_transmitter.py index 2dcdd5098..8c298a907 100644 --- a/pyaml/rf/rf_transmitter.py +++ b/pyaml/rf/rf_transmitter.py @@ -4,12 +4,13 @@ from .. import PyAMLException from ..common import abstract from ..common.element import Element -from ..validation import DynamicValidation +from ..validation import DynamicValidation, register_schema # Define the main class name for this module PYAMLCLASS = "RFTransmitter" +@register_schema class RFTransmitter(Element, DynamicValidation): """ Class that handle a RF transmitter From f814df4faaf5f9890ffa8118d586a4b406ed8547 Mon Sep 17 00:00:00 2001 From: Teresia Olsson Date: Fri, 26 Jun 2026 17:52:52 +0200 Subject: [PATCH 21/21] Make validation at object creation optional. --- pyaml/validation/validation_models.py | 20 +++++++++++++++++--- tests/validation/test_models.py | 12 ++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/pyaml/validation/validation_models.py b/pyaml/validation/validation_models.py index 129710428..afecfb361 100644 --- a/pyaml/validation/validation_models.py +++ b/pyaml/validation/validation_models.py @@ -32,16 +32,26 @@ class ValidationMeta(type): Whenever an instance is created, the supplied constructor arguments are validated against the model before the class constructor is invoked. + + Pass `validate=False` to skip validation for a single construction. """ def __call__(cls, *args: Any, **kwargs: Any): """ - Create an instance after validating constructor arguments. + Create an instance after optionally validating constructor arguments. The supplied arguments are bound to the class ``__init__`` signature, default values are applied, and the resulting argument mapping is - validated using ``validation_model``. The validated values are then - passed to the constructor. + validated using ``validation_model`` unless ``validate=False`` is + passed to the constructor. The validated values are then passed to the + constructor. + + Parameters + ---------- + validate + If ``True`` (default), validate constructor arguments before + instantiation. If ``False``, skip validation and pass the supplied + arguments directly to the constructor. Raises ------ @@ -52,6 +62,10 @@ def __call__(cls, *args: Any, **kwargs: Any): If the supplied arguments do not conform to the validation model. """ + validate = kwargs.pop("validate", True) + + if not validate: + return super().__call__(*args, **kwargs) validation_model = getattr(cls, "validation_model", None) diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index 6bab84987..45ddbc068 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -193,6 +193,18 @@ def __init__(self, name: str): self.name = name +def test_dynamic_validation_can_be_disabled_per_instance(): + class MyClass(DynamicValidation): + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + obj = MyClass(name="test", count="12", validate=False) + + assert obj.name == "test" + assert obj.count == "12" + + # ========================================================== # StaticValidation # ==========================================================