Skip to content

Commit 9dd89b0

Browse files
Bordashaypal5
andauthored
convert _default_params to dataclass (#237)
* convert `_default_params` to dataclass * lint * copy * set * keys * assert * global * global func * warning * fix import * lint --------- Co-authored-by: Shay Palachy-Affek <shaypal5@users.noreply.github.com>
1 parent 3bbe33b commit 9dd89b0

8 files changed

Lines changed: 92 additions & 65 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ repos:
4747
args: ["--number"]
4848

4949
- repo: https://github.com/pre-commit/mirrors-prettier
50-
rev: v4.0.0-alpha.8
50+
rev: v3.1.0
5151
hooks:
5252
- id: prettier
5353
files: \.(json|yml|yaml|toml)

src/cachier/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from .config import (
33
disable_caching,
44
enable_caching,
5-
get_default_params,
6-
set_default_params,
5+
get_global_params,
6+
set_global_params,
77
)
88
from .core import cachier
99

1010
__all__ = [
1111
"cachier",
12-
"set_default_params",
13-
"get_default_params",
12+
"set_global_params",
13+
"get_global_params",
1414
"enable_caching",
1515
"disable_caching",
1616
]

src/cachier/config.py

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import hashlib
33
import os
44
import pickle
5-
from typing import Optional, TypedDict, Union
5+
from collections.abc import Mapping
6+
from dataclasses import dataclass, replace
7+
from typing import Optional, Union
68

79
from ._types import Backend, HashFunc, Mongetter
810

@@ -16,35 +18,24 @@ def _default_hash_func(args, kwds):
1618
return hashlib.sha256(serialized).hexdigest()
1719

1820

19-
class Params(TypedDict):
20-
"""Type definition for cachier parameters."""
21-
22-
caching_enabled: bool
23-
hash_func: HashFunc
24-
backend: Backend
25-
mongetter: Optional[Mongetter]
26-
stale_after: datetime.timedelta
27-
next_time: bool
28-
cache_dir: Union[str, os.PathLike]
29-
pickle_reload: bool
30-
separate_files: bool
31-
wait_for_calc_timeout: int
32-
allow_none: bool
33-
34-
35-
_default_params: Params = {
36-
"caching_enabled": True,
37-
"hash_func": _default_hash_func,
38-
"backend": "pickle",
39-
"mongetter": None,
40-
"stale_after": datetime.timedelta.max,
41-
"next_time": False,
42-
"cache_dir": "~/.cachier/",
43-
"pickle_reload": True,
44-
"separate_files": False,
45-
"wait_for_calc_timeout": 0,
46-
"allow_none": False,
47-
}
21+
@dataclass
22+
class Params:
23+
"""Default definition for cachier parameters."""
24+
25+
caching_enabled: bool = True
26+
hash_func: HashFunc = _default_hash_func
27+
backend: Backend = "pickle"
28+
mongetter: Optional[Mongetter] = None
29+
stale_after: datetime.timedelta = datetime.timedelta.max
30+
next_time: bool = False
31+
cache_dir: Union[str, os.PathLike] = "~/.cachier/"
32+
pickle_reload: bool = True
33+
separate_files: bool = False
34+
wait_for_calc_timeout: int = 0
35+
allow_none: bool = False
36+
37+
38+
_global_params = Params()
4839

4940

5041
def _update_with_defaults(
@@ -57,11 +48,25 @@ def _update_with_defaults(
5748
if kw_name in func_kwargs:
5849
return func_kwargs.pop(kw_name)
5950
if param is None:
60-
return cachier.config._default_params[name]
51+
return getattr(cachier.config._global_params, name)
6152
return param
6253

6354

64-
def set_default_params(**params):
55+
def set_default_params(**params: Mapping) -> None:
56+
"""Configure default parameters applicable to all memoized functions."""
57+
# It is kept for backwards compatibility with desperation warning
58+
import warnings
59+
60+
warnings.warn(
61+
"Called `set_default_params` is deprecated and will be removed."
62+
" Please use `set_global_params` instead.",
63+
DeprecationWarning,
64+
stacklevel=2,
65+
)
66+
set_global_params(**params)
67+
68+
69+
def set_global_params(**params: Mapping) -> None:
6570
"""Configure global parameters applicable to all memoized functions.
6671
6772
This function takes the same keyword parameters as the ones defined in the
@@ -76,28 +81,46 @@ def set_default_params(**params):
7681
"""
7782
import cachier
7883

79-
valid_params = (
80-
p for p in params.items() if p[0] in cachier.config._default_params
84+
valid_params = {
85+
k: v
86+
for k, v in params.items()
87+
if hasattr(cachier.config._global_params, k)
88+
}
89+
cachier.config._global_params = replace(
90+
cachier.config._global_params, **valid_params
91+
)
92+
93+
94+
def get_default_params() -> Params:
95+
"""Get current set of default parameters."""
96+
# It is kept for backwards compatibility with desperation warning
97+
import warnings
98+
99+
warnings.warn(
100+
"Called `get_default_params` is deprecated and will be removed."
101+
" Please use `get_global_params` instead.",
102+
DeprecationWarning,
103+
stacklevel=2,
81104
)
82-
_default_params.update(valid_params)
105+
return get_global_params()
83106

84107

85-
def get_default_params():
108+
def get_global_params() -> Params:
86109
"""Get current set of default parameters."""
87110
import cachier
88111

89-
return cachier.config._default_params
112+
return cachier.config._global_params
90113

91114

92115
def enable_caching():
93116
"""Enable caching globally."""
94117
import cachier
95118

96-
cachier.config._default_params["caching_enabled"] = True
119+
cachier.config._global_params.caching_enabled = True
97120

98121

99122
def disable_caching():
100123
"""Disable caching globally."""
101124
import cachier
102125

103-
cachier.config._default_params["caching_enabled"] = False
126+
cachier.config._global_params.caching_enabled = False

src/cachier/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Backend,
2222
HashFunc,
2323
Mongetter,
24-
_default_params,
2524
_update_with_defaults,
2625
)
2726
from .cores.base import RecalculationNeeded, _BaseCore
@@ -176,6 +175,8 @@ def cachier(
176175
None will not be cached and are recalculated every call.
177176
178177
"""
178+
from .config import _global_params
179+
179180
# Check for deprecated parameters
180181
if hash_params is not None:
181182
message = (
@@ -244,7 +245,7 @@ def func_wrapper(*args, **kwds):
244245
_print = lambda x: None # noqa: E731
245246
if verbose:
246247
_print = print
247-
if ignore_cache or not _default_params["caching_enabled"]:
248+
if ignore_cache or not _global_params.caching_enabled:
248249
return (
249250
func(args[0], **kwargs)
250251
if core.func_is_method

tests/test_core_lookup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import pytest
44

5-
from cachier import cachier, get_default_params
5+
from cachier import cachier, get_global_params
66
from cachier.cores.mongo import MissingMongetter
77

88

99
def test_get_default_params():
10-
params = get_default_params()
11-
assert tuple(sorted(params)) == (
10+
params = get_global_params()
11+
assert sorted(vars(params).keys()) == [
1212
"allow_none",
1313
"backend",
1414
"cache_dir",
@@ -20,7 +20,7 @@ def test_get_default_params():
2020
"separate_files",
2121
"stale_after",
2222
"wait_for_calc_timeout",
23-
)
23+
]
2424

2525

2626
def test_bad_name(name="nope"):

tests/test_defaults.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,23 @@
44
import random
55
import threading
66
import time
7+
from dataclasses import replace
78

89
import pytest
910

1011
import cachier
1112
from tests.test_mongo_core import _test_mongetter
1213

1314
MONGO_DELTA = datetime.timedelta(seconds=3)
14-
_default_params = cachier.get_default_params().copy()
15+
_copied_defaults = replace(cachier.get_global_params())
1516

1617

1718
def setup_function():
18-
cachier.set_default_params(**_default_params)
19+
cachier.set_global_params(**vars(_copied_defaults))
1920

2021

2122
def teardown_function():
22-
cachier.set_default_params(**_default_params)
23+
cachier.set_global_params(**vars(_copied_defaults))
2324

2425

2526
def test_hash_func_default_param():
@@ -30,7 +31,7 @@ def slow_hash_func(args, kwds):
3031
def fast_hash_func(args, kwds):
3132
return "hash"
3233

33-
cachier.set_default_params(hash_func=slow_hash_func)
34+
cachier.set_global_params(hash_func=slow_hash_func)
3435

3536
@cachier.cachier()
3637
def global_test_1():
@@ -51,7 +52,7 @@ def global_test_2():
5152

5253

5354
def test_backend_default_param():
54-
cachier.set_default_params(backend="memory")
55+
cachier.set_global_params(backend="memory")
5556

5657
@cachier.cachier()
5758
def global_test_1():
@@ -67,7 +68,7 @@ def global_test_2():
6768

6869
@pytest.mark.mongo
6970
def test_mongetter_default_param():
70-
cachier.set_default_params(mongetter=_test_mongetter)
71+
cachier.set_global_params(mongetter=_test_mongetter)
7172

7273
@cachier.cachier()
7374
def global_test_1():
@@ -82,7 +83,7 @@ def global_test_2():
8283

8384

8485
def test_cache_dir_default_param(tmpdir):
85-
cachier.set_default_params(cache_dir=tmpdir / "1")
86+
cachier.set_global_params(cache_dir=tmpdir / "1")
8687

8788
@cachier.cachier()
8889
def global_test_1():
@@ -97,7 +98,7 @@ def global_test_2():
9798

9899

99100
def test_separate_files_default_param(tmpdir):
100-
cachier.set_default_params(separate_files=True)
101+
cachier.set_global_params(separate_files=True)
101102

102103
@cachier.cachier(cache_dir=tmpdir / "1")
103104
def global_test_1(arg_1, arg_2):
@@ -117,7 +118,7 @@ def global_test_2(arg_1, arg_2):
117118

118119

119120
def test_allow_none_default_param(tmpdir):
120-
cachier.set_default_params(
121+
cachier.set_global_params(
121122
allow_none=True,
122123
separate_files=True,
123124
verbose_cache=True,
@@ -167,7 +168,7 @@ def _stale_after_test(arg_1, arg_2):
167168
"""Some function."""
168169
return random.random() + arg_1 + arg_2
169170

170-
cachier.set_default_params(stale_after=MONGO_DELTA)
171+
cachier.set_global_params(stale_after=MONGO_DELTA)
171172

172173
_stale_after_test.clear_cache()
173174
val1 = _stale_after_test(1, 2)
@@ -187,7 +188,7 @@ def _stale_after_next_time(arg_1, arg_2):
187188
"""Some function."""
188189
return random.random()
189190

190-
cachier.set_default_params(stale_after=NEXT_AFTER_DELTA, next_time=True)
191+
cachier.set_global_params(stale_after=NEXT_AFTER_DELTA, next_time=True)
191192

192193
_stale_after_next_time.clear_cache()
193194
val1 = _stale_after_next_time(1, 2)
@@ -217,7 +218,7 @@ def _calls_wait_for_calc_timeout_slow(res_queue):
217218
res = _wait_for_calc_timeout_slow(1, 2)
218219
res_queue.put(res)
219220

220-
cachier.set_default_params(wait_for_calc_timeout=2)
221+
cachier.set_global_params(wait_for_calc_timeout=2)
221222
_wait_for_calc_timeout_slow.clear_cache()
222223
res_queue = queue.Queue()
223224
thread1 = threading.Thread(

tests/test_general.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,17 @@ def test_separate_processes():
229229

230230
def test_global_disable():
231231
@cachier.cachier()
232-
def get_random():
232+
def get_random() -> float:
233233
return random()
234234

235235
get_random.clear_cache()
236236
result_1 = get_random()
237237
result_2 = get_random()
238238
cachier.disable_caching()
239+
assert cachier.config._global_params.caching_enabled is False
239240
result_3 = get_random()
240241
cachier.enable_caching()
242+
assert cachier.config._global_params.caching_enabled is True
241243
result_4 = get_random()
242244
assert result_1 == result_2 == result_4
243245
assert result_1 != result_3

tests/test_pickle_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import pandas as pd
3131

3232
from cachier import cachier
33-
from cachier.core import _default_params
33+
from cachier.config import _global_params
3434

3535

3636
def _get_decorated_func(func, **kwargs):
@@ -329,7 +329,7 @@ def _bad_cache(arg_1, arg_2):
329329
".tests.test_pickle_core._bad_cache_"
330330
f"{hashlib.sha256(pickle.dumps((0.13, 0.02))).hexdigest()}"
331331
)
332-
EXPANDED_CACHIER_DIR = os.path.expanduser(_default_params["cache_dir"])
332+
EXPANDED_CACHIER_DIR = os.path.expanduser(_global_params.cache_dir)
333333
_BAD_CACHE_FPATH = os.path.join(EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME)
334334
_BAD_CACHE_FPATH_SEPARATE_FILES = os.path.join(
335335
EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME_SEPARATE_FILES

0 commit comments

Comments
 (0)