Skip to content

Commit 41af34e

Browse files
riemanliThe Meridian Authors
authored andcommitted
Add data-to-parameter ratio check to EDA engine.
PiperOrigin-RevId: 893707590
1 parent 703ea65 commit 41af34e

4 files changed

Lines changed: 155 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add EDA check for data-to-parameter ratio (DATA_ADEQUACY).
2627
* Add JAX 64-bit precision opt-in configuration.
2728
* Ensure consistent float precision across tensors, NumPy arrays, and prior
2829
distributions.

meridian/model/eda/eda_engine.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Collection, Sequence
2020
import dataclasses
2121
import functools
22+
import textwrap
2223
import typing
2324
from typing import Protocol
2425
import warnings
@@ -2492,3 +2493,66 @@ def check_population_corr_raw_media(
24922493
explanation=eda_constants.POPULATION_CORRELATION_RAW_MEDIA_INFO,
24932494
check_name='check_population_corr_raw_media',
24942495
)
2496+
2497+
def check_data_param_ratio(
2498+
self,
2499+
) -> eda_outcome.EDAOutcome[eda_outcome.DataParameterRatioArtifact]:
2500+
"""Checks the ratio of data points to model parameters.
2501+
2502+
Returns:
2503+
An EDAOutcome object with findings and result values.
2504+
"""
2505+
n_geos = self._model_context.n_geos
2506+
n_times = self._model_context.n_times
2507+
n_knots = self._model_context.knot_info.n_knots
2508+
n_controls = self._model_context.n_controls
2509+
n_treatments = (
2510+
self._model_context.n_media_channels
2511+
+ self._model_context.n_rf_channels
2512+
+ self._model_context.n_organic_media_channels
2513+
+ self._model_context.n_organic_rf_channels
2514+
+ self._model_context.n_non_media_channels
2515+
)
2516+
2517+
n_parameters = (n_geos - 1) + n_knots + n_controls + n_treatments
2518+
n_data_points = n_geos * n_times
2519+
ratio = n_data_points / n_parameters if n_parameters > 0 else float('inf')
2520+
2521+
artifact = eda_outcome.DataParameterRatioArtifact(
2522+
level=eda_outcome.AnalysisLevel.OVERALL,
2523+
n_parameters=n_parameters,
2524+
n_data_points=n_data_points,
2525+
ratio=ratio,
2526+
)
2527+
2528+
explanation = textwrap.dedent(f"""\
2529+
As a rough guidance, please review the ratio of data points to
2530+
parameters, where
2531+
* the number of data points = n_geos * n_times,
2532+
* the number of parameters = (n_geos-1) + n_knots + n_controls + n_treatments.\n
2533+
A very small ratio indicates insufficient data for estimation.
2534+
In that case, consider dropping or combining channels,
2535+
or reducing the number of knots with `knots` argument in `ModelSpec`.
2536+
For more details, please refer to this documentation page:
2537+
https://developers.google.com/meridian/docs/pre-modeling/amount-data-needed.\n
2538+
This ratio is {ratio:.2f} for your dataset, where
2539+
* n_geos = {n_geos}
2540+
* n_times = {n_times}
2541+
* n_knots = {n_knots}
2542+
* n_controls = {n_controls}
2543+
* n_treatments = {n_treatments}.""")
2544+
2545+
findings = [
2546+
eda_outcome.EDAFinding(
2547+
severity=eda_outcome.EDASeverity.INFO,
2548+
explanation=explanation,
2549+
finding_cause=eda_outcome.FindingCause.NONE,
2550+
associated_artifact=artifact,
2551+
)
2552+
]
2553+
2554+
return eda_outcome.EDAOutcome(
2555+
check_type=eda_outcome.EDACheckType.DATA_ADEQUACY,
2556+
findings=findings,
2557+
analysis_artifacts=[artifact],
2558+
)

meridian/model/eda/eda_engine_test.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from meridian import constants
2222
from meridian.backend import test_utils
2323
from meridian.model import context
24+
from meridian.model import knots
2425
from meridian.model import model_test_data
2526
from meridian.model import spec as model_spec
2627
from meridian.model.eda import constants as eda_constants
@@ -8060,6 +8061,76 @@ def test_check_population_corr_raw_media_returns_info_finding(self):
80608061
with self.subTest("finding_associated_artifact"):
80618062
self.assertIs(finding.associated_artifact, outcome.analysis_artifacts[0])
80628063

