Source code for ml_research_tools.cache.redis

"""Redis caching utilities for ML Research Tools.

This module provides utilities for Redis caching, including a Redis cache manager
class and functions for common caching operations.
"""

import hashlib
import itertools
import json
import pickle
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast

import redis

from ml_research_tools.core.config import RedisConfig
from ml_research_tools.core.logging_tools import get_logger

logger = get_logger(__name__)

T = TypeVar("T")
R = TypeVar("R")


[docs] def create_redis_client(config: RedisConfig) -> Optional[redis.Redis]: """Create and return a Redis client based on configuration. Args: config: Redis configuration from the Config object Returns: Redis client instance or None if disabled or connection failed """ if not config.enabled: logger.info("Redis caching is disabled") return None try: client = redis.Redis( host=config.host, port=config.port, db=config.db, password=config.password, decode_responses=False, # Keep binary for proper serialization ) # Test connection client.ping() logger.info(f"Connected to Redis at {config.host}:{config.port} (db: {config.db})") return client except redis.ConnectionError as e: logger.warning(f"Failed to connect to Redis: {e}. Caching will be disabled.") return None except Exception as e: logger.warning(f"Unexpected error connecting to Redis: {e}. Caching will be disabled.") return None
[docs] def generate_cache_key( args: Any = None, kwargs: dict[str, Any] | None = None, prefix: str = "" ) -> str: """Generate a unique cache key based on input parameters. Args: *args: Arguments to include in the key generation prefix: Optional prefix for the key (e.g., function name) Returns: A string key suitable for Redis """ args = args or [] kwargs = kwargs or {} # Convert arguments to a consistent string representation key_parts = [] for key, arg in itertools.chain(enumerate(args), kwargs.items()): if isinstance(arg, (str, int, float, bool, type(None))): key_parts.append(f"{key}:{arg}") else: try: # Try to convert to JSON for consistent string representation arg = json.dumps(arg, sort_keys=True) except (TypeError, ValueError): arg = repr(arg) logger.debug(f"Non-serializable argument {key}: {arg}. Using repr() instead.") key_parts.append(f"{key}:{arg}") # Create a combined string with prefix combined = f"{prefix}|{'|'.join(key_parts)}" # Hash it to create a fixed-length key that's safe for Redis return hashlib.md5(combined.encode("utf-8")).hexdigest()
[docs] def get_from_cache(redis_client: Optional[redis.Redis], cache_key: str) -> Optional[bytes]: """Retrieve data from Redis cache if available. Args: redis_client: Redis client instance or None cache_key: Unique cache key for the data Returns: Cached data as bytes or None if not found """ if redis_client is None: return None try: cached_data = redis_client.get(cache_key) if cached_data: logger.debug(f"Cache hit for key: {cache_key}") return cached_data logger.debug(f"Cache miss for key: {cache_key}") return None except Exception as e: logger.warning(f"Error retrieving from cache: {e}") return None
[docs] def save_to_cache( redis_client: Optional[redis.Redis], cache_key: str, data: bytes, ttl: int ) -> bool: """Save data to Redis cache with the specified TTL. Args: redis_client: Redis client instance or None cache_key: Unique cache key for the data data: Data to cache (as bytes) ttl: Time-to-live in seconds (0 for no expiration) Returns: True if successfully cached, False otherwise """ if redis_client is None: return False try: if ttl > 0: redis_client.setex(cache_key, ttl, data) logger.debug(f"Saved data to cache key {cache_key} with TTL of {ttl} seconds") else: redis_client.set(cache_key, data) logger.debug(f"Saved data to cache key {cache_key} with no TTL") return True except Exception as e: logger.warning(f"Error saving to cache: {e}") return False
[docs] class RedisCache: """Redis cache manager for ML Research Tools. This class provides a simple interface for Redis caching operations, including serialization and deserialization of complex Python objects. Example: :: from ml_research_tools.config import get_config from ml_research_tools.cache import RedisCache config = get_config() cache = RedisCache(config.redis) # Cache a Python object data = {"results": [1, 2, 3]} cache.set("my_key", data) # Get it back retrieved = cache.get("my_key") """
[docs] def __init__(self, config: RedisConfig): """Initialize Redis cache manager. Args: config: Redis configuration from Config object """ self.config = config self.client = create_redis_client(config) self._enabled = config.enabled and self.client is not None self._recache = config.recache self.ttl = config.ttl
@property def enabled(self) -> bool: """Return whether caching is enabled.""" return self._enabled @property def recache(self) -> bool: """Return whether recaching is enabled (don't use cached values).""" return self._recache
[docs] def get(self, key: str, default: Optional[T] = None) -> Optional[T]: """Get value from cache by key. Args: key: Cache key default: Default value to return if key not found Returns: Cached value or default """ if not self.enabled or self.recache: return default cached = get_from_cache(self.client, key) if cached is None: return default try: return pickle.loads(cached) # type: ignore except (pickle.PickleError, EOFError) as e: logger.warning(f"Failed to deserialize cached value: {e}") return default
[docs] def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: """Set value in cache with optional TTL. Args: key: Cache key value: Value to cache (can be any pickle-serializable object) ttl: Time-to-live in seconds (None uses config default) Returns: True if successfully cached, False otherwise """ if not self.enabled: return False ttl_value = ttl if ttl is not None else self.ttl try: serialized = pickle.dumps(value) return save_to_cache(self.client, key, serialized, ttl_value) except (pickle.PickleError, TypeError) as e: logger.warning(f"Failed to serialize value for caching: {e}") return False
[docs] def delete(self, key: str) -> bool: """Delete key from cache. Args: key: Cache key to delete Returns: True if key was deleted, False otherwise """ if not self.enabled: return False try: return bool(self.client.delete(key)) # type: ignore except Exception as e: logger.warning(f"Error deleting key from cache: {e}") return False
[docs] def exists(self, key: str) -> bool: """Check if key exists in cache. Args: key: Cache key to check Returns: True if key exists, False otherwise """ if not self.enabled: return False try: return bool(self.client.exists(key)) # type: ignore except Exception as e: logger.warning(f"Error checking key existence in cache: {e}") return False
[docs] def clear(self, pattern: str = "*") -> bool: """Clear cache keys matching pattern. Args: pattern: Redis key pattern to match (default: all keys) Returns: True if successful, False otherwise """ if not self.enabled: return False try: pipeline = self.client.pipeline() # type: ignore # Get keys matching pattern keys = self.client.keys(pattern) # type: ignore if keys: # Use pipeline to delete all keys at once pipeline.delete(*keys) pipeline.execute() logger.info(f"Cleared {len(keys)} keys from cache") else: logger.info(f"No keys found matching pattern: {pattern}") return True except Exception as e: logger.warning(f"Error clearing cache: {e}") return False
[docs] def cached( prefix: str = "", ttl: Optional[int] = None, key_fn: Optional[Callable[..., str]] = None, ) -> Callable[[Callable[..., R]], Callable[..., R]]: """Decorator to cache function results in Redis. Args: prefix: Prefix for cache keys ttl: Time-to-live in seconds (None uses config default) key_fn: Custom function to generate cache key (if None, uses generate_cache_key) Returns: Decorator function Example: :: from ml_research_tools.cache.redis import cached from ml_research_tools.config import get_config config = get_config() @cached(prefix="expensive_computation", ttl=3600) def expensive_computation(a, b, c): # ... some expensive calculation return result """ def decorator(func: Callable[..., R]) -> Callable[..., R]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> R: cache_instance = None for arg in itertools.chain(args, kwargs.values()): if isinstance(arg, RedisCache): cache_instance = arg break else: assert False, "RedisCache instance not found in arguments" if not cache_instance.enabled: return func(*args, **kwargs) key_args = [arg for arg in args if not isinstance(arg, RedisCache)] key_kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, RedisCache)} # Generate key if key_fn: key = key_fn(*key_args, **key_kwargs) assert prefix == "", "Prefix should be empty when using key_fn" else: func_prefix = f"{func.__module__}.{func.__name__}" if prefix: func_prefix = f"{prefix}:{func_prefix}" key = generate_cache_key(key_args, key_kwargs, prefix=func_prefix) # Check cache cached_result = cache_instance.get(key) if cached_result is not None: return cast(R, cached_result) # Call the function result = func(*args, **kwargs) # Cache the result cache_instance.set(key, result, ttl) return result return wrapper return decorator