import json
import logging
from plumbum import ProcessExecutionError
from plumbum.cmd import git
from rich.text import Text
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Horizontal, Vertical, VerticalScroll
from textual.screen import ModalScreen
from textual.widgets import (
Footer,
Header,
Input,
Label,
Markdown,
Static,
TabbedContent,
TabPane,
Tree,
)
[docs]
class EmanagerApp(App):
"""A Textual app to manage experiments."""
TITLE = "ML Research Tools - Experiment Manager"
CSS = """
#sidebar {
width: 30%;
dock: left;
border-right: solid green;
}
#tree-view {
height: 1fr;
}
#search-input {
margin: 1;
}
#detail-view {
width: 70%;
height: 1fr;
}
.scrollable {
height: 1fr;
overflow: auto;
}
"""
BINDINGS = [
Binding("d", "toggle_dark", "Toggle dark mode"),
Binding("c", "checkout", "Checkout branch"),
Binding("n", "new_branch", "New branch"),
Binding("a", "toggle_archive", "Toggle archive"),
Binding("m", "add_note", "Add note"),
Binding("q", "quit", "Quit"),
]
def __init__(self, tool_instance):
super().__init__()
self.tool = tool_instance
self.metadata = self.tool._read_metadata()
self.experiments = self.metadata.get("experiments", {})
self.current_tag = self.tool._get_current_tag()
self.selected_tag = None
self.search_term = ""
[docs]
def compose(self) -> ComposeResult:
yield Header()
with Horizontal():
with Vertical(id="sidebar"):
yield Input(placeholder="Search experiments...", id="search-input")
yield Tree("Experiments", id="tree-view")
with Vertical(id="detail-view"):
with TabbedContent(initial="tab-overview"):
with TabPane("Overview", id="tab-overview"):
with VerticalScroll(classes="scrollable"):
yield Markdown(id="experiment-markdown")
with TabPane("Diff (vs Parent)", id="tab-diff"):
with VerticalScroll(classes="scrollable"):
yield Static(id="experiment-diff", expand=True)
with TabPane("Commits", id="tab-commits"):
with VerticalScroll(classes="scrollable"):
yield Static(id="experiment-commits", expand=True)
yield Footer()
[docs]
def on_mount(self) -> None:
self.rebuild_tree()
[docs]
def rebuild_tree(self):
tree = self.query_one("#tree-view", Tree)
tree.clear()
visible_tags = set(self.experiments.keys())
children = {}
roots = []
for tag, data in self.experiments.items():
parent = data.get("parent")
if not parent or parent not in visible_tags:
roots.append(tag)
else:
children.setdefault(parent, []).append(tag)
def add_node(parent_node, tag):
data = self.experiments[tag]
label = tag
if tag == self.current_tag:
label = f"🟢 [bold]{label}[/bold]"
if data.get("archived"):
label += " (archived)"
if data.get("sync_commits"):
label += " 🔄"
node_matches = self.search_term.lower() in tag.lower()
if self.search_term:
notes = " ".join(data.get("notes", [])).lower()
node_matches = node_matches or (self.search_term.lower() in notes)
def has_matching_child(t):
if (
self.search_term.lower() in t.lower()
or self.search_term.lower()
in " ".join(self.experiments[t].get("notes", [])).lower()
):
return True
return any(has_matching_child(c) for c in children.get(t, []))
if not self.search_term or node_matches or has_matching_child(tag):
node = parent_node.add(label, data=tag)
if self.search_term:
node.expand()
for child in sorted(children.get(tag, [])):
add_node(node, child)
for root in sorted(roots):
add_node(tree.root, root)
if not self.search_term:
tree.root.expand()
[docs]
def on_tree_node_highlighted(self, event: Tree.NodeHighlighted) -> None:
tag = event.node.data
if tag is None:
self.selected_tag = None
return
self.selected_tag = tag
self.update_detail_view(tag)
[docs]
def update_detail_view(self, tag):
data = self.experiments.get(tag)
if not data:
return
# Overview tab
md_text = self.tool._generate_readme_text(tag, self.metadata)
md_widget = self.query_one("#experiment-markdown", Markdown)
md_widget.update(md_text)
# Diff tab
diff_text = "No parent to diff against."
parent = data.get("parent")
if parent:
try:
diff_output = git("diff", "--color", f"exp/{parent}..exp/{tag}")
diff_text = diff_output if diff_output.strip() else "No changes."
except ProcessExecutionError as e:
diff_text = f"Could not compute diff:\n{e.stderr or e.stdout or str(e)}"
diff_widget = self.query_one("#experiment-diff", Static)
diff_widget.update(Text.from_ansi(diff_text))
# Commits tab
try:
commits_output = git("log", "--color", "--oneline", "-n", "50", f"exp/{tag}")
except ProcessExecutionError as e:
commits_output = f"Could not get commits:\n{e.stderr or e.stdout or str(e)}"
commits_widget = self.query_one("#experiment-commits", Static)
commits_widget.update(Text.from_ansi(commits_output))
[docs]
def action_checkout(self) -> None:
if self.selected_tag:
if self.selected_tag == self.current_tag:
self.notify("Already on this experiment")
return
res = self.tool._switch(self.selected_tag)
if res == 0:
self.current_tag = self.tool._get_current_tag()
self.rebuild_tree()
self.notify(f"Checked out {self.selected_tag}")
else:
self.notify(
f"Failed to checkout {self.selected_tag} (Check CLI output)", severity="error"
)
[docs]
def action_new_branch(self) -> None:
if not self.selected_tag:
self.notify("Select an experiment first to branch from", severity="error")
return
def check_new_branch(new_tag):
if new_tag:
# Switch to parent first
if self.selected_tag != self.current_tag:
self.tool._switch(self.selected_tag)
res = self.tool._create(new_tag)
if res == 0:
self.metadata = self.tool._read_metadata()
self.experiments = self.metadata.get("experiments", {})
self.current_tag = self.tool._get_current_tag()
self.rebuild_tree()
self.notify(f"Created and checked out {new_tag}")
else:
self.notify(f"Failed to create {new_tag}", severity="error")
self.push_screen(InputModal("Enter new experiment tag:"), check_new_branch)
[docs]
def action_toggle_archive(self) -> None:
if self.selected_tag:
data = self.experiments.get(self.selected_tag)
if data:
is_archived = data.get("archived", False)
if is_archived:
self.tool._unarchive(self.selected_tag)
else:
self.tool._archive(self.selected_tag)
self.metadata = self.tool._read_metadata()
self.experiments = self.metadata.get("experiments", {})
self.rebuild_tree()
self.notify(f"{'Unarchived' if is_archived else 'Archived'} {self.selected_tag}")
[docs]
def action_add_note(self) -> None:
if not self.selected_tag:
self.notify("Select an experiment first", severity="error")
return
def check_new_note(note):
if note:
self.tool._note(self.selected_tag, note, False)
self.metadata = self.tool._read_metadata()
self.experiments = self.metadata.get("experiments", {})
self.update_detail_view(self.selected_tag)
self.notify(f"Added note to {self.selected_tag}")
self.push_screen(InputModal(f"Enter note for {self.selected_tag}:"), check_new_note)