|
21 | 21 | from meridian import constants |
22 | 22 | from meridian.backend import test_utils |
23 | 23 | from meridian.model import context |
| 24 | +from meridian.model import knots |
24 | 25 | from meridian.model import model_test_data |
25 | 26 | from meridian.model import spec as model_spec |
26 | 27 | from meridian.model.eda import constants as eda_constants |
@@ -8060,6 +8061,76 @@ def test_check_population_corr_raw_media_returns_info_finding(self): |
8060 | 8061 | with self.subTest("finding_associated_artifact"): |
8061 | 8062 | self.assertIs(finding.associated_artifact, outcome.analysis_artifacts[0]) |
8062 | 8063 |
|
| 8064 | + def test_check_data_param_ratio(self): |
| 8065 | + mock_model_context = mock.create_autospec( |
| 8066 | + context.ModelContext, instance=True, spec_set=True |
| 8067 | + ) |
| 8068 | + |
| 8069 | + mock_model_context.n_geos = 10 |
| 8070 | + mock_model_context.n_times = 20 |
| 8071 | + mock_model_context.n_controls = 2 |
| 8072 | + mock_model_context.n_media_channels = 3 |
| 8073 | + mock_model_context.n_rf_channels = 1 |
| 8074 | + mock_model_context.n_organic_media_channels = 0 |
| 8075 | + mock_model_context.n_organic_rf_channels = 0 |
| 8076 | + mock_model_context.n_non_media_channels = 0 |
| 8077 | + n_knots = 5 |
| 8078 | + mock_model_context.knot_info = knots.KnotInfo( |
| 8079 | + n_knots=n_knots, knot_locations=np.array([1]), weights=np.array([1]) |
| 8080 | + ) |
| 8081 | + mock.seal(mock_model_context) |
| 8082 | + |
| 8083 | + engine = eda_engine.EDAEngine(model_context=mock_model_context) |
| 8084 | + outcome = engine.check_data_param_ratio() |
| 8085 | + |
| 8086 | + expected_n_treatments = ( |
| 8087 | + mock_model_context.n_media_channels |
| 8088 | + + mock_model_context.n_rf_channels |
| 8089 | + + mock_model_context.n_organic_media_channels |
| 8090 | + + mock_model_context.n_organic_rf_channels |
| 8091 | + + mock_model_context.n_non_media_channels |
| 8092 | + ) |
| 8093 | + expected_n_parameters = ( |
| 8094 | + (mock_model_context.n_geos - 1) |
| 8095 | + + n_knots |
| 8096 | + + mock_model_context.n_controls |
| 8097 | + + expected_n_treatments |
| 8098 | + ) |
| 8099 | + expected_n_data_points = ( |
| 8100 | + mock_model_context.n_geos * mock_model_context.n_times |
| 8101 | + ) |
| 8102 | + expected_ratio = expected_n_data_points / expected_n_parameters |
| 8103 | + |
| 8104 | + with self.subTest(name="check_type"): |
| 8105 | + self.assertEqual( |
| 8106 | + outcome.check_type, eda_outcome.EDACheckType.DATA_ADEQUACY |
| 8107 | + ) |
| 8108 | + |
| 8109 | + with self.subTest(name="finding_details"): |
| 8110 | + self.assertLen(outcome.findings, 1) |
| 8111 | + finding = outcome.findings[0] |
| 8112 | + self.assertEqual(finding.severity, eda_outcome.EDASeverity.INFO) |
| 8113 | + self.assertIn( |
| 8114 | + "As a rough guidance, please review the ratio of data points to", |
| 8115 | + finding.explanation, |
| 8116 | + ) |
| 8117 | + self.assertIn( |
| 8118 | + f"This ratio is {expected_ratio:.2f} for your dataset", |
| 8119 | + finding.explanation, |
| 8120 | + ) |
| 8121 | + self.assertIn( |
| 8122 | + f"* n_treatments = {expected_n_treatments}.", |
| 8123 | + finding.explanation, |
| 8124 | + ) |
| 8125 | + |
| 8126 | + with self.subTest(name="artifact_details"): |
| 8127 | + self.assertLen(outcome.analysis_artifacts, 1) |
| 8128 | + artifact = outcome.analysis_artifacts[0] |
| 8129 | + self.assertIsInstance(artifact, eda_outcome.DataParameterRatioArtifact) |
| 8130 | + self.assertEqual(artifact.n_parameters, expected_n_parameters) |
| 8131 | + self.assertEqual(artifact.n_data_points, expected_n_data_points) |
| 8132 | + self.assertAlmostEqual(artifact.ratio, expected_ratio) |
| 8133 | + |
8063 | 8134 |
|
8064 | 8135 | class HelpersTest(test_utils.MeridianTestCase): |
8065 | 8136 |
|
@@ -8174,6 +8245,5 @@ def test_get_triangle_corr_mat_with_geo(self, lower, expected_values): |
8174 | 8245 | actual = eda_engine.get_triangle_corr_mat(da, lower=lower) |
8175 | 8246 | np.testing.assert_allclose(actual.values, expected_values) |
8176 | 8247 |
|
8177 | | - |
8178 | 8248 | if __name__ == "__main__": |
8179 | 8249 | absltest.main() |
0 commit comments