8064+
def test_check_data_param_ratio(self):
8065+
mock_model_context = mock.create_autospec(
8066+
context.ModelContext, instance=True, spec_set=True
8067+
)
8068+
8069+
mock_model_context.n_geos = 10
8070+
mock_model_context.n_times = 20
8071+
mock_model_context.n_controls = 2
8072+
mock_model_context.n_media_channels = 3
8073+
mock_model_context.n_rf_channels = 1
8074+
mock_model_context.n_organic_media_channels = 0
8075+
mock_model_context.n_organic_rf_channels = 0
8076+
mock_model_context.n_non_media_channels = 0
8077+
n_knots = 5
8078+
mock_model_context.knot_info = knots.KnotInfo(
8079+
n_knots=n_knots, knot_locations=np.array([1]), weights=np.array([1])
8080+
)
8081+
mock.seal(mock_model_context)
8082+
8083+
engine = eda_engine.EDAEngine(model_context=mock_model_context)
8084+
outcome = engine.check_data_param_ratio()
8085+
8086+
expected_n_treatments = (
8087+
mock_model_context.n_media_channels
8088+
+ mock_model_context.n_rf_channels
8089+
+ mock_model_context.n_organic_media_channels
8090+
+ mock_model_context.n_organic_rf_channels
8091+
+ mock_model_context.n_non_media_channels
8092+
)
8093+
expected_n_parameters = (
8094+
(mock_model_context.n_geos - 1)
8095+
+ n_knots
8096+
+ mock_model_context.n_controls
8097+
+ expected_n_treatments
8098+
)
8099+
expected_n_data_points = (
8100+
mock_model_context.n_geos * mock_model_context.n_times
8101+
)
8102+
expected_ratio = expected_n_data_points / expected_n_parameters
8103+
8104+
with self.subTest(name="check_type"):
8105+
self.assertEqual(
8106+
outcome.check_type, eda_outcome.EDACheckType.DATA_ADEQUACY
8107+
)
8108+
8109+
with self.subTest(name="finding_details"):
8110+
self.assertLen(outcome.findings, 1)
8111+
finding = outcome.findings[0]
8112+
self.assertEqual(finding.severity, eda_outcome.EDASeverity.INFO)
8113+
self.assertIn(
8114+
"As a rough guidance, please review the ratio of data points to",
8115+
finding.explanation,
8116+
)
8117+
self.assertIn(
8118+
f"This ratio is {expected_ratio:.2f} for your dataset",
8119+
finding.explanation,
8120+
)
8121+
self.assertIn(
8122+
f"* n_treatments = {expected_n_treatments}.",
8123+
finding.explanation,
8124+
)
8125+
8126+
with self.subTest(name="artifact_details"):
8127+
self.assertLen(outcome.analysis_artifacts, 1)
8128+
artifact = outcome.analysis_artifacts[0]
8129+
self.assertIsInstance(artifact, eda_outcome.DataParameterRatioArtifact)
8130+
self.assertEqual(artifact.n_parameters, expected_n_parameters)
8131+
self.assertEqual(artifact.n_data_points, expected_n_data_points)
8132+
self.assertAlmostEqual(artifact.ratio, expected_ratio)
8133+
80638134

80648135
class HelpersTest(test_utils.MeridianTestCase):
80658136

@@ -8174,6 +8245,5 @@ def test_get_triangle_corr_mat_with_geo(self, lower, expected_values):
81748245
actual = eda_engine.get_triangle_corr_mat(da, lower=lower)
81758246
np.testing.assert_allclose(actual.values, expected_values)
81768247

8177-
81788248
if __name__ == "__main__":
81798249
absltest.main()

meridian/model/eda/eda_outcome.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
__all__ = (
2525
"EDASeverity",
26-
"EDAFinding",
26+
"FindingCause",
2727
"AnalysisLevel",
2828
"AnalysisArtifact",
29-
"FindingCause",
29+
"EDAFinding",
3030
"PairwiseCorrArtifact",
3131
"StandardDeviationArtifact",
3232
"VIFArtifact",
@@ -35,8 +35,8 @@
3535
"VariableGeoTimeCollinearityArtifact",
3636
"PopulationCorrelationArtifact",
3737
"PriorProbabilityArtifact",
38+
"DataParameterRatioArtifact",
3839
"EDACheckType",
39-
"ArtifactType",
4040
"EDAOutcome",
4141
"CriticalCheckEDAOutcomes",
4242
)
@@ -260,6 +260,21 @@ class PriorProbabilityArtifact(AnalysisArtifact):
260260
mean_prior_contribution_da: xr.DataArray
261261

262262

263+
@dataclasses.dataclass(frozen=True)
264+
class DataParameterRatioArtifact(AnalysisArtifact):
265+
"""Artifact for model complexity check.
266+
267+
Attributes:
268+
n_parameters: The number of parameters in the model.
269+
n_data_points: The number of data points.
270+
ratio: The ratio of data points to parameters.
271+
"""
272+
273+
n_parameters: int
274+
n_data_points: int
275+
ratio: float
276+
277+
263278
@enum.unique
264279
class EDACheckType(enum.Enum):
265280
"""Enumeration for the type of an EDA check."""
@@ -272,6 +287,7 @@ class EDACheckType(enum.Enum):
272287
VARIABLE_GEO_TIME_COLLINEARITY = enum.auto()
273288
POPULATION_CORRELATION = enum.auto()
274289
PRIOR_PROBABILITY = enum.auto()
290+
DATA_ADEQUACY = enum.auto()
275291

276292

277293
ArtifactType = typing.TypeVar("ArtifactType", bound=AnalysisArtifact)

0 commit comments

Comments
 (0)