Skip to content

Commit 7348cdd

Browse files
committed
Fix mixed-dtype autocast regressions
1 parent e2c9b01 commit 7348cdd

3 files changed

Lines changed: 157 additions & 19 deletions

File tree

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,40 @@ def autocast_linear_forward_sidecar_patches(
5858

5959
# Finally, apply any remaining patches.
6060
if len(unprocessed_patches_and_weights) > 0:
61+
weight, bias = orig_module._cast_weight_bias_for_input(input)
6162
# Prepare the original parameters for the patch aggregation.
62-
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
63+
orig_params = {"weight": weight, "bias": bias}
6364
# Filter out None values.
6465
orig_params = {k: v for k, v in orig_params.items() if v is not None}
6566

6667
aggregated_param_residuals = orig_module._aggregate_patch_parameters(
6768
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
6869
)
69-
output += torch.nn.functional.linear(
70-
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
71-
)
70+
residual_weight = orig_module._cast_tensor_for_input(aggregated_param_residuals["weight"], input)
71+
residual_bias = orig_module._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
72+
assert residual_weight is not None
73+
output += torch.nn.functional.linear(input, residual_weight, residual_bias)
7274

7375
return output
7476

7577

7678
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
77-
def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
78-
weight = cast_to_device(self.weight, input.device)
79-
bias = cast_to_device(self.bias, input.device)
79+
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
80+
tensor = cast_to_device(tensor, input.device)
8081
if (
81-
input.is_floating_point()
82-
and weight.is_floating_point()
83-
and not isinstance(weight, GGMLTensor)
84-
and weight.dtype != input.dtype
82+
tensor is not None
83+
and input.is_floating_point()
84+
and tensor.is_floating_point()
85+
and not isinstance(tensor, GGMLTensor)
86+
and tensor.dtype != input.dtype
8587
):
86-
weight = weight.to(dtype=input.dtype)
87-
if bias is not None and not isinstance(bias, GGMLTensor):
88-
bias = bias.to(dtype=input.dtype)
88+
tensor = tensor.to(dtype=input.dtype)
89+
return tensor
90+
91+
def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
92+
weight = self._cast_tensor_for_input(self.weight, input)
93+
bias = self._cast_tensor_for_input(self.bias, input)
94+
assert weight is not None
8995
return weight, bias
9096

9197
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
@@ -100,10 +106,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
100106
return self._autocast_forward_with_patches(input)
101107
elif self._device_autocasting_enabled:
102108
return self._autocast_forward(input)
103-
elif (
104-
input.is_floating_point()
105-
and self.weight.is_floating_point()
106-
and self.weight.dtype != input.dtype
109+
elif input.is_floating_point() and (
110+
(self.weight.is_floating_point() and self.weight.dtype != input.dtype)
111+
or (
112+
self.bias is not None
113+
and self.bias.is_floating_point()
114+
and not isinstance(self.bias, GGMLTensor)
115+
and self.bias.dtype != input.dtype
116+
)
107117
):
108118
weight, bias = self._cast_weight_bias_for_input(input)
109119
return torch.nn.functional.linear(input, weight, bias)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def _aggregate_patch_parameters(
4949
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
5050
for param_name in orig_params.keys():
5151
param = orig_params[param_name]
52-
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
52+
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
53+
pass
54+
elif type(param) is torch.Tensor:
5355
pass
5456
elif type(param) is GGMLTensor:
5557
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /

tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from collections.abc import Callable
23

34
import gguf
45
import pytest
@@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module):
124125
return unwrap_custom_layer(layer, orig_layer_type)
125126

126127

128+
class ZeroParamPatch(BaseLayerPatch):
129+
"""A minimal parameter patch that exercises the aggregated sidecar patch path."""
130+
131+
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
132+
return {name: torch.zeros_like(param) for name, param in orig_parameters.items()}
133+
134+
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
135+
return self
136+
137+
def calc_size(self) -> int:
138+
return 0
139+
140+
141+
def _cpu_dtype_supported(
142+
layer_factory: Callable[[], torch.nn.Module],
143+
input_factory: Callable[[torch.dtype], torch.Tensor],
144+
dtype: torch.dtype,
145+
) -> bool:
146+
try:
147+
layer = layer_factory().to(dtype=dtype)
148+
input_tensor = input_factory(dtype)
149+
with torch.no_grad():
150+
_ = layer(input_tensor)
151+
return True
152+
except (RuntimeError, TypeError, NotImplementedError):
153+
return False
154+
155+
156+
def _cpu_dtype_param(
157+
dtype: torch.dtype,
158+
layer_factory: Callable[[], torch.nn.Module],
159+
input_factory: Callable[[torch.dtype], torch.Tensor],
160+
):
161+
supported = _cpu_dtype_supported(layer_factory, input_factory, dtype)
162+
return pytest.param(
163+
dtype,
164+
id=str(dtype).removeprefix("torch."),
165+
marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"),
166+
)
167+
168+
169+
LINEAR_CPU_MIXED_DTYPE_PARAMS = [
170+
_cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
171+
_cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
172+
]
173+
174+
175+
CONV2D_CPU_MIXED_DTYPE_PARAMS = [
176+
_cpu_dtype_param(
177+
torch.bfloat16,
178+
lambda: torch.nn.Conv2d(8, 16, 3),
179+
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
180+
),
181+
_cpu_dtype_param(
182+
torch.float16,
183+
lambda: torch.nn.Conv2d(8, 16, 3),
184+
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
185+
),
186+
]
187+
188+
127189
def test_isinstance(layer_under_test: LayerUnderTest):
128190
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
129191
orig_layer, _, _ = layer_under_test
@@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(
550612

551613
# Assert that the outputs with and without autocasting are the same.
552614
assert torch.allclose(expected_output, autocast_output, atol=1e-6)
615+
616+
617+
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
618+
@torch.no_grad()
619+
def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype):
620+
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
621+
input = torch.randn(2, 8, dtype=dtype)
622+
623+
output = layer(input)
624+
625+
assert output.dtype == input.dtype
626+
assert output.shape == (2, 16)
627+
628+
629+
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
630+
@torch.no_grad()
631+
def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype):
632+
layer = torch.nn.Linear(8, 16).to(dtype=dtype)
633+
layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32))
634+
layer = wrap_single_custom_layer(layer)
635+
input = torch.randn(2, 8, dtype=dtype)
636+
637+
output = layer(input)
638+
639+
assert output.dtype == input.dtype
640+
assert output.shape == (2, 16)
641+
642+
643+
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
644+
@torch.no_grad()
645+
def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype):
646+
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
647+
input = torch.randn(2, 8, 5, 5, dtype=dtype)
648+
649+
output = layer(input)
650+
651+
assert output.dtype == input.dtype
652+
assert output.shape == (2, 16, 3, 3)
653+
654+
655+
@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
656+
@torch.no_grad()
657+
def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
658+
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
659+
layer.add_patch(ZeroParamPatch(), 1.0)
660+
input = torch.randn(2, 8, dtype=dtype)
661+
662+
output = layer(input)
663+
664+
assert output.dtype == input.dtype
665+
assert output.shape == (2, 16)
666+
667+
668+
@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
669+
@torch.no_grad()
670+
def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
671+
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
672+
layer.add_patch(ZeroParamPatch(), 1.0)
673+
input = torch.randn(2, 8, 5, 5, dtype=dtype)
674+
675+
output = layer(input)
676+
677+
assert output.dtype == input.dtype
678+
assert output.shape == (2, 16, 3, 3)

0 commit comments

Comments
 (0)