Skip to content

Index

optimus_dl.recipe.serve

ServeConfig dataclass

ServeConfig(serve: optimus_dl.recipe.serve.config.ServeRecipeConfig = , common: optimus_dl.recipe.serve.config.ServeCommonConfig = )

Parameters:

Name Type Description Default
serve ServeRecipeConfig

ServeRecipeConfig(port: int = 8000, host: str = '0.0.0.0')

<dynamic>
common ServeCommonConfig

ServeCommonConfig(checkpoint_path: str | None = None, model: Any = None, tokenizer: optimus_dl.modules.tokenizer.config.BaseTokenizerConfig = '???', device: str = 'auto')

<dynamic>
Source code in optimus_dl/recipe/serve/config.py
@dataclass
class ServeConfig:
    serve: ServeRecipeConfig = field(default_factory=ServeRecipeConfig)
    common: ServeCommonConfig = field(default_factory=ServeCommonConfig)

ServeRecipe

Recipe for serving LLM Baselines models via simple HTTP API.

This class loads a model from a checkpoint or config, initializes the tokenizer, and starts an HTTP server compatible with OpenAI clients.

Source code in optimus_dl/recipe/serve/base.py
class ServeRecipe:
    """Recipe for serving LLM Baselines models via simple HTTP API.

    This class loads a model from a checkpoint or config, initializes the
    tokenizer, and starts an HTTP server compatible with OpenAI clients.
    """

    def __init__(self, cfg: ServeConfig):
        self.cfg = cfg
        self.model = None
        self.tokenizer = None
        self.device = None

        # Initialize builder with empty config as we load from checkpoint
        chkp_cfg = CheckpointManagerConfig()
        self.checkpoint_manager = CheckpointManager(chkp_cfg)

        modelb_cfg = ModelBuilderConfig()
        self.model_builder = ModelBuilder(modelb_cfg)

    def setup(self):
        """Load model weights and tokenizer, and configure the device."""
        # Setup device
        if self.cfg.common.device == "auto":
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(self.cfg.common.device)

        logger.info(f"Using device: {self.device}")

        # Build collective for potential distributed init
        collective = build_best_collective(
            device=None if self.device.type == "cuda" else torch.device("cpu"),
            config=DistributedConfig(),
        )

        assert (self.cfg.common.checkpoint_path is not None) ^ (
            self.cfg.common.model is not None
        ), "Either checkpoint_path or model must be specified, but not both"

        if self.cfg.common.checkpoint_path is not None:
            logger.info(
                f"Loading model from checkpoint: {self.cfg.common.checkpoint_path}"
            )
            self.model, _ = self.checkpoint_manager.build_model_from_checkpoint(
                checkpoint_path=self.cfg.common.checkpoint_path, device=self.device
            )
        else:
            logger.info("Building model from config")
            self.model = self.model_builder.build_model(
                model_config=self.cfg.common.model,
                collective=collective,
            )

        self.model.to(self.device)
        self.model.eval()

        # Build tokenizer
        self.tokenizer = build_tokenizer(self.cfg.common.tokenizer)
        logger.info("Model and tokenizer loaded")

    @torch.no_grad()
    def _debug_tokens_log(self, input_ids):
        """Log tokens for debugging."""
        tokens = []
        for token in input_ids.cpu().reshape(-1):
            token = token.item()
            tokens.append(f"{token}:'{self.tokenizer.decode([token])}'")
        logger.debug(f"Input tokens: {' '.join(tokens)}")

    @torch.no_grad()
    def generate_stream(
        self,
        prompt_or_messages: str | list[dict],
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: int | None = None,
    ):
        """Generate text continuation yielding chunks.

        Handles tokenization (including chat templates), inference loop,
        sampling, and detokenization delta logic for streaming.

        Args:
            prompt_or_messages: Input string or list of chat messages.
            max_new_tokens: Maximum number of tokens to generate.
            temperature: Sampling temperature (0.0 for greedy).
            top_k: Optional top-k sampling.

        Yields:
            String chunks of generated text.
        """
        if isinstance(prompt_or_messages, list):
            # Apply chat template
            input_ids_list = self.tokenizer.apply_chat_template(
                prompt_or_messages, tokenize=True, add_generation_prompt=True
            )
            input_ids = torch.tensor(
                input_ids_list, dtype=torch.long, device=self.device
            ).unsqueeze(0)
        else:
            if isinstance(prompt_or_messages, list):
                # Handle list of strings? Simple server assumes single string prompt
                prompt_or_messages = prompt_or_messages[0]

            input_ids = torch.tensor(
                self.tokenizer.encode(prompt_or_messages),
                dtype=torch.long,
                device=self.device,
            ).unsqueeze(0)

        self._debug_tokens_log(input_ids)

        generated_ids = []
        last_text = ""

        for _ in range(max_new_tokens):
            # Crop context if needed
            if input_ids.size(1) > self.model.config.sequence_length:
                input_cond = input_ids[:, -self.model.config.sequence_length :]
            else:
                input_cond = input_ids

            outputs = self.model(input_cond)
            logits = outputs["logits"][:, -1, :]

            if temperature > 0:
                logits = logits / temperature
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float("Inf")
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)

            input_ids = torch.cat([input_ids, next_token], dim=1)
            generated_ids.append(next_token.item())

            # Simple streaming: decode all and yield diff
            # This is inefficient but safe for bytes/utf-8 boundaries
            current_text = self.tokenizer.decode(generated_ids)
            new_text = current_text[len(last_text) :]

            if new_text:
                yield new_text
                last_text = current_text

            if (
                hasattr(self.cfg.common.tokenizer, "eos_token_id")
                and next_token.item() == self.cfg.common.tokenizer.eos_token_id
            ):
                break

    def generate(
        self,
        prompt_or_messages: str | list[dict],
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: int | None = None,
    ) -> str:
        """Generate full text continuation.

        Wrapper around `generate_stream` that accumulates all chunks.
        """
        return "".join(
            list(
                self.generate_stream(
                    prompt_or_messages, max_new_tokens, temperature, top_k
                )
            )
        )

    def run(self):
        """Start the HTTP server."""
        self.setup()

        server_address = (self.cfg.serve.host, self.cfg.serve.port)
        httpd = HTTPServer(server_address, RequestHandler)
        httpd.recipe = self

        logger.info(f"Serving at http://{self.cfg.serve.host}:{self.cfg.serve.port}")

        # Example payloads
        text_completion_ex = json.dumps(
            {
                "prompt": "Once upon a time",
                "max_tokens": 20,
                "temperature": 0.8,
            }
        )
        logger.info(
            f"Text Completion Example:\ncurl -X POST http://{self.cfg.serve.host}:{self.cfg.serve.port}/v1/completions -d '{text_completion_ex}'"
        )

        chat_completion_ex = json.dumps(
            {
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "Hello!"},
                ],
                "max_tokens": 20,
                "temperature": 0.8,
                "stream": True,
            }
        )
        logger.info(
            f"Chat Streaming Example:\ncurl -X POST http://{self.cfg.serve.host}:{self.cfg.serve.port}/v1/chat/completions -d '{chat_completion_ex}'"
        )

        try:
            httpd.serve_forever()
        except KeyboardInterrupt:
            pass
        httpd.server_close()
        logger.info("Server stopped")

