Skip to content

hf_llama

optimus_dl.modules.model.presets.hf_llama

Preset for loading Hugging Face Llama models.

HFLlamaConfig dataclass

Bases: LlamaConfig

HFLlamaConfig(_name: str | None = None, block_size: int = 1024, vocab_size: int = 50304, n_layer: int = 12, n_head: int = 12, n_embd: int = 768, head_dim: int | None = None, dropout: float = 0.0, bias: bool = False, tie_word_embeddings: bool = True, shard_every_ith_layer: int = 1, padding_token_id: int | None = None, sequence_length: int = 16000, rmsnorm_eps: float = 1e-05, attention_bias: bool = False, n_kv_head: int | None = None, intermediate_size: int | None = None, multiple_of: int = 256, rope_theta: float = 10000.0, rope_scaling: dict | None = None, use_liger_rmsnorm: bool | None = None, use_liger_swiglu: bool | None = None, hf_model_name: str = 'meta-llama/Llama-2-7b-hf', load_weights: bool = True)

Parameters:

Name Type Description Default
hf_model_name str
'meta-llama/Llama-2-7b-hf'
load_weights bool
True
Source code in optimus_dl/modules/model/presets/hf_llama.py
@dataclass
class HFLlamaConfig(LlamaConfig):
    hf_model_name: str = "meta-llama/Llama-2-7b-hf"
    load_weights: bool = (
        True  # If True, will download and load weights. If False, just config is used (random init)
    )

make_hf_llama_model(cfg, **_)

Create a Llama model loaded with weights from Hugging Face.

Source code in optimus_dl/modules/model/presets/hf_llama.py
@register_model("preset_hfllama2", HFLlamaConfig)
def make_hf_llama_model(cfg: HFLlamaConfig, **_):
    """Create a Llama model loaded with weights from Hugging Face."""
    logger.info(f"Loading HF model: {cfg.hf_model_name}")

    # Load HF config
    hf_config = AutoConfig.from_pretrained(cfg.hf_model_name)

    # Update local config from HF config
    update_config_from_hf(cfg, hf_config)

    # Initialize local Llama model
    model = Llama(cfg)

    if not cfg.load_weights:
        return model

    # Load HF model weights
    logger.info("Loading HF model weights...")
    hf_model = AutoModelForCausalLM.from_pretrained(
        cfg.hf_model_name,
        dtype=torch.float32,
        low_cpu_mem_usage=True,
    )
    hf_sd = hf_model.state_dict()
    mapper = WeightMapper(hf_sd, model.state_dict())

    logger.info("Copying weights...")

    # Embeddings
    mapper.copy("model.embed_tokens.weight", "transformer.wte.weight")

    # Layers
    for i in range(cfg.n_layer):
        # Attention
        mapper.copy(
            f"model.layers.{i}.self_attn.q_proj.weight",
            f"transformer.h.{i}.attn.wq.weight",
            permute=True,
            n_heads=cfg.n_head,
            head_dim=cfg.head_dim,
        )
        mapper.copy(
            f"model.layers.{i}.self_attn.k_proj.weight",
            f"transformer.h.{i}.attn.wk.weight",
            permute=True,
            n_heads=cfg.n_kv_head,
            head_dim=cfg.head_dim,
        )
        mapper.copy(
            f"model.layers.{i}.self_attn.v_proj.weight",
            f"transformer.h.{i}.attn.wv.weight",
        )
        mapper.copy(
            f"model.layers.{i}.self_attn.o_proj.weight",
            f"transformer.h.{i}.attn.wo.weight",
        )

        # MLP
        mapper.copy(
            f"model.layers.{i}.mlp.gate_proj.weight", f"transformer.h.{i}.mlp.w1.weight"
        )
        mapper.copy(
            f"model.layers.{i}.mlp.up_proj.weight", f"transformer.h.{i}.mlp.w2.weight"
        )
        mapper.copy(
            f"model.layers.{i}.mlp.down_proj.weight",
            f"transformer.h.{i}.mlp.c_proj.weight",
        )

        # Layer Norms
        mapper.copy(
            f"model.layers.{i}.input_layernorm.weight", f"transformer.h.{i}.ln_1.weight"
        )
        mapper.copy(
            f"model.layers.{i}.post_attention_layernorm.weight",
            f"transformer.h.{i}.ln_2.weight",
        )

    # Final Norm
    mapper.copy("model.norm.weight", "transformer.ln_f.weight")

    # LM Head
    mapper.copy("lm_head.weight", "lm_head.weight")

    # Validation
    mapper.validate(tie_word_embeddings=cfg.tie_word_embeddings)

    del hf_model
    del hf_sd
    import gc

    gc.collect()

    return model