def cache_file(
cache_dir: Optional[os.PathLike | str] = None,
backend: str = "pickle",
file_args: Optional[List[str]] = None,
ignore_args: Optional[List[str]] = None,
file_pattern: Optional[str] = None,
env_vars: Optional[List[str]] = None,
algo: str = "xxhash64",
version: Optional[str] = None,
depends_on_files: Optional[List[str]] = None,
depends_on_vars: Optional[Dict[str, Any]] = None,
verbose: bool = False,
hash_file_paths: bool = True,
) -> CacheDecorator:
"""
Disk-backed caching decorator (Python analogue of R's cacheFile).
Returns a :class:`CacheDecorator` that can be reused across functions::
cf = cache_file("/tmp/cache")
@cf
def step1(x): ...
@cf(verbose=True)
def step2(x): ...
Or used as a one-shot decorator::
@cache_file("/tmp/cache")
def f(x, y=1): ...
"""
if cache_dir is None:
cache_dir_path = cache_default_dir()
else:
cache_dir_path = Path(cache_dir)
# attempt to create directory (race-safe)
try:
cache_dir_path.mkdir(parents=True, exist_ok=True)
except Exception:
logger.warning("cache_file: could not create cache directory %s", cache_dir_path)
cache_dir_path = cache_dir_path.resolve()
backend = backend.lower()
if backend not in {"pickle"}:
raise ValueError("backend must be 'pickle'")
ext = "pkl"
# static path specs from function body (stubbed for now)
path_specs = _find_path_specs # function; we will call inside decorator
def decorator(f: Callable) -> Callable:
sig = inspect.signature(f)
ps = path_specs(f)
static_dirs_lit: List[str] = ps.get("literals", [])
static_dirs_sym: List[str] = ps.get("symbols", [])
# Detect import names at decoration time (AST is static)
_import_names = _detect_import_names(f)
def _get_path_hash(path: os.PathLike | str) -> str:
p = Path(path).resolve()
if p.is_dir():
# list files recursively, optional regex filter
files = []
for sub in sorted(p.rglob("*")):
if sub.is_file():
if file_pattern is not None:
if not re.search(file_pattern, sub.name):
continue
files.append(sub)
if not files:
return "empty_dir"
# hash (relative name, content hash) for structure + content
file_entries = []
for sub in files:
rel = str(sub.relative_to(p))
file_entries.append((rel, fast_file_hash(sub, algo=algo)))
return _digest_obj(file_entries, algo=algo)
elif p.is_file():
return fast_file_hash(p, algo=algo)
else:
return ""
def _atomic_save(obj: Any, path: Path) -> None:
"""
Atomic write:
- write to temp file in same dir
- os.replace() to target
"""
path = Path(path)
tmp_name = f"{path.name}.tmp.{''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=8))}"
tmp_path = path.with_name(tmp_name)
try:
with tmp_path.open("wb") as f2:
pickle.dump(obj, f2, protocol=pickle.HIGHEST_PROTOCOL)
try:
os.replace(tmp_path, path)
except OSError:
# fallback: copy+unlink
import shutil
shutil.copy2(tmp_path, path)
tmp_path.unlink(missing_ok=True)
# best-effort permissions (like 0664)
if os.name == "posix":
try:
os.chmod(path, 0o664)
except OSError:
pass
except Exception as e:
logger.warning("cache_file: failed to save cache file %s: %s", path, e)
try:
tmp_path.unlink(missing_ok=True)
except Exception:
pass
def _safe_load(path: Path) -> Any:
with path.open("rb") as f2:
obj = pickle.load(f2)
# expect {"dat": value, "meta": {...}}
if isinstance(obj, dict) and "dat" in obj:
return obj["dat"]
return obj
def _safe_load_full(path: Path) -> Optional[Dict[str, Any]]:
"""Load a cache file and return the raw dict (with 'dat' and 'meta')."""
try:
with path.open("rb") as f2:
obj = pickle.load(f2)
if isinstance(obj, dict) and "meta" in obj:
return obj
except Exception:
pass
return None
def wrapper(*args, _load: bool = True, _force: bool = False, _skip_save: bool = False, **kwargs):
invoke_env_globals = f.__globals__
# -------- function name for filename label --------
fname = getattr(f, "__name__", "anon")
# -------- normalize arguments (include defaults) --------
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
args_for_hash: Dict[str, Any] = dict(bound.arguments)
# remove control params from hashing
args_for_hash.pop("_load", None)
args_for_hash.pop("_force", None)
args_for_hash.pop("_skip_save", None)
if ignore_args:
for nm in ignore_args:
args_for_hash.pop(nm, None)
# order by argument name for stability
args_for_hash = dict(sorted(args_for_hash.items(), key=lambda kv: kv[0]))
# Sort **kwargs dict values for order-independent hashing
for param_name, param in sig.parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD and param_name in args_for_hash:
val = args_for_hash[param_name]
if isinstance(val, dict):
args_for_hash[param_name] = dict(sorted(val.items(), key=lambda kv: kv[0]))
# -------- resolve symlinks and normalize file_args --------
if file_args:
for nm in file_args:
if nm in args_for_hash:
val = args_for_hash[nm]
if isinstance(val, (str, Path)):
resolved = str(Path(val).resolve())
if not hash_file_paths and Path(resolved).exists():
args_for_hash[nm] = _get_path_hash(resolved)
else:
args_for_hash[nm] = resolved
# -------- dynamic path scanning over arguments --------
def _collect_paths(val: Any) -> List[Path]:
"""Recursively extract file/directory paths from any value."""
out: List[Path] = []
if isinstance(val, Path):
out.append(val)
elif isinstance(val, str):
# only treat as path if it looks like one
if os.sep in val or val.startswith(".") or val.startswith("~"):
out.append(Path(val))
elif isinstance(val, dict):
for v in val.values():
out.extend(_collect_paths(v))
elif isinstance(val, (list, tuple, set)):
for v in val:
out.extend(_collect_paths(v))
return out
dir_hashes_args: Dict[str, str] = {}
if args_for_hash:
if file_args:
scan_items = {k: v for k, v in args_for_hash.items() if k in file_args}
else:
scan_items = args_for_hash
for nm, expr_val in scan_items.items():
paths = _collect_paths(expr_val)
if not paths:
continue
for p in paths:
if p.exists():
h = _get_path_hash(p)
dir_hashes_args[str(p.resolve())] = h
# -------- static path scanning (currently just stubbed lists) --------
# literals
static_hashes_lit: Dict[str, str] = {}
for lit in static_dirs_lit:
h = _get_path_hash(lit)
static_hashes_lit[lit] = h
# symbols: look up in globals and hash underlying paths
static_hashes_sym: Dict[str, str] = {}
for sym in static_dirs_sym:
val = invoke_env_globals.get(sym)
if isinstance(val, (str, Path, list, tuple)):
paths = _collect_paths(val)
sub_hashes = {str(Path(p).resolve()): _get_path_hash(p) for p in paths}
static_hashes_sym[f"sym:{sym}"] = _digest_obj(sub_hashes, algo=algo)
# -------- environment variables --------
current_envs: Optional[Dict[str, Optional[str]]] = None
if env_vars:
vars_sorted = sorted(env_vars)
current_envs = {name: os.getenv(name) for name in vars_sorted}
# -------- recursive closure hash --------
deep_hash = get_recursive_closure_hash(f, algo=algo)
# -------- package version detection --------
pkg_versions = _get_package_versions(_import_names, f)
# -------- build master hash --------
dir_states: Dict[str, str] = {}
dir_states.update(dir_hashes_args)
dir_states.update(static_hashes_lit)
dir_states.update(static_hashes_sym)
# -------- depends_on_files hashing --------
dep_file_hashes = None
if depends_on_files:
dep_file_hashes = {p: _get_path_hash(p) for p in sorted(depends_on_files)}
hashlist = {
"call": args_for_hash,
"closure": deep_hash,
"dir_states": dict(sorted(dir_states.items(), key=lambda kv: kv[0])),
"envs": current_envs,
"version": version,
"depends_on_files": dep_file_hashes,
"depends_on_vars": depends_on_vars,
"pkgs": pkg_versions,
}
args_hash = _digest_obj(hashlist, algo=algo)
outfile = cache_dir_path / f"{fname}.{args_hash}.{ext}"
# -------- register node in cache tree --------
node_id = f"{fname}:{args_hash}"
_cache_tree_register_node(node_id, fname, args_hash, outfile)
_cache_tree_call_stack.append(node_id)
try:
# 1. optimistic load
sentinel_path = outfile.with_suffix(outfile.suffix + ".computing")
if _load and not _force and outfile.exists():
try:
result = _safe_load(outfile)
if verbose:
logger.info("[%s] cache hit", fname)
return result
except Exception:
# partial/corrupt -> ignore and recompute
pass
# 1b. check if another process is already computing
if not _force:
waited_result = _wait_for_sentinel(
sentinel_path, outfile, _safe_load, fname
)
if waited_result is not None:
return waited_result
# verbose: report why we're computing
if verbose:
if _force:
logger.info("[%s] forced re-execution", fname)
else:
existing = sorted(
cache_dir_path.glob(f"{fname}.*.{ext}"),
key=lambda p: p.stat().st_mtime,
)
if not existing:
logger.info("[%s] first execution", fname)
else:
stored = _safe_load_full(existing[-1])
if stored is not None:
sm = stored["meta"]
_MISS_LABELS = {
"call": "arguments",
"closure": "function body/closure",
"dir_states": "file/directory contents",
"envs": "environment variables",
"version": "version",
"depends_on_files": "explicit file dependencies",
"depends_on_vars": "explicit variable dependencies",
"pkgs": "package versions",
}
changes = [
label for key, label in _MISS_LABELS.items()
if sm.get(key) != hashlist.get(key)
]
if not changes:
changes = ["unknown (possibly new argument combination)"]
logger.info("[%s] cache miss -- changed: %s", fname, ", ".join(changes))
else:
logger.info("[%s] cache miss (previous entry unreadable)", fname)
# 2. record pre-execution file hashes for modification warning
pre_file_hashes: Dict[str, str] = {}
if file_args and dir_hashes_args:
pre_file_hashes = dict(dir_hashes_args)
# 3. compute with sentinel
try:
sentinel_path.touch()
except OSError:
pass
try:
dat = f(*args, **kwargs)
except Exception:
# Remove graph node on error
_cache_tree_graph.pop(node_id, None)
raise
finally:
# Always clean up sentinel
try:
sentinel_path.unlink(missing_ok=True)
except OSError:
pass
# 4. check for file modification during execution
if pre_file_hashes:
import warnings as _warnings
_file_state_cache.clear() # force re-hash
for pstr, old_h in pre_file_hashes.items():
p = Path(pstr)
if p.exists():
new_h = _get_path_hash(p)
if new_h != old_h:
_warnings.warn(
f"File modified during execution: {pstr}",
stacklevel=2,
)
if not _skip_save:
save_data = {"dat": dat, "meta": hashlist}
# 3. try file locking if available
lock = None
lock_path = outfile.with_suffix(outfile.suffix + ".lock")
try:
from filelock import FileLock # type: ignore
lock = FileLock(str(lock_path))
lock.acquire(timeout=5)
except Exception:
lock = None
try:
# double-check: maybe someone else wrote it while we computed
if _load and not _force and outfile.exists():
try:
return _safe_load(outfile)
except Exception:
pass
_atomic_save(save_data, outfile)
finally:
if lock is not None:
try:
lock.release()
except Exception:
pass
return dat
finally:
_cache_tree_call_stack.pop()
# preserve metadata
wrapper.__name__ = getattr(f, "__name__", "cached_fn")
wrapper.__doc__ = f.__doc__
wrapper.__wrapped__ = f # for inspect
return wrapper
defaults = dict(
cache_dir=cache_dir, backend=backend, file_args=file_args,
ignore_args=ignore_args, file_pattern=file_pattern,
env_vars=env_vars, algo=algo, version=version,
depends_on_files=depends_on_files, depends_on_vars=depends_on_vars,
verbose=verbose, hash_file_paths=hash_file_paths,
)
return CacheDecorator(decorator, **defaults)