Module lore.types.utils

Expand source code
from typing import Optional

import huggingface_hub as hf_hub


def check_commit_hash_branch_association(name: str, branch: str, commit_hash: str) -> None:
    api = hf_hub.HfApi()

    # Commits are sorted by date (last commit first)
    commit_list = api.list_repo_commits(name, revision=branch)

    for commit in commit_list:
        if commit.commit_id == commit_hash:
            return

    raise RuntimeError(f"commit_hash {commit_hash} does not belong to branch {branch}")


def get_latest_commit_hash_from_branch(name: str, branch: str) -> Optional[str]:
    api = hf_hub.HfApi()
    try:
        model_info = api.model_info(name, revision=branch)

    except Exception as e:
        # If exception is due to branch not exist, return None and raise error up the stack
        # To satisfy lint

        status_code = e.response.status_code  # type: ignore[attr-defined]
        if status_code == 404:
            return None
        else:
            raise e
    sha = model_info.sha
    assert isinstance(sha, str)
    return sha

Functions

def check_commit_hash_branch_association(name: str, branch: str, commit_hash: str) ‑> None
Expand source code
def check_commit_hash_branch_association(name: str, branch: str, commit_hash: str) -> None:
    api = hf_hub.HfApi()

    # Commits are sorted by date (last commit first)
    commit_list = api.list_repo_commits(name, revision=branch)

    for commit in commit_list:
        if commit.commit_id == commit_hash:
            return

    raise RuntimeError(f"commit_hash {commit_hash} does not belong to branch {branch}")
def get_latest_commit_hash_from_branch(name: str, branch: str) ‑> Optional[str]
Expand source code
def get_latest_commit_hash_from_branch(name: str, branch: str) -> Optional[str]:
    api = hf_hub.HfApi()
    try:
        model_info = api.model_info(name, revision=branch)

    except Exception as e:
        # If exception is due to branch not exist, return None and raise error up the stack
        # To satisfy lint

        status_code = e.response.status_code  # type: ignore[attr-defined]
        if status_code == 404:
            return None
        else:
            raise e
    sha = model_info.sha
    assert isinstance(sha, str)
    return sha