Skip to content

Commit c1911c7

Browse files
authored
fix(config): Fix XDG_X_HOME env vars, add OPENML_CACHE_DIR env var (#1359)
* fix(config): Fix XDG_X_HOME env vars, add OPENML_CACHE_DIR env var * fix(config): Check correct backwards compat location * test: Add safe context manager for environ vriable
1 parent 891f4a6 commit c1911c7

2 files changed

Lines changed: 151 additions & 16 deletions

File tree

openml/config.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging.handlers
99
import os
1010
import platform
11+
import shutil
1112
import warnings
1213
from io import StringIO
1314
from pathlib import Path
@@ -20,6 +21,8 @@
2021
console_handler: logging.StreamHandler | None = None
2122
file_handler: logging.handlers.RotatingFileHandler | None = None
2223

24+
OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
25+
2326

2427
class _Config(TypedDict):
2528
apikey: str
@@ -101,14 +104,50 @@ def set_file_log_level(file_output_level: int) -> None:
101104

102105
# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards)
103106
_user_path = Path("~").expanduser().absolute()
107+
108+
109+
def _resolve_default_cache_dir() -> Path:
110+
user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR)
111+
if user_defined_cache_dir is not None:
112+
return Path(user_defined_cache_dir)
113+
114+
if platform.system().lower() != "linux":
115+
return _user_path / ".openml"
116+
117+
xdg_cache_home = os.environ.get("XDG_CACHE_HOME")
118+
if xdg_cache_home is None:
119+
return Path("~", ".cache", "openml")
120+
121+
# This is the proper XDG_CACHE_HOME directory, but
122+
# we unfortunately had a problem where we used XDG_CACHE_HOME/org,
123+
# we check heuristically if this old directory still exists and issue
124+
# a warning if it does. There's too much data to move to do this for the user.
125+
126+
# The new cache directory exists
127+
cache_dir = Path(xdg_cache_home) / "openml"
128+
if cache_dir.exists():
129+
return cache_dir
130+
131+
# The old cache directory *does not* exist
132+
heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml"
133+
if not heuristic_dir_for_backwards_compat.exists():
134+
return cache_dir
135+
136+
root_dir_to_delete = Path(xdg_cache_home) / "org"
137+
openml_logger.warning(
138+
"An old cache directory was found at '%s'. This directory is no longer used by "
139+
"OpenML-Python. To silence this warning you would need to delete the old cache "
140+
"directory. The cached files will then be located in '%s'.",
141+
root_dir_to_delete,
142+
cache_dir,
143+
)
144+
return Path(xdg_cache_home)
145+
146+
104147
_defaults: _Config = {
105148
"apikey": "",
106149
"server": "https://www.openml.org/api/v1/xml",
107-
"cachedir": (
108-
Path(os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml"))
109-
if platform.system() == "Linux"
110-
else _user_path / ".openml"
111-
),
150+
"cachedir": _resolve_default_cache_dir(),
112151
"avoid_duplicate_runs": True,
113152
"retry_policy": "human",
114153
"connection_n_retries": 5,
@@ -218,11 +257,66 @@ def stop_using_configuration_for_example(cls) -> None:
218257
cls._start_last_called = False
219258

220259

260+
def _handle_xdg_config_home_backwards_compatibility(
261+
xdg_home: str,
262+
) -> Path:
263+
# NOTE(eddiebergman): A previous bug results in the config
264+
# file being located at `${XDG_CONFIG_HOME}/config` instead
265+
# of `${XDG_CONFIG_HOME}/openml/config`. As to maintain backwards
266+
# compatibility, where users may already may have had a configuration,
267+
# we copy it over an issue a warning until it's deleted.
268+
# As a heurisitic to ensure that it's "our" config file, we try parse it first.
269+
config_dir = Path(xdg_home) / "openml"
270+
271+
backwards_compat_config_file = Path(xdg_home) / "config"
272+
if not backwards_compat_config_file.exists():
273+
return config_dir
274+
275+
# If it errors, that's a good sign it's not ours and we can
276+
# safely ignore it, jumping out of this block. This is a heurisitc
277+
try:
278+
_parse_config(backwards_compat_config_file)
279+
except Exception: # noqa: BLE001
280+
return config_dir
281+
282+
# Looks like it's ours, lets try copy it to the correct place
283+
correct_config_location = config_dir / "config"
284+
try:
285+
# We copy and return the new copied location
286+
shutil.copy(backwards_compat_config_file, correct_config_location)
287+
openml_logger.warning(
288+
"An openml configuration file was found at the old location "
289+
f"at {backwards_compat_config_file}. We have copied it to the new "
290+
f"location at {correct_config_location}. "
291+
"\nTo silence this warning please verify that the configuration file "
292+
f"at {correct_config_location} is correct and delete the file at "
293+
f"{backwards_compat_config_file}."
294+
)
295+
return config_dir
296+
except Exception as e: # noqa: BLE001
297+
# We failed to copy and its ours, return the old one.
298+
openml_logger.warning(
299+
"While attempting to perform a backwards compatible fix, we "
300+
f"failed to copy the openml config file at "
301+
f"{backwards_compat_config_file}' to {correct_config_location}"
302+
f"\n{type(e)}: {e}",
303+
"\n\nTo silence this warning, please copy the file "
304+
"to the new location and delete the old file at "
305+
f"{backwards_compat_config_file}.",
306+
)
307+
return backwards_compat_config_file
308+
309+
221310
def determine_config_file_path() -> Path:
222-
if platform.system() == "Linux":
223-
config_dir = Path(os.environ.get("XDG_CONFIG_HOME", Path("~") / ".config" / "openml"))
311+
if platform.system().lower() == "linux":
312+
xdg_home = os.environ.get("XDG_CONFIG_HOME")
313+
if xdg_home is not None:
314+
config_dir = _handle_xdg_config_home_backwards_compatibility(xdg_home)
315+
else:
316+
config_dir = Path("~", ".config", "openml")
224317
else:
225318
config_dir = Path("~") / ".openml"
319+
226320
# Still use os.path.expanduser to trigger the mock in the unit test
227321
config_dir = Path(config_dir).expanduser().resolve()
228322
return config_dir / "config"
@@ -260,11 +354,15 @@ def _setup(config: _Config | None = None) -> None:
260354
apikey = config["apikey"]
261355
server = config["server"]
262356
show_progress = config["show_progress"]
263-
short_cache_dir = Path(config["cachedir"])
264357
n_retries = int(config["connection_n_retries"])
265358

266359
set_retry_policy(config["retry_policy"], n_retries)
267360

361+
user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR)
362+
if user_defined_cache_dir is not None:
363+
short_cache_dir = Path(user_defined_cache_dir)
364+
else:
365+
short_cache_dir = Path(config["cachedir"])
268366
_root_cache_directory = short_cache_dir.expanduser().resolve()
269367

270368
try:

tests/test_openml/test_config.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4+
from contextlib import contextmanager
45
import os
56
import tempfile
67
import unittest.mock
78
from copy import copy
9+
from typing import Any, Iterator
810
from pathlib import Path
911

1012
import pytest
@@ -13,6 +15,24 @@
1315
import openml.testing
1416

1517

18+
@contextmanager
19+
def safe_environ_patcher(key: str, value: Any) -> Iterator[None]:
20+
"""Context manager to temporarily set an environment variable.
21+
22+
Safe to errors happening in the yielded to function.
23+
"""
24+
_prev = os.environ.get(key)
25+
os.environ[key] = value
26+
try:
27+
yield
28+
except Exception as e:
29+
raise e
30+
finally:
31+
os.environ.pop(key)
32+
if _prev is not None:
33+
os.environ[key] = _prev
34+
35+
1636
class TestConfig(openml.testing.TestBase):
1737
@unittest.mock.patch("openml.config.openml_logger.warning")
1838
@unittest.mock.patch("openml.config._create_log_handlers")
@@ -29,15 +49,22 @@ def test_non_writable_home(self, log_handler_mock, warnings_mock):
2949
assert not log_handler_mock.call_args_list[0][1]["create_file_handler"]
3050
assert openml.config._root_cache_directory == Path(td) / "something-else"
3151

32-
@unittest.mock.patch("os.path.expanduser")
33-
def test_XDG_directories_do_not_exist(self, expanduser_mock):
52+
def test_XDG_directories_do_not_exist(self):
3453
with tempfile.TemporaryDirectory(dir=self.workdir) as td:
54+
# Save previous state
55+
path = Path(td) / "fake_xdg_cache_home"
56+
with safe_environ_patcher("XDG_CONFIG_HOME", str(path)):
57+
expected_config_dir = path / "openml"
58+
expected_determined_config_file_path = expected_config_dir / "config"
3559

36-
def side_effect(path_):
37-
return os.path.join(td, str(path_).replace("~/", ""))
60+
# Ensure that it correctly determines the path to the config file
61+
determined_config_file_path = openml.config.determine_config_file_path()
62+
assert determined_config_file_path == expected_determined_config_file_path
3863

39-
expanduser_mock.side_effect = side_effect
40-
openml.config._setup()
64+
# Ensure that setup will create the config folder as the configuration
65+
# will be written to that location.
66+
openml.config._setup()
67+
assert expected_config_dir.exists()
4168

4269
def test_get_config_as_dict(self):
4370
"""Checks if the current configuration is returned accurately as a dict."""
@@ -121,7 +148,7 @@ def test_example_configuration_start_twice(self):
121148

122149

123150
def test_configuration_file_not_overwritten_on_load():
124-
""" Regression test for #1337 """
151+
"""Regression test for #1337"""
125152
config_file_content = "apikey = abcd"
126153
with tempfile.TemporaryDirectory() as tmpdir:
127154
config_file_path = Path(tmpdir) / "config"
@@ -136,12 +163,22 @@ def test_configuration_file_not_overwritten_on_load():
136163
assert config_file_content == new_file_content
137164
assert "abcd" == read_config["apikey"]
138165

166+
139167
def test_configuration_loads_booleans(tmp_path):
140168
config_file_content = "avoid_duplicate_runs=true\nshow_progress=false"
141-
with (tmp_path/"config").open("w") as config_file:
169+
with (tmp_path / "config").open("w") as config_file:
142170
config_file.write(config_file_content)
143171
read_config = openml.config._parse_config(tmp_path)
144172

145173
# Explicit test to avoid truthy/falsy modes of other types
146174
assert True == read_config["avoid_duplicate_runs"]
147175
assert False == read_config["show_progress"]
176+
177+
178+
def test_openml_cache_dir_env_var(tmp_path: Path) -> None:
179+
expected_path = tmp_path / "test-cache"
180+
181+
with safe_environ_patcher("OPENML_CACHE_DIR", str(expected_path)):
182+
openml.config._setup()
183+
assert openml.config._root_cache_directory == expected_path
184+
assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")

0 commit comments

Comments
 (0)