Source code for ml_research_tools.exp.wandb_downloader_tool
#!/usr/bin/env python3
"""
Tool to download Weights & Biases (W&B) run logs to local JSON files.
"""
import argparse
import json
import logging
import os
import re
from typing import Set
import wandb
from rich.panel import Panel
from wandb.apis.public.runs import Run as WandbRun
from wandb.apis.public.runs import Runs as WandbRuns
from ml_research_tools.core.base_tool import BaseTool
[docs]
class WandbDownloaderTool(BaseTool):
"""Tool for downloading W&B run logs to local JSON files."""
name = "wandb-downloader"
description = "Download artifacts and runs from Weights & Biases"
[docs]
def __init__(self, services) -> None:
"""Initialize the W&B downloader tool."""
super().__init__(services)
self.logger = logging.getLogger(__name__)
[docs]
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> None:
"""Add tool-specific arguments to the parser."""
parser.add_argument(
"--entity",
"-e",
default=os.environ.get("WANDB_ENTITY"),
help="W&B entity (username or team name). Can also use WANDB_ENTITY env variable.",
)
parser.add_argument(
"--project",
"-p",
default=os.environ.get("WANDB_PROJECT"),
help="W&B project name. Can also use WANDB_PROJECT env variable.",
)
parser.add_argument(
"--output-dir",
"-o",
default="wandb_logs",
help="Directory to save log files (default: wandb_logs)",
)
parser.add_argument(
"--timeout",
"-t",
type=int,
default=30,
help="API timeout in seconds (default: 30)",
)
parser.add_argument(
"--quiet",
"-q",
action="store_true",
help="Suppress progress bar and detailed logging",
)
parser.add_argument(
"--no-delete",
action="store_true",
help="Don't delete logs for runs that no longer exist",
)
[docs]
def download_wandb_logs(
self,
entity: str,
project: str,
output_dir: str = "wandb_logs",
timeout: int = 30,
quiet: bool = False,
delete_outdated: bool = True,
) -> int:
"""
Download W&B logs for a specified project to local JSON files.
Args:
entity: The W&B entity (username or team name)
project: The W&B project name
output_dir: Directory where log files will be saved
timeout: API timeout in seconds
quiet: If True, suppress progress bar
delete_outdated: If True, delete logs for runs that no longer exist
"""
# Initialize the W&B API
self.logger.info(f"Initializing W&B API for {entity}/{project}")
try:
api = wandb.Api(timeout=timeout)
except Exception as e:
self.logger.exception(f"Failed to initialize W&B API")
return 1
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
# Retrieve all runs from the specified project
self.logger.info("Retrieving runs from W&B...")
try:
runs: WandbRuns = api.runs(
path=f"{entity}/{project}",
per_page=32,
order="-created_at",
)
except Exception as e:
self.logger.exception(f"Failed to retrieve runs for {entity}/{project}")
return 1
# Use Rich Progress for loading metadata
self.logger.info(f"Found {len(runs)} runs, loading metadata...")
# Create a set of current run IDs from W&B
current_run_ids = set()
# Use rich progress if not in quiet mode
if not quiet:
with self.create_progress() as progress:
meta_task = progress.add_task("[cyan]Loading run metadata...", total=len(runs))
for run in runs:
current_run_ids.add(run.id)
progress.update(
meta_task, advance=1, description=f"[cyan]Loading run {run.id}..."
)
else:
# Simple operation if quiet mode
current_run_ids = set(run.id for run in runs)
# Display summary of runs
if not quiet:
self.console.print(
Panel(
f"[bold]W&B Project Summary[/bold]\n"
f"🔹 Entity: [cyan]{entity}[/cyan]\n"
f"🔹 Project: [green]{project}[/green]\n"
f"🔹 Total runs: [yellow]{len(current_run_ids)}[/yellow]\n"
f"🔹 Output directory: [cyan]{output_dir}[/cyan]",
title="W&B Downloader",
border_style="blue",
)
)
if delete_outdated:
# Handle deletion of outdated logs
try:
self.logger.info("Checking for outdated logs...")
deleted_count = self.delete_outdated_logs(output_dir, current_run_ids)
if deleted_count > 0:
self.logger.info(f"Deleted {deleted_count} outdated log file(s)")
except Exception as e:
self.logger.warning(f"Error during deletion of outdated logs: {e}")
# Use rich progress for run processing
if not quiet:
with self.create_progress() as progress:
task = progress.add_task("[cyan]Processing runs...", total=len(runs))
for run in runs:
run_name = self.sanitize_filename(run.name)
progress.update(
task, description=f"[cyan]Processing run [bold]{run_name}[/bold]..."
)
try:
self.process_run(run, output_dir)
except Exception as e:
self.logger.warning(f"Error processing run {run.name}/{run.id}: {e}")
progress.update(task, advance=1)
else:
# Process runs without progress display
for run in runs:
try:
self.process_run(run, output_dir)
except Exception as e:
self.logger.warning(f"Error processing run {run.name}/{run.id}: {e}")
self.logger.info(f"Successfully downloaded {len(runs)} run(s) to {output_dir}")
return 0
[docs]
def execute(self, config, args: argparse.Namespace) -> int:
"""
Execute the W&B log download with the provided arguments.
Args:
args: Parsed command-line arguments
Returns:
Exit code (0 for success, non-zero for error)
"""
# Validate required arguments
if not args.entity:
self.logger.error("--entity is required (or set WANDB_ENTITY environment variable)")
return 1
if not args.project:
self.logger.error("--project is required (or set WANDB_PROJECT environment variable)")
return 1
# Run the download
status = self.download_wandb_logs(
entity=args.entity,
project=args.project,
output_dir=args.output_dir,
timeout=args.timeout,
quiet=args.quiet,
delete_outdated=not args.no_delete,
)
if status == 0:
self.logger.info("Download completed successfully!")
return status
[docs]
@staticmethod
def sanitize_filename(name: str) -> str:
"""
Sanitize the run name to create a valid filename.
Args:
name: The original run name
Returns:
A sanitized string suitable for use as a filename
"""
return re.sub(r"[^\w\s\-\.]", "", name).strip().replace(" ", "_")
[docs]
def delete_outdated_logs(self, output_dir: str, current_run_ids: Set[str]) -> int:
"""
Delete log files that do not correspond to any current run ID.
Args:
output_dir: Directory containing log files
current_run_ids: Set of valid run IDs from W&B
Returns:
Number of files deleted
"""
# List all files in the output directory
local_files = os.listdir(output_dir)
deleted_count = 0
# Delete files that do not correspond to any current run ID
for file in local_files:
# Extract run ID from filename (assuming format: <sanitized_name>_<run_id>.json)
match = re.match(r".*_(\w+)\.json$", file)
if match:
run_id = match.group(1)
if run_id not in current_run_ids:
file_path = os.path.join(output_dir, file)
os.remove(file_path)
deleted_count += 1
self.logger.info(f"Deleted outdated log file: {file}")
return deleted_count
[docs]
def process_run(self, run: WandbRun, output_dir: str) -> None:
"""
Process a single W&B run and save its history to a JSON file.
Args:
run: W&B run object
output_dir: Directory where the log file will be saved
"""
# Sanitize the run name for use in filenames
sanitized_name = self.sanitize_filename(run.name)
# Construct the filename using run ID and sanitized run name
filename = f"{sanitized_name}_{run.id}.json"
filepath = os.path.join(output_dir, filename)
# Get current last heartbeat time from W&B run
current_last_heartbeat_time = run.heartbeatAt
run_info = {
"id": run.id,
"name": run.name,
"config": run.config,
"tags": run.tags,
"url": run.url,
"state": run.state,
"notes": run.notes,
"summary": run.summary._json_dict if hasattr(run.summary, "_json_dict") else {},
}
# Check if the file already exists to avoid redundant downloads
if os.path.exists(filepath):
try:
with open(filepath, "r") as f:
existing_data = json.load(f)
existing_last_heartbeat_time = existing_data[0].get("last_heartbeat_time", None)
# Skip updating if last heartbeat time hasn't changed
if existing_last_heartbeat_time == current_last_heartbeat_time:
existing_data[0]["run_info"] = run_info
with open(filepath, "w") as f:
json.dump(existing_data, f, indent=4)
return
except (json.JSONDecodeError, IndexError, KeyError) as e:
self.logger.warning(f"Error reading existing file {filepath}: {e}. Will overwrite.")
# Extract the history of the run as a dataframe
try:
history = run.history(pandas=True)
except Exception as e:
self.logger.warning(f"Failed to retrieve history for run {run.name}/{run.id}: {e}")
return
# Convert the dataframe to a dictionary
history_dict = history.to_dict(orient="records")
if len(history_dict) == 0:
history_dict = [dict()]
# Add last heartbeat time and run info to history[0]
history_dict[0]["last_heartbeat_time"] = current_last_heartbeat_time
history_dict[0]["run_info"] = run_info
# Save the dictionary as a JSON file
try:
with open(filepath, "w") as f:
json.dump(history_dict, f, indent=4)
except Exception as e:
self.logger.error(f"Failed to save log file {filepath}: {e}")