|
1 | 1 | import copy |
| 2 | +from collections.abc import Callable |
2 | 3 |
|
3 | 4 | import gguf |
4 | 5 | import pytest |
@@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module): |
124 | 125 | return unwrap_custom_layer(layer, orig_layer_type) |
125 | 126 |
|
126 | 127 |
|
| 128 | +class ZeroParamPatch(BaseLayerPatch): |
| 129 | + """A minimal parameter patch that exercises the aggregated sidecar patch path.""" |
| 130 | + |
| 131 | + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: |
| 132 | + return {name: torch.zeros_like(param) for name, param in orig_parameters.items()} |
| 133 | + |
| 134 | + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): |
| 135 | + return self |
| 136 | + |
| 137 | + def calc_size(self) -> int: |
| 138 | + return 0 |
| 139 | + |
| 140 | + |
| 141 | +def _cpu_dtype_supported( |
| 142 | + layer_factory: Callable[[], torch.nn.Module], |
| 143 | + input_factory: Callable[[torch.dtype], torch.Tensor], |
| 144 | + dtype: torch.dtype, |
| 145 | +) -> bool: |
| 146 | + try: |
| 147 | + layer = layer_factory().to(dtype=dtype) |
| 148 | + input_tensor = input_factory(dtype) |
| 149 | + with torch.no_grad(): |
| 150 | + _ = layer(input_tensor) |
| 151 | + return True |
| 152 | + except (RuntimeError, TypeError, NotImplementedError): |
| 153 | + return False |
| 154 | + |
| 155 | + |
| 156 | +def _cpu_dtype_param( |
| 157 | + dtype: torch.dtype, |
| 158 | + layer_factory: Callable[[], torch.nn.Module], |
| 159 | + input_factory: Callable[[torch.dtype], torch.Tensor], |
| 160 | +): |
| 161 | + supported = _cpu_dtype_supported(layer_factory, input_factory, dtype) |
| 162 | + return pytest.param( |
| 163 | + dtype, |
| 164 | + id=str(dtype).removeprefix("torch."), |
| 165 | + marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"), |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +LINEAR_CPU_MIXED_DTYPE_PARAMS = [ |
| 170 | + _cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)), |
| 171 | + _cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)), |
| 172 | +] |
| 173 | + |
| 174 | + |
| 175 | +CONV2D_CPU_MIXED_DTYPE_PARAMS = [ |
| 176 | + _cpu_dtype_param( |
| 177 | + torch.bfloat16, |
| 178 | + lambda: torch.nn.Conv2d(8, 16, 3), |
| 179 | + lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype), |
| 180 | + ), |
| 181 | + _cpu_dtype_param( |
| 182 | + torch.float16, |
| 183 | + lambda: torch.nn.Conv2d(8, 16, 3), |
| 184 | + lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype), |
| 185 | + ), |
| 186 | +] |
| 187 | + |
| 188 | + |
127 | 189 | def test_isinstance(layer_under_test: LayerUnderTest): |
128 | 190 | """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" |
129 | 191 | orig_layer, _, _ = layer_under_test |
@@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device( |
550 | 612 |
|
551 | 613 | # Assert that the outputs with and without autocasting are the same. |
552 | 614 | assert torch.allclose(expected_output, autocast_output, atol=1e-6) |
| 615 | + |
| 616 | + |
| 617 | +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) |
| 618 | +@torch.no_grad() |
| 619 | +def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype): |
| 620 | + layer = wrap_single_custom_layer(torch.nn.Linear(8, 16)) |
| 621 | + input = torch.randn(2, 8, dtype=dtype) |
| 622 | + |
| 623 | + output = layer(input) |
| 624 | + |
| 625 | + assert output.dtype == input.dtype |
| 626 | + assert output.shape == (2, 16) |
| 627 | + |
| 628 | + |
| 629 | +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) |
| 630 | +@torch.no_grad() |
| 631 | +def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype): |
| 632 | + layer = torch.nn.Linear(8, 16).to(dtype=dtype) |
| 633 | + layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32)) |
| 634 | + layer = wrap_single_custom_layer(layer) |
| 635 | + input = torch.randn(2, 8, dtype=dtype) |
| 636 | + |
| 637 | + output = layer(input) |
| 638 | + |
| 639 | + assert output.dtype == input.dtype |
| 640 | + assert output.shape == (2, 16) |
| 641 | + |
| 642 | + |
| 643 | +@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS) |
| 644 | +@torch.no_grad() |
| 645 | +def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype): |
| 646 | + layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3)) |
| 647 | + input = torch.randn(2, 8, 5, 5, dtype=dtype) |
| 648 | + |
| 649 | + output = layer(input) |
| 650 | + |
| 651 | + assert output.dtype == input.dtype |
| 652 | + assert output.shape == (2, 16, 3, 3) |
| 653 | + |
| 654 | + |
| 655 | +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) |
| 656 | +@torch.no_grad() |
| 657 | +def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype): |
| 658 | + layer = wrap_single_custom_layer(torch.nn.Linear(8, 16)) |
| 659 | + layer.add_patch(ZeroParamPatch(), 1.0) |
| 660 | + input = torch.randn(2, 8, dtype=dtype) |
| 661 | + |
| 662 | + output = layer(input) |
| 663 | + |
| 664 | + assert output.dtype == input.dtype |
| 665 | + assert output.shape == (2, 16) |
| 666 | + |
| 667 | + |
| 668 | +@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS) |
| 669 | +@torch.no_grad() |
| 670 | +def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype): |
| 671 | + layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3)) |
| 672 | + layer.add_patch(ZeroParamPatch(), 1.0) |
| 673 | + input = torch.randn(2, 8, 5, 5, dtype=dtype) |
| 674 | + |
| 675 | + output = layer(input) |
| 676 | + |
| 677 | + assert output.dtype == input.dtype |
| 678 | + assert output.shape == (2, 16, 3, 3) |
0 commit comments