Source code for ml_research_tools.exp.wandb_downloader_tool
#!/usr/bin/env python3
"""
Tool to download Weights & Biases (W&B) run logs to local JSON files.
"""
import argparse
import json
import logging
import os
import re
from typing import TYPE_CHECKING, Any, Set
if TYPE_CHECKING:
import wandb
from wandb.apis.public.runs import Run as WandbRun
from wandb.apis.public.runs import Runs as WandbRuns
else:
WandbRun = Any
WandbRuns = Any
from rich.panel import Panel
from ml_research_tools.core.base_tool import BaseTool
[docs]
class WandbDownloaderTool(BaseTool):
"""Tool for downloading W&B run logs to local JSON files."""
name = "wandb-downloader"
description = "Download artifacts and runs from Weights & Biases"
[docs]
def __init__(self, services) -> None:
"""Initialize the W&B downloader tool."""
super().__init__(services)
self.logger = logging.getLogger(__name__)
[docs]
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> None:
"""Add tool-specific arguments to the parser."""
parser.add_argument(
"--entity",
"-e",
default=os.environ.get("WANDB_ENTITY"),
help="W&B entity (username or team name). Can also use WANDB_ENTITY env variable.",
)
parser.add_argument(
"--project",
"-p",
default=os.environ.get("WANDB_PROJECT"),
help="W&B project name. Can also use WANDB_PROJECT env variable.",
)
parser.add_argument(
"--output-dir",
"-o",
default="wandb_logs",
help="Directory to save log files (default: wandb_logs)",
)
parser.add_argument(
"--timeout",
"-t",
type=int,
default=30,
help="API timeout in seconds (default: 30)",
)
parser.add_argument(
"--quiet",
"-q",
action="store_true",
help="Suppress progress bar and detailed logging",
)
parser.add_argument(
"--no-delete",
action="store_true",
help="Don't delete logs for runs that no longer exist",
)
[docs]
def download_wandb_logs(
self,
entity: str,
project: str,
output_dir: str = "wandb_logs",
timeout: int = 30,
quiet: bool = False,
delete_outdated: bool = True,
) -> int:
"""
Download W&B logs for a specified project to local JSON files.
Args:
entity: The W&B entity (username or team name)
project: The W&B project name
output_dir: Directory where log files will be saved
timeout: API timeout in seconds
quiet: If True, suppress progress bar
delete_outdated: If True, delete logs for runs that no longer exist
"""
import wandb
# Initialize the W&B API
self.logger.info(f"Initializing W&B API for {entity}/{project}")
try:
api = wandb.Api(timeout=timeout)
except Exception as e:
self.logger.exception(f"Failed to initialize W&B API")
return 1
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
# Retrieve all runs from the specified project
self.logger.info("Retrieving runs from W&B...")
try:
runs: WandbRuns = api.runs(
path=f"{entity}/{project}",
per_page=32,
order="-created_at",
)
except Exception as e:
self.logger.exception(f"Failed to retrieve runs for {entity}/{project}")
return 1
# Use Rich Progress for loading metadata
self.logger.info(f"Found {len(runs)} runs, loading metadata...")
# Create a set of current run IDs from W&B
current_run_ids = set()
# Use rich progress if not in quiet mode
if not quiet:
with self.create_progress() as progress:
meta_task = progress.add_task("[cyan]Loading run metadata...", total=len(runs))
for run in runs:
current_run_ids.add(run.id)
progress.update(
meta_task, advance=1, description=f"[cyan]Loading run {run.id}..."
)
else:
# Simple operation if quiet mode
current_run_ids = set(run.id for run in runs)
# Display summary of runs
if not quiet:
self.console.print(
Panel(
f"[bold]W&B Project Summary[/bold]\n"
f"🔹 Entity: [cyan]{entity}[/cyan]\n"
f"🔹 Project: [green]{project}[/green]\n"
f"🔹 Total runs: [yellow]{len(current_run_ids)}[/yellow]\n"
f"🔹 Output directory: [cyan]{output_dir}[/cyan]",
title="W&B Downloader",
border_style="blue",
)
)
if delete_outdated:
# Handle deletion of outdated logs
try:
self.logger.info("Checking for outdated logs...")
deleted_count = self.delete_outdated_logs(output_dir, current_run_ids)
if deleted_count > 0:
self.logger.info(f"Deleted {deleted_count} outdated log file(s)")
except Exception as e:
self.logger.warning(f"Error during deletion of outdated logs: {e}")
# Use rich progress for run processing
if not quiet:
with self.create_progress() as progress:
task = progress.add_task("[cyan]Processing runs...", total=len(runs))
for run in runs:
run_name = self.sanitize_filename(run.name)
progress.update(
task, description=f"[cyan]Processing run [bold]{run_name}[/bold]..."
)
try:
self.process_run(run, output_dir)
except Exception as e:
self.logger.warning(f"Error processing run {run.name}/{run.id}: {e}")
progress.update(task, advance=1)
else:
# Process runs without progress display
for run in runs:
try:
self.process_run(run, output_dir)
except Exception as e:
self.logger.warning(f"Error processing run {run.name}/{run.id}: {e}")
self.logger.info(f"Successfully downloaded {len(runs)} run(s) to {output_dir}")
return 0
[docs]
def execute(self, config, args: argparse.Namespace) -> int:
"""
Execute the W&B log download with the provided arguments.
Args:
args: Parsed command-line arguments
Returns:
Exit code (0 for success, non-zero for error)
"""
# Validate required arguments
if not args.entity:
self.logger.error("--entity is required (or set WANDB_ENTITY environment variable)")
return 1
if not args.project:
self.logger.error("--project is required (or set WANDB_PROJECT environment variable)")
return 1
# Run the download
status = self.download_wandb_logs(
entity=args.entity,
project=args.project,
output_dir=args.output_dir,
timeout=args.timeout,
quiet=args.quiet,
delete_outdated=not args.no_delete,
)
if status == 0:
self.logger.info("Download completed successfully!")
return status
[docs]
@staticmethod
def sanitize_filename(name: str) -> str:
"""
Sanitize the run name to create a valid filename.
Args:
name: The original run name
Returns:
A sanitized string suitable for use as a filename
"""
return re.sub(r"[^\w\s\-\.]", "", name).strip().replace(" ", "_")
[docs]
def delete_outdated_logs(self, output_dir: str, current_run_ids: Set[str]) -> int:
"""
Delete log files that do not correspond to any current run ID.
Args:
output_dir: Directory containing log files
current_run_ids: Set of valid run IDs from W&B
Returns:
Number of files deleted
"""
# List all files in the output directory
local_files = os.listdir(output_dir)
deleted_count = 0
# Delete files that do not correspond to any current run ID
for file in local_files:
# Extract run ID from filename (assuming format: <sanitized_name>_<run_id>.json)
match = re.match(r".*_(\w+)\.json$", file)
if match:
run_id = match.group(1)
if run_id not in current_run_ids:
file_path = os.path.join(output_dir, file)
os.remove(file_path)
deleted_count += 1
self.logger.info(f"Deleted outdated log file: {file}")
return deleted_count
[docs]
def process_run(self, run: WandbRun, output_dir: str) -> None:
"""
Process a single W&B run and save its history to a JSON file.
Args:
run: W&B run object
output_dir: Directory where the log file will be saved
"""
run.load_full_data()
run.load()
# Sanitize the run name for use in filenames
sanitized_name = self.sanitize_filename(run.name)
# Construct the filename using run ID and sanitized run name
filename = f"{sanitized_name}_{run.id}.json"
filepath = os.path.join(output_dir, filename)
# Get current last heartbeat time from W&B run
current_last_heartbeat_time = run.heartbeatAt
run_info = {
"id": run.id,
"name": run.name,
"config": run.config,
"tags": run.tags,
"url": run.url,
"state": run.state,
"notes": run.notes,
"summary": run.summary._json_dict if hasattr(run.summary, "_json_dict") else {},
}
# Check if the file already exists to avoid redundant downloads
if os.path.exists(filepath):
try:
with open(filepath, "r") as f:
existing_data = json.load(f)
existing_last_heartbeat_time = existing_data[0].get("last_heartbeat_time", None)
# Skip updating if last heartbeat time hasn't changed
if existing_last_heartbeat_time == current_last_heartbeat_time:
existing_data[0]["run_info"] = run_info
with open(filepath, "w") as f:
json.dump(existing_data, f, indent=4, sort_keys=True)
return
except (json.JSONDecodeError, IndexError, KeyError) as e:
self.logger.warning(f"Error reading existing file {filepath}: {e}. Will overwrite.")
# Extract the history of the run as a dataframe
try:
history = run.history(pandas=True, samples=None)
except Exception as e:
self.logger.warning(f"Failed to retrieve history for run {run.name}/{run.id}: {e}")
return
# Convert the dataframe to a dictionary
history_dict = history.to_dict(orient="records")
if len(history_dict) == 0:
history_dict = [dict()]
# Add last heartbeat time and run info to history[0]
history_dict[0]["last_heartbeat_time"] = current_last_heartbeat_time
history_dict[0]["run_info"] = run_info
for entry in history_dict:
if "_step" in entry:
entry.setdefault("step", entry["_step"])
entry.setdefault("iteration", entry["_step"])
# Save the dictionary as a JSON file
try:
with open(filepath, "w") as f:
json.dump(history_dict, f, indent=4, sort_keys=True)
except Exception as e:
self.logger.error(f"Failed to save log file {filepath}: {e}")