generate(prompt_or_messages, max_new_tokens=50, temperature=1.0, top_k=None)

Generate full text continuation.

Wrapper around generate_stream that accumulates all chunks.

Source code in optimus_dl/recipe/serve/base.py
def generate(
    self,
    prompt_or_messages: str | list[dict],
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = None,
) -> str:
    """Generate full text continuation.

    Wrapper around `generate_stream` that accumulates all chunks.
    """
    return "".join(
        list(
            self.generate_stream(
                prompt_or_messages, max_new_tokens, temperature, top_k
            )
        )
    )

generate_stream(prompt_or_messages, max_new_tokens=50, temperature=1.0, top_k=None)

Generate text continuation yielding chunks.

Handles tokenization (including chat templates), inference loop, sampling, and detokenization delta logic for streaming.

Parameters:

Name Type Description Default
prompt_or_messages str | list[dict]

Input string or list of chat messages.

required
max_new_tokens int

Maximum number of tokens to generate.

50
temperature float

Sampling temperature (0.0 for greedy).

1.0
top_k int | None

Optional top-k sampling.

None

Yields:

Type Description

String chunks of generated text.

Source code in optimus_dl/recipe/serve/base.py
@torch.no_grad()
def generate_stream(
    self,
    prompt_or_messages: str | list[dict],
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = None,
):
    """Generate text continuation yielding chunks.

    Handles tokenization (including chat templates), inference loop,
    sampling, and detokenization delta logic for streaming.

    Args:
        prompt_or_messages: Input string or list of chat messages.
        max_new_tokens: Maximum number of tokens to generate.
        temperature: Sampling temperature (0.0 for greedy).
        top_k: Optional top-k sampling.

    Yields:
        String chunks of generated text.
    """
    if isinstance(prompt_or_messages, list):
        # Apply chat template
        input_ids_list = self.tokenizer.apply_chat_template(
            prompt_or_messages, tokenize=True, add_generation_prompt=True
        )
        input_ids = torch.tensor(
            input_ids_list, dtype=torch.long, device=self.device
        ).unsqueeze(0)
    else:
        if isinstance(prompt_or_messages, list):
            # Handle list of strings? Simple server assumes single string prompt
            prompt_or_messages = prompt_or_messages[0]

        input_ids = torch.tensor(
            self.tokenizer.encode(prompt_or_messages),
            dtype=torch.long,
            device=self.device,
        ).unsqueeze(0)

    self._debug_tokens_log(input_ids)

    generated_ids = []
    last_text = ""

    for _ in range(max_new_tokens):
        # Crop context if needed
        if input_ids.size(1) > self.model.config.sequence_length:
            input_cond = input_ids[:, -self.model.config.sequence_length :]
        else:
            input_cond = input_ids

        outputs = self.model(input_cond)
        logits = outputs["logits"][:, -1, :]

        if temperature > 0:
            logits = logits / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = torch.argmax(logits, dim=-1, keepdim=True)

        input_ids = torch.cat([input_ids, next_token], dim=1)
        generated_ids.append(next_token.item())

        # Simple streaming: decode all and yield diff
        # This is inefficient but safe for bytes/utf-8 boundaries
        current_text = self.tokenizer.decode(generated_ids)
        new_text = current_text[len(last_text) :]

        if new_text:
            yield new_text
            last_text = current_text

        if (
            hasattr(self.cfg.common.tokenizer, "eos_token_id")
            and next_token.item() == self.cfg.common.tokenizer.eos_token_id
        ):
            break

