Skip to content

Commit c39fba2

Browse files
authored
[tests] fix autoencoderdc tests (#13424)
* fix autoencoderdc tests * up
1 parent 24b4c25 commit c39fba2

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_dc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828

2929

3030
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def main_input_name(self):
33+
return "sample"
34+
3135
@property
3236
def model_class(self):
3337
return AutoencoderDC
@@ -77,6 +81,12 @@ def get_dummy_inputs(self):
7781
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
7882
base_precision = 1e-2
7983

84+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
85+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
86+
if dtype == torch.bfloat16 and IS_GITHUB_ACTIONS:
87+
pytest.skip("Skipping bf16 test inside GitHub Actions environment")
88+
super().test_from_save_pretrained_dtype_inference(tmp_path, dtype)
89+
8090

8191
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
8292
"""Training tests for AutoencoderDC."""

0 commit comments

Comments
 (0)