run()

Start the HTTP server.

Source code in optimus_dl/recipe/serve/base.py
def run(self):
    """Start the HTTP server."""
    self.setup()

    server_address = (self.cfg.serve.host, self.cfg.serve.port)
    httpd = HTTPServer(server_address, RequestHandler)
    httpd.recipe = self

    logger.info(f"Serving at http://{self.cfg.serve.host}:{self.cfg.serve.port}")

    # Example payloads
    text_completion_ex = json.dumps(
        {
            "prompt": "Once upon a time",
            "max_tokens": 20,
            "temperature": 0.8,
        }
    )
    logger.info(
        f"Text Completion Example:\ncurl -X POST http://{self.cfg.serve.host}:{self.cfg.serve.port}/v1/completions -d '{text_completion_ex}'"
    )

    chat_completion_ex = json.dumps(
        {
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Hello!"},
            ],
            "max_tokens": 20,
            "temperature": 0.8,
            "stream": True,
        }
    )
    logger.info(
        f"Chat Streaming Example:\ncurl -X POST http://{self.cfg.serve.host}:{self.cfg.serve.port}/v1/chat/completions -d '{chat_completion_ex}'"
    )

    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    httpd.server_close()
    logger.info("Server stopped")

setup()

Load model weights and tokenizer, and configure the device.

Source code in optimus_dl/recipe/serve/base.py
def setup(self):
    """Load model weights and tokenizer, and configure the device."""
    # Setup device
    if self.cfg.common.device == "auto":
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        self.device = torch.device(self.cfg.common.device)

    logger.info(f"Using device: {self.device}")

    # Build collective for potential distributed init
    collective = build_best_collective(
        device=None if self.device.type == "cuda" else torch.device("cpu"),
        config=DistributedConfig(),
    )

    assert (self.cfg.common.checkpoint_path is not None) ^ (
        self.cfg.common.model is not None
    ), "Either checkpoint_path or model must be specified, but not both"

    if self.cfg.common.checkpoint_path is not None:
        logger.info(
            f"Loading model from checkpoint: {self.cfg.common.checkpoint_path}"
        )
        self.model, _ = self.checkpoint_manager.build_model_from_checkpoint(
            checkpoint_path=self.cfg.common.checkpoint_path, device=self.device
        )
    else:
        logger.info("Building model from config")
        self.model = self.model_builder.build_model(
            model_config=self.cfg.common.model,
            collective=collective,
        )

    self.model.to(self.device)
    self.model.eval()

    # Build tokenizer
    self.tokenizer = build_tokenizer(self.cfg.common.tokenizer)
    logger.info("Model and tokenizer loaded")

Modules and Sub-packages

  • base: Serving recipe for LLM Baselines models.
  • config:
  • models: A single message in a chat conversation.