Module lore.client

Expand source code
import logging
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import datasets as hf
import httpx
import pandas as pd
from determined.common.api import Session, bindings
from determined.common.api.authentication import Authentication, Credentials
from determined.experimental import checkpoint

import lore.const as const
import lore.types.base_types as bt
import lore.types.enums as bte
import lore.types.exceptions as ex
from lore.backend.utils import RequestClient
from lore.utils.server import create_session, get_det_master_address, obtain_token

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
SupportedCredentials = Union[Credentials, str, Session, Tuple[str, str], Authentication]


class Lore(RequestClient):
    def __init__(
        self,
        host: str = f"http://localhost:{const.DEFAULT_SERVER_PORT}",
        prefix=f"{const.GENAI_URL_PREFIX}/api/v1",
        client: Optional[httpx.Client] = None,
        credentials: Optional[SupportedCredentials] = None,
    ) -> None:
        if not host.startswith(("http://", "https://")):
            host = f"http://{host}"
        logger.info("Created lore client with base url: %s%s", host, prefix)
        super().__init__(client or httpx.Client(base_url=host), prefix)
        self.project: bt.Project = None
        self.workspace: bt.Workspace = None
        self._client_key = "notebook"  # To differentiate from web client to avoid interference
        self.max_new_tokens = None
        self.batch_size = None
        if credentials:
            self.login(credentials)

    @property
    def _required_token(self) -> str:
        """Return the token required for this operation.

        We have operations that do not clearly look out for user management this
        helps transition them into using token.
        """
        assert self.token is not None, "Token is required and assumed present for this operation."
        return self.token

    def _set_client_key(self, key: str):
        # This should only be used in testing
        self._client_key = key

    def login(
        self,
        credentials: SupportedCredentials,
        token: Optional[str] = None,
        mlde_host: Optional[str] = None,
    ) -> str:
        """Login to the lore server.

        token: @deprecated
        """
        if token is not None:
            # for backward compatibility.
            self.token = token
            return token
        token = None
        if isinstance(credentials, str):
            token = credentials
        elif isinstance(credentials, Credentials):
            if mlde_host is None:
                logger.warn(
                    f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
                )
                mlde_host = get_det_master_address()
            token = obtain_token(
                credentials.username, credentials.password, master_address=mlde_host
            )
        elif isinstance(credentials, tuple):
            if mlde_host is None:
                logger.warn(
                    f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
                )
                mlde_host = get_det_master_address()
            token = obtain_token(credentials[0], credentials[1], master_address=mlde_host)
        elif isinstance(credentials, Session):
            assert credentials._auth is not None, "Session must be authenticated."
            token = credentials._auth.get_session_token()
        elif isinstance(credentials, Authentication):
            token = credentials.get_session_token()
        else:
            raise ValueError(f"Unsupported credentials type: {type(credentials)}")
        self.token = token
        return token

    def set_workspace(self, workspace: Union[str, bt.Workspace, int]) -> None:
        if isinstance(workspace, str):
            workspaces = self.get_workspaces()
            workspace = next((w for w in workspaces if w.name == workspace), None)
            assert isinstance(workspace, bt.Workspace), f"Workspace {workspace} not found."
            self.workspace = workspace
        elif isinstance(workspace, bt.Workspace):
            self.workspace = self.get_workspace(workspace.id)
        else:
            self.workspace = self.get_workspace(workspace)
        self.project = self._get(
            f"{self._prefix}/project/{self.workspace.experiment_project_id}", out_cls=bt.Project
        )

    # Workspace resources.

    def create_workspace(self, workspace_name: str) -> bt.Workspace:
        workspace = bt.CreateWorkspaceRequest(name=workspace_name)
        return self._post(f"{self._prefix}/workspace", workspace, out_cls=bt.Workspace)

    def _can_reach_server(self, timeout: int = 60):
        self._get(f"{self._prefix}/workspaces", out_cls=None, timeout=timeout)

    def get_workspaces(self) -> List[bt.Workspace]:
        return self._get(f"{self._prefix}/workspaces", out_cls=bt.Workspace, seq=True)

    def get_workspace(self, workspace_id: int) -> bt.Workspace:
        return self._get(f"{self._prefix}/workspace/{workspace_id}", out_cls=bt.Workspace)

    def delete_workspace(self, workspace_id: int) -> None:
        return self._delete(f"{self._prefix}/workspace/{workspace_id}")

    # Dataset resources.

    def construct_dataset_from_hf(
        self,
        dataset_name: str,
        dataset_config_name: Optional[str] = None,
        task_type: Optional[str] = None,
        data_type: Optional[List[bt.DatasetDataType]] = None,
        token: Optional[str] = None,
    ) -> bt.Dataset:
        data = bt.ConstructDatasetFromHFRequest(
            dataset_name=dataset_name,
            dataset_config_name=dataset_config_name,
            task_type=task_type,
            data_type=data_type,
            token=token,
            workspace_id=self.workspace.id,
        )
        return self._post(
            f"{self._prefix}/construct_dataset_from_hf",
            in_data=data,
            out_cls=bt.Dataset,
            sync=False,
        )

    def construct_dataset_from_local(
        self,
        hf_dataset: Optional[Union[hf.Dataset, hf.DatasetDict]] = None,
        dataset_name: Optional[str] = None,
    ) -> bt.Dataset:
        if dataset_name is None:
            if isinstance(hf_dataset, hf.Dataset):
                dataset_name = hf_dataset.info.dataset_name
            else:
                # Find the first split with info and dataset_name
                dataset_name = None
                for key in hf_dataset:
                    if hasattr(hf_dataset[key], "info") and hasattr(
                        hf_dataset[key].info, "dataset_name"
                    ):
                        dataset_name = hf_dataset[key].info.dataset_name
                        break

        assert (
            dataset_name
        ), "The 'dataset_name' parameter is missing in hf_dataset. Please provide the dataset name explicitly."

        with tempfile.TemporaryDirectory() as tmp_dir:
            with tempfile.TemporaryDirectory() as zip_tmp_dir:
                hf_dataset.save_to_disk(os.path.join(tmp_dir, "hf_dataset"))

                shutil.make_archive(
                    os.path.join(zip_tmp_dir, "hf_dataset"), "zip", root_dir=tmp_dir
                )
                uploaded_files = self._upload_arrow_files(
                    os.path.join(zip_tmp_dir, "hf_dataset.zip")
                )
                data = bt.ConstructDatasetFromLocalRequest(
                    workspace_id=self.workspace.id,
                    dataset_name=dataset_name,
                    arrow_file_path=uploaded_files.file_paths[0],
                )

        return self._post(
            f"{self._prefix}/construct_dataset_from_local",
            in_data=data,
            out_cls=bt.Dataset,
            sync=False,
        )

    def commit_dataset(self, dataset: bt.Dataset) -> bt.Dataset:
        return self._post(f"{self._prefix}/dataset", in_data=dataset, out_cls=bt.Dataset)

    def get_datasets(self) -> List[bt.Dataset]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/datasets", out_cls=bt.Dataset, seq=True
        )

    def get_dataset(self, dataset_id: int) -> bt.Dataset:
        return self._get(
            f"{self._prefix}/dataset/{dataset_id}",
            out_cls=bt.Dataset,
        )

    def get_dataset_by_name(self, dataset_name: str) -> bt.Dataset:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/dataset/name",
            params={"name": dataset_name},
            out_cls=bt.Dataset,
        )

    def delete_dataset(self, dataset_id: int) -> None:
        return self._delete(f"{self._prefix}/dataset/{dataset_id}", sync=False)

    def sample_dataset(
        self,
        dataset: Union[bt.Dataset, str, int],
        start_index: Optional[int] = None,
        number_of_samples: Optional[int] = None,
        ratio: Optional[float] = None,
        splits: Optional[List[str]] = None,
        seed: int = 1234,
        as_dataset: bool = False,
    ) -> Union[bt.Dataset, Dict]:
        dataset = self._get_dataset(dataset)
        request = bt.SampleDatasetRequest(
            dataset=dataset,
            start_index=start_index,
            number_of_samples=number_of_samples,
            ratio=ratio,
            splits=splits,
            workspace_id=self.workspace.id,
            seed=seed,
            as_dataset=as_dataset,
        )

        out_cls = bt.Dataset if as_dataset is True else None
        output = self._post(
            f"{self._prefix}/dataset/sample", in_data=request, out_cls=out_cls, sync=False
        )
        if isinstance(output, bt.Dataset):
            return output
        else:
            dataset_sample = dict(output)
            return {key: pd.read_json(sample) for key, sample in dataset_sample.items()}

    def merge_and_resplit(
        self,
        dataset: Union[bt.Dataset, str, int],
        train_ratio: float,
        validation_ratio: float,
        test_ratio: float,
        splits_to_resplit: List[str],
        as_dataset: bool = True,
        shuffle: bool = True,
        seed: int = 1234,
    ) -> Union[bt.Dataset, Dict]:
        dataset = self._get_dataset(dataset)
        data = bt.MergeAndResplitRequest(
            dataset=dataset,
            train_ratio=train_ratio,
            validation_ratio=validation_ratio,
            test_ratio=test_ratio,
            splits_to_resplit=splits_to_resplit,
            workspace_id=self.workspace.id,
            shuffle=shuffle,
            seed=seed,
        )
        out_cls = bt.Dataset if as_dataset is True else dict
        return self._post(
            f"{self._prefix}/dataset/merge_and_resplit",
            out_cls=out_cls,
            in_data=data,
            sync=False,
        )

    def apply_prompt_template(
        self,
        dataset: bt.Dataset,
        prompt_template: bt.PromptTemplate,
        model: Optional[bt.Model] = None,
        new_column_name: str = "genai_input_text",
        expected_output_column_name: Optional[str] = None,
        batch_size: int = 10,
    ) -> bt.Dataset:
        data = {
            "dataset": dataset,
            "prompt_template": prompt_template,
            "model": model,
            "new_column_name": new_column_name,
            "expected_output_column_name": expected_output_column_name,
            "batch_size": batch_size,
            "workspace_id": self.workspace.id,
        }
        return self._post(
            f"{self._prefix}/dataset/apply_prompt_template",
            in_data=data,  # type: ignore
            out_cls=bt.Dataset,
            sync=False,
        )

    def compute_metrics(
        self,
        dataset: bt.Dataset,
        ground_truth_column_name: str,
        predictions_column_name: str,
        metrics: List[str] = ["exact_match"],
        substr_match: bool = True,
        split: Optional[str] = None,
        strip: bool = False,
        lower: bool = False,
    ) -> Tuple[bt.Dataset, Union[Dict[str, float], Dict[str, Dict[str, float]]]]:
        data = bt.ComputeMetricsRequest(
            dataset=dataset,
            ground_truth_column_name=ground_truth_column_name,
            predictions_column_name=predictions_column_name,
            metrics=metrics,
            substr_match=substr_match,
            workspace_id=self.workspace.id,
            split=split,
            strip=strip,
            lower=lower,
        )
        results = self._post(
            f"{self._prefix}/dataset/compute_metrics",
            in_data=data,
            out_cls=None,
            sync=False,
        )
        return bt.Dataset.parse_obj(results[0]), results[1]

    def materialize_dataset(
        self, dataset: bt.Dataset, output_path: Optional[str] = None
    ) -> hf.DatasetDict:
        if os.path.exists(dataset.storage_path):
            storage_path = dataset.storage_path
        else:
            storage_path = self._download_dataset(dataset, output_path)

        return hf.load_from_disk(storage_path)

    def _download_dataset(self, dataset: bt.Dataset, output_path: Optional[str] = None) -> str:
        # TODO (GAS-448): migrate to use self._get which handles authentication.
        file_response = self._client.get(
            f"{self._prefix}/dataset/{dataset.id}/download",
            headers={"Authorization": f"Bearer {self._required_token}"},
        )
        assert file_response.status_code == 200

        if output_path is None:
            hf_cache_dir = os.getenv("HF_CACHE")
            if hf_cache_dir:
                output_path = hf_cache_dir
            else:
                lore_cache_dir = "~/.cache/lore/datasets"
                os.makedirs(lore_cache_dir, exist_ok=True)
                output_path = lore_cache_dir

        with tempfile.NamedTemporaryFile(suffix=".zip") as temp_archive:
            for chunk in file_response.iter_bytes(chunk_size=8192):
                temp_archive.write(chunk)
            temp_archive.flush()
            with zipfile.ZipFile(temp_archive.name, "r") as zip_ref:
                zip_ref.extractall(output_path)
        return output_path

    def _upload_arrow_files(self, zip_file: str) -> bt.UploadedFiles:
        with open(zip_file, "rb") as f:
            file_name = zip_file.split("/")[-1]
            files = [("files", (file_name, f))]

            return self._post(
                f"{self._prefix}/dataset/upload_files", files=files, out_cls=bt.UploadedFiles
            )

    def concatenate_datasets(
        self, datasets: List[bt.Dataset], name: str, register_dataset: bool = False
    ) -> bt.Dataset:
        data = {
            "datasets": datasets,
            "name": name,
            "workspace_id": self.workspace.id,
            "register_dataset": register_dataset,
            "system_generated": False,  # as it is trigger from user facing Python client
        }
        request_body = bt.ConcatenateDatasetsRequest.parse_obj(data)
        return self._post(
            f"{self._prefix}/dataset/concatenate", in_data=request_body, out_cls=bt.Dataset
        )

    # Model resources.

    def get_models(self) -> List[bt.Model]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/models", out_cls=bt.Model, seq=True
        )

    def get_base_models(self) -> List[bt.Model]:
        models = self._get(f"{self._prefix}/base-models", out_cls=bt.Model, seq=True)
        models = sorted(models, key=lambda x: x.hf_model_name)
        return models

    def get_base_model(self, model_name: str) -> bt.Model:
        models = self.get_base_models()
        model = next((model for model in models if model.hf_model_name == model_name), None)
        if model is None:
            raise ValueError(f"{model_name} not found.")
        return model

    def get_model(self, model_id: int) -> bt.Model:
        return self._get(f"{self._prefix}/model/{model_id}", out_cls=bt.Model)

    def get_model_by_name(self, model_name: str, is_base_model: bool = True) -> bt.Model:
        model = self._post(
            f"{self._prefix}/model",
            in_data=bt.GetModelByNameRequest(
                name=model_name, workspace_id=self.workspace.id, is_base_model=is_base_model
            ),
            out_cls=bt.Model,
            seq=False,
        )
        return model

    def register_model(self, model: bt.Model) -> bt.Model:
        if model.workspace_id is None:
            model.workspace_id = self.workspace.id
        assert model.workspace_id is not None, f"Workspace id is required for model registration"
        return self._post(f"{self._prefix}/model/register", model, out_cls=bt.Model)

    def delete_model(self, model_id: int) -> None:
        self._delete(f"{self._prefix}/model/{model_id}")

    def create_hf_model(
        self,
        hf_model_name: str,
        branch: Optional[str] = None,
        hf_model_commit: Optional[str] = None,
        formatting_tokens: Optional[bt.FormattingTokens] = None,
        problem_type: str = "",
        model_config: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> bt.Model:
        return bt.Model.from_hf(
            hf_model_name=hf_model_name,
            branch=branch,
            hf_model_commit=hf_model_commit,
            formatting_tokens=formatting_tokens,
            problem_type=problem_type,
            model_config=model_config,
            **kwargs,
        )

    def download_model(
        self, model_id: int, output_path: Optional[str] = None, mlde_host: Optional[str] = None
    ) -> str:
        """Downloads a model by checkpoint through the determined master directly
        Arguments:
            model_id (int): id for the model
            output_path (string): the folder path to save the model to
        Returns:
            String of model folder path
        """
        model = self.get_model(model_id)
        assert (
            model.genai_checkpoint_uuid
        ), f"Model ({model_id}) is missing a checkpoint uuid. Cancelling download"
        if mlde_host is None:
            logger.warn(
                f"mlde_host is not provided for download_model. Falling back {get_det_master_address()}"
            )
            mlde_host = get_det_master_address()
        # Call determined directly to avoid downloading our
        # checkpoints in two hops.
        session = create_session(self._required_token, master_address=mlde_host)
        resp = bindings.get_GetCheckpoint(session, checkpointUuid=model.genai_checkpoint_uuid)
        ckpt = checkpoint.Checkpoint._from_bindings(resp.checkpoint, session)
        Path(output_path).mkdir(parents=True)
        ckpt.download(output_path)
        return str(output_path)

    def upload_model_to_hf(
        self,
        model: bt.Model,
        hf_repo_owner: str,
        hf_repo_name: str,
        hf_token: str,
        private: bool = False,
    ) -> int:
        data = bt.UploadModelToHFRequest(
            model=model,
            hf_repo_owner=hf_repo_owner,
            hf_repo_name=hf_repo_name,
            hf_token=hf_token,
            private=private,
        )
        response: bt.UploadModelToHFResponse = self._post(
            f"{self._prefix}/model/upload_model_to_hf",
            in_data=data,
            out_cls=bt.UploadModelToHFResponse,
            sync=True,
        )
        return response.experiment_id

    # Prompt template resources.

    def commit_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
        return self._post(
            f"{self._prefix}/prompt-template", in_data=prompt_template, out_cls=bt.PromptTemplate
        )

    def get_prompt_templates(self) -> List[bt.PromptTemplate]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/prompt-templates",
            out_cls=bt.PromptTemplate,
            seq=True,
        )

    def get_prompt_template(self, prompt_template_id: int) -> bt.PromptTemplate:
        return self._get(
            f"{self._prefix}/prompt-template/{prompt_template_id}", out_cls=bt.PromptTemplate
        )

    def get_start_prompt_templates(self) -> List[bt.PromptTemplate]:
        return self._get(
            f"{self._prefix}/starting-prompt-templates", out_cls=bt.PromptTemplate, seq=True
        )

    def update_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
        return self._put(
            f"{self._prefix}/prompt-template/{prompt_template.id}",
            in_data=prompt_template,
            out_cls=bt.PromptTemplate,
        )

    def delete_prompt_template(self, prompt_template_id: int) -> None:
        self._delete(f"{self._prefix}/prompt-template/{prompt_template_id}")

    # Chat resources.
    def _get_recommended_vllm_configs(
        self, resource_pool, model: Union[str, bt.Model], is_base_model: bool
    ) -> bt.ResourcePoolInferenceConfigs:
        if isinstance(model, str):
            model = self.get_model_by_name(model, is_base_model)

        if resource_pool is not None:
            request = bt.InferenceConfigRequest(max_config_per_resource_pool=1, model_id=model.id)
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/inference/config",
                in_data=request,
                out_cls=bt.ResourcePoolInferenceConfigs,
                seq=True,
            )
            filtered_result = [c for c in result if c.resource_pool.name == resource_pool]
            if len(filtered_result) == 0:
                raise ex.InferenceConfigNotFoundException(
                    f"Resource pool {resource_pool} cannot run model {model.name}."
                )
            config = filtered_result[0]
        else:
            request = bt.InferenceConfigRequest(
                max_config_per_resource_pool=1, max_resource_pool=1, model_id=model.id
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/inference/config",
                in_data=request,
                out_cls=bt.ResourcePoolInferenceConfigs,
                seq=True,
            )
            if len(result) == 0:
                raise ex.InferenceConfigNotFoundException(
                    f"We cannot find any recommended configs for model {model.name}."
                )
            config = result[0]

        logger.info(
            f"Recommended vllm inference config is: resource_pool: {config.resource_pool.name}, slots: {config.vllm_configs[0].slots_per_trial}"
        )
        return config

    def _try_recommend_vllm_config(
        self,
        use_vllm,
        vllm_tensor_parallel_size,
        is_arbitrary_hf_model,
        disable_inference_config_recommendation,
    ):
        if disable_inference_config_recommendation:
            return False

        if not use_vllm:
            return False

        # load_model input validation already force resource pool to be given if vllm_tensor_parallel_size is provided
        if vllm_tensor_parallel_size:
            return False

        # No config recommendation for non-vetted HF models
        if is_arbitrary_hf_model:
            return False
        return True

    def load_model(
        self,
        model: Union[str, bt.Model],
        is_genai_base_model: bool = True,
        genai_model_version: Optional[int] = None,
        *,
        resource_pool: Optional[str] = None,
        use_vllm: bool = True,
        vllm_tensor_parallel_size: Optional[int] = None,
        vllm_swap_space: Optional[int] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        is_arbitrary_hf_model: bool = False,
        disable_inference_config_recommendation: bool = False,
    ) -> bool:
        # Input validation
        if not use_vllm and vllm_tensor_parallel_size is not None:
            raise ValueError(
                "vllm_tensor_parallel_size should not be provided if use_vllm is False"
            )

        if vllm_tensor_parallel_size is not None and resource_pool is None:
            raise ValueError(
                "Please provide resource_pool when vllm_tensor_parallel_size is given. Otherwise, leave both fields blank for system recommendation."
            )

        if not use_vllm and vllm_tensor_parallel_size is not None:
            raise ValueError("use_vllm is False but vllm_tensor_parallel_size is not None")

        if is_arbitrary_hf_model and is_genai_base_model:
            raise ValueError("is_arbitrary_hf_model and is_genai_base_model cannot both be True")

        if torch_dtype is not None and torch_dtype not in (
            const.BF16_TYPE_STR,
            const.FP16_TYPE_STR,
        ):
            raise ValueError(
                f"torch_dtype should be {const.BF16_TYPE_STR} or {const.FP16_TYPE_STR}. Note that support for {const.BF16_TYPE_STR} is GPU dependent."
            )

        # End input validation

        try_config_recommendation = self._try_recommend_vllm_config(
            use_vllm=use_vllm,
            vllm_tensor_parallel_size=vllm_tensor_parallel_size,
            is_arbitrary_hf_model=is_arbitrary_hf_model,
            disable_inference_config_recommendation=disable_inference_config_recommendation,
        )

        recommended_config_found = False
        if try_config_recommendation:
            logger.info("Getting recommended vllm inference config from server")
            try:
                recommended_config = self._get_recommended_vllm_configs(
                    resource_pool=resource_pool, model=model, is_base_model=is_genai_base_model
                )
                logger.info(
                    "Using recommended config to run inference, to turn off recommendation, set disable_inference_config_recommendation = True"
                )
                # We should be only returning one vllm config per resource pool
                inference_config = recommended_config.vllm_configs[0]
                slots = inference_config.slots_per_trial
                resource_pool = recommended_config.resource_pool.name
                if torch_dtype is not None and torch_dtype != inference_config.torch_dtype:
                    logger.info(
                        f"torch_dtype {torch_dtype} will be overwritten by recommended config value {inference_config.torch_dtype}"
                    )
                torch_dtype = inference_config.torch_dtype
                self.max_new_tokens = inference_config.max_new_tokens
                self.batch_size = inference_config.batch_size
                vllm_config = {
                    "tensor-parallel-size": slots,
                    "swap-space": inference_config.swap_space,
                }
                recommended_config_found = True
            except ex.InferenceConfigNotFoundException as e:
                if resource_pool:
                    logger.warning(
                        f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with provided resource_pool: {resource_pool}"
                    )
                else:
                    logger.warning(
                        f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with default resource pool for the workspace"
                    )

        if not recommended_config_found:
            logger.info("User provided / default model config is used.")
            # Reset max_new_tokens and batch_size
            self.max_new_tokens = None
            self.batch_size = None
            if vllm_tensor_parallel_size:
                slots = vllm_tensor_parallel_size
            else:
                slots = const.DEFAULT_SLOTS_PER_TRIAL

            vllm_config = {
                "tensor-parallel-size": slots,
                "swap-space": vllm_swap_space if vllm_swap_space else const.DEFAULT_SWAP_SPACE,
            }

        model_load_config: bt.ModelLoadConfig = bt.ModelLoadConfig(
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=low_cpu_mem_usage,
            trust_remote_code=trust_remote_code,
            token=hf_token,
            vllm_config=vllm_config,
        )

        if isinstance(model, str):
            params = {
                "name": model,
                "is_base_model": is_genai_base_model,
                "genai_model_version": genai_model_version,
                "resource_pool": resource_pool,
                "slots": slots,
                "use_vllm": use_vllm,
                "is_arbitrary_hf_model": is_arbitrary_hf_model,
                "client_key": self._client_key,
            }
            endpoint = f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model"
        else:
            params = {
                "resource_pool": resource_pool,
                "slots": slots,
                "use_vllm": use_vllm,
                "client_key": self._client_key,
            }
            endpoint = (
                f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model/{model.id}"
            )
        params = {k: params[k] for k in params if params[k] is not None}
        return self._put(endpoint, params=params, in_data=model_load_config, sync=False)

    def generate(
        self,
        text: Union[str, List[str]],
        generation_config: Dict[str, Any] = {},
        batch_size: Optional[int] = None,
    ) -> Union[str, List[str]]:
        # Input validation
        if (
            self.batch_size is not None and batch_size is not None
        ) and batch_size > self.batch_size:
            raise ValueError(
                f"batch_size should be <= {self.batch_size} based on recommended model config"
            )
        if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
            if generation_config["max_new_tokens"] > self.max_new_tokens:
                raise ValueError(
                    f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
                )

        # End input validation

        seq = False if isinstance(text, str) else True
        data = bt.GenerateChatResponseRequest(
            prompts=text,
            generation_config=generation_config,
            batch_size=batch_size,
            client_key=self._client_key,
        )
        return self._post(
            f"{self._prefix}/workspace/{self.workspace.id}/inference/generate",
            in_data=data,
            out_cls=str,
            sync=False,
            seq=seq,
            poll_interval=0.1,
        )

    def generate_on_dataset(
        self,
        dataset: bt.Dataset,
        prompt_template: bt.PromptTemplate,
        as_dataset: bool = True,
        output_column_name: str = "genai_generated_text",
        expected_output_column_name: Optional[str] = "genai_expected_text",
        split: Optional[str] = None,
        start_index: Optional[int] = None,
        number_of_samples: Optional[int] = None,
        batch_size: int = 10,
        test_mode: bool = False,
        generation_config: Dict[str, Any] = {},
        num_proc: Optional[int] = None,
    ) -> bt.Dataset:
        # Input validation
        if (
            self.batch_size is not None and batch_size is not None
        ) and batch_size > self.batch_size:
            raise ValueError(
                f"batch_size should be <= {self.batch_size} based on recommended model config"
            )
        if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
            if generation_config["max_new_tokens"] > self.max_new_tokens:
                raise ValueError(
                    f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
                )
        # End input validation

        data = bt.GenerateOnDatasetRequest(
            dataset=dataset,
            prompt_template=prompt_template,
            output_column_name=output_column_name,
            expected_output_column_name=expected_output_column_name,
            split=split,
            start_index=start_index,
            number_of_samples=number_of_samples,
            batch_size=batch_size,
            test_mode=test_mode,
            generation_config=generation_config,
            num_proc=num_proc,
            client_key=self._client_key,
        )

        out_cls = bt.Dataset if as_dataset else dict
        return self._post(
            f"{self._prefix}/workspace/{self.workspace.id}/inference/generate_on_dataset",
            in_data=data,
            out_cls=out_cls,
            sync=False,
        )

    # Experiment resources.
    def _try_recommend_training_config(
        self, base_model: bt.Model, disable_training_config_recommendation: bool
    ) -> bool:
        if disable_training_config_recommendation:
            return False
        if base_model.model_architecture == "":
            # We need to know model_architecture to recommend training configs
            return False

        return True

    def _find_training_configs_recommendation(
        self, model_architecture: str, resource_pool: Optional[str]
    ) -> bt.ResourcePoolTrainingConfigs:
        if resource_pool is not None:
            request = bt.TrainingConfigRequest(
                model_architecture=model_architecture, max_config_per_resource_pool=1
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/model/training-configs",
                in_data=request,
                out_cls=bt.ResourcePoolTrainingConfigs,
                seq=True,
            )

            filtered_result = [c for c in result if c.resource_pool.name == resource_pool]
            if len(filtered_result) == 0:
                raise ex.TrainingConfigNotFoundException(
                    f"No config found for resource pool {resource_pool}"
                )
            config = filtered_result[0]
        else:
            request = bt.TrainingConfigRequest(
                model_architecture=model_architecture,
                max_config_per_resource_pool=1,
                max_resource_pool=1,
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/model/training-configs",
                in_data=request,
                out_cls=bt.ResourcePoolTrainingConfigs,
                seq=True,
            )

            if len(result) == 0:
                raise ex.TrainingConfigNotFoundException(
                    f"No config found for the model with architecture {model_architecture}"
                )
            config = result[0]
        return config

    def launch_training(
        self,
        dataset: bt.Dataset,
        base_model: bt.Model,
        name: Optional[str] = None,
        prompt_template: Optional[bt.PromptTemplate] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        slots_per_trial: Optional[int] = None,
        max_steps: int = -1,
        num_train_epochs: int = 1,
        save_total_limit: int = 1,
        learning_rate: float = 5e-5,
        per_device_train_batch_size: Optional[int] = None,
        per_device_eval_batch_size: Optional[int] = None,
        do_train: bool = True,
        do_eval: bool = True,
        block_size: Optional[int] = None,
        deepspeed: Optional[bool] = None,
        gradient_checkpointing: Optional[bool] = None,
        ds_cpu_offload: bool = False,
        seed: int = 1337,
        output_dir: str = "/tmp/lore_output",
        eval_steps: Optional[int] = None,
        logging_steps: Optional[int] = None,
        save_steps: Optional[int] = None,
        logging_strategy: str = "",
        evaluation_strategy: str = "",
        save_strategy: str = "",
        resource_pool: Optional[str] = None,
        disable_training_config_recommendation: bool = False,
    ) -> bt.Experiment:
        # Input validation
        if slots_per_trial is not None and resource_pool is None:
            raise ValueError(
                "slots_per_trial is provided but not resource_pool. Either provide both or neither."
            )

        if torch_dtype is not None and torch_dtype not in (
            const.BF16_TYPE_STR,
            const.FP16_TYPE_STR,
        ):
            raise ValueError(
                f"torch_dtype should be {const.FP16_TYPE_STR} or {const.BF16_TYPE_STR}. Note that {const.BF16_TYPE_STR} support is GPU dependent"
            )
        # End input validation

        # Register dataset if needed
        if dataset.id is None:
            dataset = self.commit_dataset(dataset)

        # Try get training config recommendation
        try_recommend_training_config = self._try_recommend_training_config(
            base_model=base_model,
            disable_training_config_recommendation=disable_training_config_recommendation,
        )
        config_recommendation_found = False

        if try_recommend_training_config:
            try:
                config = self._find_training_configs_recommendation(
                    model_architecture=base_model.model_architecture, resource_pool=resource_pool
                )
                logger.info(
                    "Using recommended config to launch training, to turn off recommendation, set disable_training_config_recommendation = True"
                )
                config_recommendation_found = True
                training_config = config.training_configs[0]
                slots_per_trial = training_config.slots_per_trial
                resource_pool = config.resource_pool.name
                logger.info(f"resource_pool is {resource_pool}")

                if (
                    per_device_train_batch_size is not None
                    and per_device_train_batch_size != training_config.batch_size
                ):
                    logger.warn(
                        f"per_device_train_batch_size: {per_device_train_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                    )
                per_device_train_batch_size = training_config.batch_size

                if (
                    per_device_eval_batch_size is not None
                    and per_device_eval_batch_size != training_config.batch_size
                ):
                    logger.warn(
                        f"per_device_eval_batch_size: {per_device_eval_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                    )
                per_device_eval_batch_size = training_config.batch_size

                if deepspeed is not None and deepspeed != training_config.deepspeed:
                    logger.warn(
                        f"deepspeed {deepspeed} will be overwritten by config recommendation to {training_config.deepspeed}"
                    )
                deepspeed = training_config.deepspeed

                if (
                    gradient_checkpointing is not None
                    and training_config.deepspeed != training_config.gradient_checkpointing
                ):
                    logger.warn(
                        f"gradient_checkpointing {gradient_checkpointing} will be overwritten by config recommendation to {training_config.gradient_checkpointing}"
                    )
                gradient_checkpointing = training_config.gradient_checkpointing

                if block_size is not None and block_size != training_config.context_window:
                    logger.warn(
                        f"block_size {block_size} will be overwritten by config recommendation to {training_config.context_window}"
                    )
                block_size = training_config.context_window

                if torch_dtype is not None and torch_dtype != training_config.torch_dtype:
                    logger.warn(
                        f"torch_dtype {torch_dtype} will be overwritten by config recommendation to {training_config.torch_dtype}"
                    )
                torch_dtype = training_config.torch_dtype

            except ex.TrainingConfigNotFoundException as e:
                logger.warn(str(e))

        if not config_recommendation_found:
            if slots_per_trial is None:
                slots_per_trial = const.DEFAULT_SLOTS_PER_TRIAL
                logger.info(f"Using default slots_per_trial: {slots_per_trial}")
            else:
                logger.info(f"Using user provided slots_per_trial: {slots_per_trial}")

            if per_device_train_batch_size is None:
                per_device_train_batch_size = const.DEFAULT_PER_DEVICE_TRAIN_BATCH_SIZE
                logger.info(
                    f"Using default per_device_train_batch_size: {per_device_train_batch_size}"
                )
            else:
                logger.info(
                    f"Using user provided per_device_train_batch_size: {per_device_train_batch_size}"
                )

            if per_device_eval_batch_size is None:
                per_device_eval_batch_size = const.DEFAULT_PER_DEVICE_EVAL_BATCH_SIZE
                logger.info(
                    f"Using default per_device_eval_batch_size: {per_device_eval_batch_size}"
                )
            else:
                logger.info(
                    f"Using user provided per_device_eval_batch_size: {per_device_eval_batch_size}"
                )

            if torch_dtype is None:
                torch_dtype = const.DEFAULT_TRAIN_TORCH_DTYPE
                logger.info(f"Using default torch_dtype: {torch_dtype}")
            else:
                logger.info(f"Using user provided torch_dtype: {torch_dtype}")

            if deepspeed is None:
                deepspeed = const.DEFAULT_TRAIN_DEEPSPEED_FLAG
                logger.info(f"Using default deepspeed: {deepspeed}")
            else:
                logger.info(f"Using user provided deepspeed: {deepspeed}")

            if gradient_checkpointing is None:
                gradient_checkpointing = const.DEFAULT_TRAIN_GRADIENT_CHECKPOINTING_FLAG
                logger.info(f"Using default gradient_checkpointing: {gradient_checkpointing}")
            else:
                logger.info(f"Using user provided gradient_checkpointing: {gradient_checkpointing}")

        # TODO: build on top of the existing ExperimentSettings instead of an
        # ad-hoc dict.
        data = {
            "name": name,
            "model": base_model,
            "project_id": self.project.id,
            "model_load_config": {
                "torch_dtype": torch_dtype,
                "low_cpu_mem_usage": low_cpu_mem_usage,
                "trust_remote_code": trust_remote_code,
                "token": hf_token,
            },
            "train_config": {
                "data_args": {
                    "dataset_data": dataset,
                    "block_size": block_size,
                },
                "slots_per_trial": slots_per_trial,
                "max_steps": max_steps,
                "num_train_epochs": num_train_epochs,
                "save_total_limit": save_total_limit,
                "learning_rate": learning_rate,
                "per_device_train_batch_size": per_device_train_batch_size,
                "per_device_eval_batch_size": per_device_eval_batch_size,
                "do_train": do_train,
                "do_eval": do_eval,
                "fp16": True if torch_dtype == const.FP16_TYPE_STR else False,
                "bf16": True if torch_dtype == const.BF16_TYPE_STR else False,
                "deepspeed": deepspeed,
                "gradient_checkpointing": gradient_checkpointing,
                "ds_cpu_offload": ds_cpu_offload,
                "seed": seed,
                "output_dir": output_dir,
                "eval_steps": eval_steps,
                "logging_steps": logging_steps,
                "save_steps": save_steps,
                "logging_strategy": logging_strategy,
                "evaluation_strategy": evaluation_strategy,
                "save_strategy": save_strategy,
            },
            "prompt_template": prompt_template,
            "resource_pool": resource_pool,
        }
        return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))

    def launch_evaluation(
        self,
        model: bt.Model,
        tasks: Union[str, List[str]],
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        name: Optional[str] = None,
        num_fewshot: int = 0,
        use_accelerate: bool = False,
        slots_per_trial: int = 1,
        max_concurrent_trials: int = 1,
        resource_pool: Optional[str] = None,
    ) -> bt.Experiment:
        data = bt.ExperimentSettings(
            name=name,
            model=model,
            project_id=self.project.id,
            model_load_config=bt.ModelLoadConfig(
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                token=hf_token,
            ),
            eval_config=bt.EvaluationConfig(
                data_args={},
                tasks=tasks,
                num_fewshot=num_fewshot,
                use_accelerate=use_accelerate,
                slots_per_trial=slots_per_trial,
                max_concurrent_trials=max_concurrent_trials,
            ),
            resource_pool=resource_pool,
        )

        return self._launch_experiment(in_data=data)

    def launch_evalution_custom_multi_choice(
        self,
        model: bt.Model,
        dataset: bt.Dataset,
        name: Optional[str] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        query_col: str = "query",
        choices_col: str = "choices",
        gold_col: str = "gold",
        num_fewshot: int = 0,
        use_accelerate: bool = False,
        slots_per_trial: int = 1,
        max_concurrent_trials: int = 1,
        resource_pool: Optional[str] = None,
    ) -> bt.Experiment:
        data = {
            "name": name,
            "model": model,
            "project_id": self.project.id,
            "model_load_config": {
                "torch_dtype": torch_dtype,
                "low_cpu_mem_usage": low_cpu_mem_usage,
                "trust_remote_code": trust_remote_code,
                "token": hf_token,
            },
            "eval_config": {
                "data_args": {
                    "dataset_data": dataset,
                },
                "tasks": [
                    {
                        "query_col": query_col,
                        "choices_col": choices_col,
                        "gold_col": gold_col,
                        "dataset_uuid": None,
                        "hf_dataset_name": None,
                        "hf_dataset_path": dataset.storage_path,
                    },
                ],
                "num_fewshot": num_fewshot,
                "use_accelerate": use_accelerate,
                "slots_per_trial": slots_per_trial,
                "max_concurrent_trials": max_concurrent_trials,
            },
            "resource_pool": resource_pool,
        }

        return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))

    def _launch_experiment(self, in_data: bt.ExperimentSettings) -> bt.Experiment:
        return self._post(
            f"{self._prefix}/experiment",
            in_data=in_data,
            out_cls=bt.Experiment,
        )

    def get_experiments(self) -> List[bt.Experiment]:
        return self._get(
            f"{self._prefix}/project/{self.project.id}/experiments", seq=True, out_cls=bt.Experiment
        )

    def get_experiment(self, exp_id: int) -> bt.Experiment:
        return self._get(f"{self._prefix}/experiment/{exp_id}", out_cls=bt.Experiment)

    def delete_experiment(self, exp_id: int) -> None:
        self._delete(f"{self._prefix}/experiment/{exp_id}")

    def kill_experiment(self, exp_id: int) -> None:
        self._put(f"{self._prefix}/experiment/{exp_id}/kill")

    def get_experiment_models(self, exp_id: int) -> List[bt.Model]:
        return self._get(f"{self._prefix}/experiment/{exp_id}/models", seq=True, out_cls=bt.Model)

    def get_experiment_metrics(self, exp_id: int) -> Dict[str, Dict[str, Any]]:
        """Gets latest metrics for an experiment."""
        return self._get(f"{self._prefix}/experiment/{exp_id}/metrics", out_cls=None)

    def create_playground_snapshot(
        self, request: bt.CreatePlaygroundSnaphotRequest
    ) -> bt.PlaygroundSnapshot:
        return self._post(
            f"{self._prefix}/playground-snapshot", in_data=request, out_cls=bt.PlaygroundSnapshot
        )

    def get_playground_snapshot(self, id: int) -> bt.PlaygroundSnapshot:
        return self._get(f"{self._prefix}/playground-snapshot/{id}", out_cls=bt.PlaygroundSnapshot)

    def get_playground_snapshots(self) -> List[bt.PlaygroundSnapshot]:
        return self._get(
            f"{self._prefix}/project/{self.workspace.experiment_project_id}/playground-snapshots",
            seq=True,
            out_cls=bt.PlaygroundSnapshot,
        )

    def delete_playground_snapshot(self, id: int):
        self._delete(f"{self._prefix}/playground-snapshot/{id}")

    def update_playground_snapshot(
        self, request: bt.UpdatePlaygroundSnaphotRequest
    ) -> bt.PlaygroundSnapshot:
        return self._put(
            f"{self._prefix}/playground-snapshot/{request.id}",
            in_data=request,
            out_cls=bt.PlaygroundSnapshot,
        )

    def restart_controller(
        self, controller_type: bte.ControllerType
    ) -> bt.RestartControllerResponse:
        if controller_type == bte.ControllerType.DATA:
            request = bt.RestartControllerRequest(
                controller_type=controller_type,
                workspace_id=self.workspace.id,
            )
        else:
            request = bt.RestartControllerRequest(
                controller_type=controller_type,
                workspace_id=self.workspace.id,
                client_key=self._client_key,
            )
        route = f"{self._prefix}/controller/restart"
        out_cls = bt.RestartControllerResponse
        return self._put(route, request, out_cls=out_cls)

    # Utils

    def _get_dataset(self, dataset: Union[bt.Dataset, str, int]) -> bt.Dataset:
        if isinstance(dataset, bt.Dataset):
            return dataset
        elif isinstance(dataset, str):
            return self.get_dataset_by_name(dataset)
        else:
            return self.get_dataset(dataset)

    def build_local_dataset(
        self,
        train_files: Optional[Union[str, List[str]]] = None,
        validation_files: Optional[Union[str, List[str]]] = None,
        test_files: Optional[Union[str, List[str]]] = None,
    ) -> Union[hf.Dataset, hf.DatasetDict]:
        train_files = self._str_to_list(train_files)
        validation_files = self._str_to_list(validation_files)
        test_files = self._str_to_list(test_files)
        splits = {
            "train": train_files,
            "validation": validation_files,
            "test": test_files,
        }
        splits = {key: value for key, value in splits.items() if value is not None}
        file_format = self._get_file_extension(splits)

        dataset = hf.load_dataset(
            file_format,
            data_files=splits,
        )
        return dataset

    def _get_file_extension(self, splits: Dict[str, List[str]]) -> str:
        extensions: Set = set()
        for filenames in splits.values():
            if filenames is None:
                continue
            split_extensions = {filename.split(".")[-1] for filename in filenames}
            extensions.update(split_extensions)
            if len(extensions) > 1:
                raise ValueError(
                    f"Files have different extensions: {extensions}. "
                    f"Only one file file format (csv, json, or jsonl) allowed at a time."
                )

        extension = list(extensions)[0]
        if extension not in ["csv", "json", "jsonl"]:
            raise ValueError(
                f"File extension {extension} not supported. " f"Only csv, json, or jsonl allowed."
            )

        # AC: jsonl can be loaded as json.
        if extension == "jsonl":
            extension = "json"

        return extension

    def _str_to_list(self, value: Union[str, List[str]]) -> List[str]:
        return [value] if isinstance(value, str) else value

    def _join_filepaths_for_splits(
        self, dataset_directory_path: str, splits: Dict[str, List[str]]
    ) -> Dict[str, List[str]]:
        for s, filenames in splits.items():
            new_filenames = [os.path.join(dataset_directory_path, f) for f in filenames]
            splits[s] = new_filenames

        return splits

Classes

class Lore (host: str = 'http://localhost:9011', prefix='/genai/api/v1', client: Optional[httpx.Client] = None, credentials: Union[determined.common.api.authentication.Credentials, str, determined.common.api._session.Session, Tuple[str, str], determined.common.api.authentication.Authentication, ForwardRef(None)] = None)
Expand source code
class Lore(RequestClient):
    def __init__(
        self,
        host: str = f"http://localhost:{const.DEFAULT_SERVER_PORT}",
        prefix=f"{const.GENAI_URL_PREFIX}/api/v1",
        client: Optional[httpx.Client] = None,
        credentials: Optional[SupportedCredentials] = None,
    ) -> None:
        if not host.startswith(("http://", "https://")):
            host = f"http://{host}"
        logger.info("Created lore client with base url: %s%s", host, prefix)
        super().__init__(client or httpx.Client(base_url=host), prefix)
        self.project: bt.Project = None
        self.workspace: bt.Workspace = None
        self._client_key = "notebook"  # To differentiate from web client to avoid interference
        self.max_new_tokens = None
        self.batch_size = None
        if credentials:
            self.login(credentials)

    @property
    def _required_token(self) -> str:
        """Return the token required for this operation.

        We have operations that do not clearly look out for user management this
        helps transition them into using token.
        """
        assert self.token is not None, "Token is required and assumed present for this operation."
        return self.token

    def _set_client_key(self, key: str):
        # This should only be used in testing
        self._client_key = key

    def login(
        self,
        credentials: SupportedCredentials,
        token: Optional[str] = None,
        mlde_host: Optional[str] = None,
    ) -> str:
        """Login to the lore server.

        token: @deprecated
        """
        if token is not None:
            # for backward compatibility.
            self.token = token
            return token
        token = None
        if isinstance(credentials, str):
            token = credentials
        elif isinstance(credentials, Credentials):
            if mlde_host is None:
                logger.warn(
                    f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
                )
                mlde_host = get_det_master_address()
            token = obtain_token(
                credentials.username, credentials.password, master_address=mlde_host
            )
        elif isinstance(credentials, tuple):
            if mlde_host is None:
                logger.warn(
                    f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
                )
                mlde_host = get_det_master_address()
            token = obtain_token(credentials[0], credentials[1], master_address=mlde_host)
        elif isinstance(credentials, Session):
            assert credentials._auth is not None, "Session must be authenticated."
            token = credentials._auth.get_session_token()
        elif isinstance(credentials, Authentication):
            token = credentials.get_session_token()
        else:
            raise ValueError(f"Unsupported credentials type: {type(credentials)}")
        self.token = token
        return token

    def set_workspace(self, workspace: Union[str, bt.Workspace, int]) -> None:
        if isinstance(workspace, str):
            workspaces = self.get_workspaces()
            workspace = next((w for w in workspaces if w.name == workspace), None)
            assert isinstance(workspace, bt.Workspace), f"Workspace {workspace} not found."
            self.workspace = workspace
        elif isinstance(workspace, bt.Workspace):
            self.workspace = self.get_workspace(workspace.id)
        else:
            self.workspace = self.get_workspace(workspace)
        self.project = self._get(
            f"{self._prefix}/project/{self.workspace.experiment_project_id}", out_cls=bt.Project
        )

    # Workspace resources.

    def create_workspace(self, workspace_name: str) -> bt.Workspace:
        workspace = bt.CreateWorkspaceRequest(name=workspace_name)
        return self._post(f"{self._prefix}/workspace", workspace, out_cls=bt.Workspace)

    def _can_reach_server(self, timeout: int = 60):
        self._get(f"{self._prefix}/workspaces", out_cls=None, timeout=timeout)

    def get_workspaces(self) -> List[bt.Workspace]:
        return self._get(f"{self._prefix}/workspaces", out_cls=bt.Workspace, seq=True)

    def get_workspace(self, workspace_id: int) -> bt.Workspace:
        return self._get(f"{self._prefix}/workspace/{workspace_id}", out_cls=bt.Workspace)

    def delete_workspace(self, workspace_id: int) -> None:
        return self._delete(f"{self._prefix}/workspace/{workspace_id}")

    # Dataset resources.

    def construct_dataset_from_hf(
        self,
        dataset_name: str,
        dataset_config_name: Optional[str] = None,
        task_type: Optional[str] = None,
        data_type: Optional[List[bt.DatasetDataType]] = None,
        token: Optional[str] = None,
    ) -> bt.Dataset:
        data = bt.ConstructDatasetFromHFRequest(
            dataset_name=dataset_name,
            dataset_config_name=dataset_config_name,
            task_type=task_type,
            data_type=data_type,
            token=token,
            workspace_id=self.workspace.id,
        )
        return self._post(
            f"{self._prefix}/construct_dataset_from_hf",
            in_data=data,
            out_cls=bt.Dataset,
            sync=False,
        )

    def construct_dataset_from_local(
        self,
        hf_dataset: Optional[Union[hf.Dataset, hf.DatasetDict]] = None,
        dataset_name: Optional[str] = None,
    ) -> bt.Dataset:
        if dataset_name is None:
            if isinstance(hf_dataset, hf.Dataset):
                dataset_name = hf_dataset.info.dataset_name
            else:
                # Find the first split with info and dataset_name
                dataset_name = None
                for key in hf_dataset:
                    if hasattr(hf_dataset[key], "info") and hasattr(
                        hf_dataset[key].info, "dataset_name"
                    ):
                        dataset_name = hf_dataset[key].info.dataset_name
                        break

        assert (
            dataset_name
        ), "The 'dataset_name' parameter is missing in hf_dataset. Please provide the dataset name explicitly."

        with tempfile.TemporaryDirectory() as tmp_dir:
            with tempfile.TemporaryDirectory() as zip_tmp_dir:
                hf_dataset.save_to_disk(os.path.join(tmp_dir, "hf_dataset"))

                shutil.make_archive(
                    os.path.join(zip_tmp_dir, "hf_dataset"), "zip", root_dir=tmp_dir
                )
                uploaded_files = self._upload_arrow_files(
                    os.path.join(zip_tmp_dir, "hf_dataset.zip")
                )
                data = bt.ConstructDatasetFromLocalRequest(
                    workspace_id=self.workspace.id,
                    dataset_name=dataset_name,
                    arrow_file_path=uploaded_files.file_paths[0],
                )

        return self._post(
            f"{self._prefix}/construct_dataset_from_local",
            in_data=data,
            out_cls=bt.Dataset,
            sync=False,
        )

    def commit_dataset(self, dataset: bt.Dataset) -> bt.Dataset:
        return self._post(f"{self._prefix}/dataset", in_data=dataset, out_cls=bt.Dataset)

    def get_datasets(self) -> List[bt.Dataset]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/datasets", out_cls=bt.Dataset, seq=True
        )

    def get_dataset(self, dataset_id: int) -> bt.Dataset:
        return self._get(
            f"{self._prefix}/dataset/{dataset_id}",
            out_cls=bt.Dataset,
        )

    def get_dataset_by_name(self, dataset_name: str) -> bt.Dataset:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/dataset/name",
            params={"name": dataset_name},
            out_cls=bt.Dataset,
        )

    def delete_dataset(self, dataset_id: int) -> None:
        return self._delete(f"{self._prefix}/dataset/{dataset_id}", sync=False)

    def sample_dataset(
        self,
        dataset: Union[bt.Dataset, str, int],
        start_index: Optional[int] = None,
        number_of_samples: Optional[int] = None,
        ratio: Optional[float] = None,
        splits: Optional[List[str]] = None,
        seed: int = 1234,
        as_dataset: bool = False,
    ) -> Union[bt.Dataset, Dict]:
        dataset = self._get_dataset(dataset)
        request = bt.SampleDatasetRequest(
            dataset=dataset,
            start_index=start_index,
            number_of_samples=number_of_samples,
            ratio=ratio,
            splits=splits,
            workspace_id=self.workspace.id,
            seed=seed,
            as_dataset=as_dataset,
        )

        out_cls = bt.Dataset if as_dataset is True else None
        output = self._post(
            f"{self._prefix}/dataset/sample", in_data=request, out_cls=out_cls, sync=False
        )
        if isinstance(output, bt.Dataset):
            return output
        else:
            dataset_sample = dict(output)
            return {key: pd.read_json(sample) for key, sample in dataset_sample.items()}

    def merge_and_resplit(
        self,
        dataset: Union[bt.Dataset, str, int],
        train_ratio: float,
        validation_ratio: float,
        test_ratio: float,
        splits_to_resplit: List[str],
        as_dataset: bool = True,
        shuffle: bool = True,
        seed: int = 1234,
    ) -> Union[bt.Dataset, Dict]:
        dataset = self._get_dataset(dataset)
        data = bt.MergeAndResplitRequest(
            dataset=dataset,
            train_ratio=train_ratio,
            validation_ratio=validation_ratio,
            test_ratio=test_ratio,
            splits_to_resplit=splits_to_resplit,
            workspace_id=self.workspace.id,
            shuffle=shuffle,
            seed=seed,
        )
        out_cls = bt.Dataset if as_dataset is True else dict
        return self._post(
            f"{self._prefix}/dataset/merge_and_resplit",
            out_cls=out_cls,
            in_data=data,
            sync=False,
        )

    def apply_prompt_template(
        self,
        dataset: bt.Dataset,
        prompt_template: bt.PromptTemplate,
        model: Optional[bt.Model] = None,
        new_column_name: str = "genai_input_text",
        expected_output_column_name: Optional[str] = None,
        batch_size: int = 10,
    ) -> bt.Dataset:
        data = {
            "dataset": dataset,
            "prompt_template": prompt_template,
            "model": model,
            "new_column_name": new_column_name,
            "expected_output_column_name": expected_output_column_name,
            "batch_size": batch_size,
            "workspace_id": self.workspace.id,
        }
        return self._post(
            f"{self._prefix}/dataset/apply_prompt_template",
            in_data=data,  # type: ignore
            out_cls=bt.Dataset,
            sync=False,
        )

    def compute_metrics(
        self,
        dataset: bt.Dataset,
        ground_truth_column_name: str,
        predictions_column_name: str,
        metrics: List[str] = ["exact_match"],
        substr_match: bool = True,
        split: Optional[str] = None,
        strip: bool = False,
        lower: bool = False,
    ) -> Tuple[bt.Dataset, Union[Dict[str, float], Dict[str, Dict[str, float]]]]:
        data = bt.ComputeMetricsRequest(
            dataset=dataset,
            ground_truth_column_name=ground_truth_column_name,
            predictions_column_name=predictions_column_name,
            metrics=metrics,
            substr_match=substr_match,
            workspace_id=self.workspace.id,
            split=split,
            strip=strip,
            lower=lower,
        )
        results = self._post(
            f"{self._prefix}/dataset/compute_metrics",
            in_data=data,
            out_cls=None,
            sync=False,
        )
        return bt.Dataset.parse_obj(results[0]), results[1]

    def materialize_dataset(
        self, dataset: bt.Dataset, output_path: Optional[str] = None
    ) -> hf.DatasetDict:
        if os.path.exists(dataset.storage_path):
            storage_path = dataset.storage_path
        else:
            storage_path = self._download_dataset(dataset, output_path)

        return hf.load_from_disk(storage_path)

    def _download_dataset(self, dataset: bt.Dataset, output_path: Optional[str] = None) -> str:
        # TODO (GAS-448): migrate to use self._get which handles authentication.
        file_response = self._client.get(
            f"{self._prefix}/dataset/{dataset.id}/download",
            headers={"Authorization": f"Bearer {self._required_token}"},
        )
        assert file_response.status_code == 200

        if output_path is None:
            hf_cache_dir = os.getenv("HF_CACHE")
            if hf_cache_dir:
                output_path = hf_cache_dir
            else:
                lore_cache_dir = "~/.cache/lore/datasets"
                os.makedirs(lore_cache_dir, exist_ok=True)
                output_path = lore_cache_dir

        with tempfile.NamedTemporaryFile(suffix=".zip") as temp_archive:
            for chunk in file_response.iter_bytes(chunk_size=8192):
                temp_archive.write(chunk)
            temp_archive.flush()
            with zipfile.ZipFile(temp_archive.name, "r") as zip_ref:
                zip_ref.extractall(output_path)
        return output_path

    def _upload_arrow_files(self, zip_file: str) -> bt.UploadedFiles:
        with open(zip_file, "rb") as f:
            file_name = zip_file.split("/")[-1]
            files = [("files", (file_name, f))]

            return self._post(
                f"{self._prefix}/dataset/upload_files", files=files, out_cls=bt.UploadedFiles
            )

    def concatenate_datasets(
        self, datasets: List[bt.Dataset], name: str, register_dataset: bool = False
    ) -> bt.Dataset:
        data = {
            "datasets": datasets,
            "name": name,
            "workspace_id": self.workspace.id,
            "register_dataset": register_dataset,
            "system_generated": False,  # as it is trigger from user facing Python client
        }
        request_body = bt.ConcatenateDatasetsRequest.parse_obj(data)
        return self._post(
            f"{self._prefix}/dataset/concatenate", in_data=request_body, out_cls=bt.Dataset
        )

    # Model resources.

    def get_models(self) -> List[bt.Model]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/models", out_cls=bt.Model, seq=True
        )

    def get_base_models(self) -> List[bt.Model]:
        models = self._get(f"{self._prefix}/base-models", out_cls=bt.Model, seq=True)
        models = sorted(models, key=lambda x: x.hf_model_name)
        return models

    def get_base_model(self, model_name: str) -> bt.Model:
        models = self.get_base_models()
        model = next((model for model in models if model.hf_model_name == model_name), None)
        if model is None:
            raise ValueError(f"{model_name} not found.")
        return model

    def get_model(self, model_id: int) -> bt.Model:
        return self._get(f"{self._prefix}/model/{model_id}", out_cls=bt.Model)

    def get_model_by_name(self, model_name: str, is_base_model: bool = True) -> bt.Model:
        model = self._post(
            f"{self._prefix}/model",
            in_data=bt.GetModelByNameRequest(
                name=model_name, workspace_id=self.workspace.id, is_base_model=is_base_model
            ),
            out_cls=bt.Model,
            seq=False,
        )
        return model

    def register_model(self, model: bt.Model) -> bt.Model:
        if model.workspace_id is None:
            model.workspace_id = self.workspace.id
        assert model.workspace_id is not None, f"Workspace id is required for model registration"
        return self._post(f"{self._prefix}/model/register", model, out_cls=bt.Model)

    def delete_model(self, model_id: int) -> None:
        self._delete(f"{self._prefix}/model/{model_id}")

    def create_hf_model(
        self,
        hf_model_name: str,
        branch: Optional[str] = None,
        hf_model_commit: Optional[str] = None,
        formatting_tokens: Optional[bt.FormattingTokens] = None,
        problem_type: str = "",
        model_config: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> bt.Model:
        return bt.Model.from_hf(
            hf_model_name=hf_model_name,
            branch=branch,
            hf_model_commit=hf_model_commit,
            formatting_tokens=formatting_tokens,
            problem_type=problem_type,
            model_config=model_config,
            **kwargs,
        )

    def download_model(
        self, model_id: int, output_path: Optional[str] = None, mlde_host: Optional[str] = None
    ) -> str:
        """Downloads a model by checkpoint through the determined master directly
        Arguments:
            model_id (int): id for the model
            output_path (string): the folder path to save the model to
        Returns:
            String of model folder path
        """
        model = self.get_model(model_id)
        assert (
            model.genai_checkpoint_uuid
        ), f"Model ({model_id}) is missing a checkpoint uuid. Cancelling download"
        if mlde_host is None:
            logger.warn(
                f"mlde_host is not provided for download_model. Falling back {get_det_master_address()}"
            )
            mlde_host = get_det_master_address()
        # Call determined directly to avoid downloading our
        # checkpoints in two hops.
        session = create_session(self._required_token, master_address=mlde_host)
        resp = bindings.get_GetCheckpoint(session, checkpointUuid=model.genai_checkpoint_uuid)
        ckpt = checkpoint.Checkpoint._from_bindings(resp.checkpoint, session)
        Path(output_path).mkdir(parents=True)
        ckpt.download(output_path)
        return str(output_path)

    def upload_model_to_hf(
        self,
        model: bt.Model,
        hf_repo_owner: str,
        hf_repo_name: str,
        hf_token: str,
        private: bool = False,
    ) -> int:
        data = bt.UploadModelToHFRequest(
            model=model,
            hf_repo_owner=hf_repo_owner,
            hf_repo_name=hf_repo_name,
            hf_token=hf_token,
            private=private,
        )
        response: bt.UploadModelToHFResponse = self._post(
            f"{self._prefix}/model/upload_model_to_hf",
            in_data=data,
            out_cls=bt.UploadModelToHFResponse,
            sync=True,
        )
        return response.experiment_id

    # Prompt template resources.

    def commit_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
        return self._post(
            f"{self._prefix}/prompt-template", in_data=prompt_template, out_cls=bt.PromptTemplate
        )

    def get_prompt_templates(self) -> List[bt.PromptTemplate]:
        return self._get(
            f"{self._prefix}/workspace/{self.workspace.id}/prompt-templates",
            out_cls=bt.PromptTemplate,
            seq=True,
        )

    def get_prompt_template(self, prompt_template_id: int) -> bt.PromptTemplate:
        return self._get(
            f"{self._prefix}/prompt-template/{prompt_template_id}", out_cls=bt.PromptTemplate
        )

    def get_start_prompt_templates(self) -> List[bt.PromptTemplate]:
        return self._get(
            f"{self._prefix}/starting-prompt-templates", out_cls=bt.PromptTemplate, seq=True
        )

    def update_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
        return self._put(
            f"{self._prefix}/prompt-template/{prompt_template.id}",
            in_data=prompt_template,
            out_cls=bt.PromptTemplate,
        )

    def delete_prompt_template(self, prompt_template_id: int) -> None:
        self._delete(f"{self._prefix}/prompt-template/{prompt_template_id}")

    # Chat resources.
    def _get_recommended_vllm_configs(
        self, resource_pool, model: Union[str, bt.Model], is_base_model: bool
    ) -> bt.ResourcePoolInferenceConfigs:
        if isinstance(model, str):
            model = self.get_model_by_name(model, is_base_model)

        if resource_pool is not None:
            request = bt.InferenceConfigRequest(max_config_per_resource_pool=1, model_id=model.id)
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/inference/config",
                in_data=request,
                out_cls=bt.ResourcePoolInferenceConfigs,
                seq=True,
            )
            filtered_result = [c for c in result if c.resource_pool.name == resource_pool]
            if len(filtered_result) == 0:
                raise ex.InferenceConfigNotFoundException(
                    f"Resource pool {resource_pool} cannot run model {model.name}."
                )
            config = filtered_result[0]
        else:
            request = bt.InferenceConfigRequest(
                max_config_per_resource_pool=1, max_resource_pool=1, model_id=model.id
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/inference/config",
                in_data=request,
                out_cls=bt.ResourcePoolInferenceConfigs,
                seq=True,
            )
            if len(result) == 0:
                raise ex.InferenceConfigNotFoundException(
                    f"We cannot find any recommended configs for model {model.name}."
                )
            config = result[0]

        logger.info(
            f"Recommended vllm inference config is: resource_pool: {config.resource_pool.name}, slots: {config.vllm_configs[0].slots_per_trial}"
        )
        return config

    def _try_recommend_vllm_config(
        self,
        use_vllm,
        vllm_tensor_parallel_size,
        is_arbitrary_hf_model,
        disable_inference_config_recommendation,
    ):
        if disable_inference_config_recommendation:
            return False

        if not use_vllm:
            return False

        # load_model input validation already force resource pool to be given if vllm_tensor_parallel_size is provided
        if vllm_tensor_parallel_size:
            return False

        # No config recommendation for non-vetted HF models
        if is_arbitrary_hf_model:
            return False
        return True

    def load_model(
        self,
        model: Union[str, bt.Model],
        is_genai_base_model: bool = True,
        genai_model_version: Optional[int] = None,
        *,
        resource_pool: Optional[str] = None,
        use_vllm: bool = True,
        vllm_tensor_parallel_size: Optional[int] = None,
        vllm_swap_space: Optional[int] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        is_arbitrary_hf_model: bool = False,
        disable_inference_config_recommendation: bool = False,
    ) -> bool:
        # Input validation
        if not use_vllm and vllm_tensor_parallel_size is not None:
            raise ValueError(
                "vllm_tensor_parallel_size should not be provided if use_vllm is False"
            )

        if vllm_tensor_parallel_size is not None and resource_pool is None:
            raise ValueError(
                "Please provide resource_pool when vllm_tensor_parallel_size is given. Otherwise, leave both fields blank for system recommendation."
            )

        if not use_vllm and vllm_tensor_parallel_size is not None:
            raise ValueError("use_vllm is False but vllm_tensor_parallel_size is not None")

        if is_arbitrary_hf_model and is_genai_base_model:
            raise ValueError("is_arbitrary_hf_model and is_genai_base_model cannot both be True")

        if torch_dtype is not None and torch_dtype not in (
            const.BF16_TYPE_STR,
            const.FP16_TYPE_STR,
        ):
            raise ValueError(
                f"torch_dtype should be {const.BF16_TYPE_STR} or {const.FP16_TYPE_STR}. Note that support for {const.BF16_TYPE_STR} is GPU dependent."
            )

        # End input validation

        try_config_recommendation = self._try_recommend_vllm_config(
            use_vllm=use_vllm,
            vllm_tensor_parallel_size=vllm_tensor_parallel_size,
            is_arbitrary_hf_model=is_arbitrary_hf_model,
            disable_inference_config_recommendation=disable_inference_config_recommendation,
        )

        recommended_config_found = False
        if try_config_recommendation:
            logger.info("Getting recommended vllm inference config from server")
            try:
                recommended_config = self._get_recommended_vllm_configs(
                    resource_pool=resource_pool, model=model, is_base_model=is_genai_base_model
                )
                logger.info(
                    "Using recommended config to run inference, to turn off recommendation, set disable_inference_config_recommendation = True"
                )
                # We should be only returning one vllm config per resource pool
                inference_config = recommended_config.vllm_configs[0]
                slots = inference_config.slots_per_trial
                resource_pool = recommended_config.resource_pool.name
                if torch_dtype is not None and torch_dtype != inference_config.torch_dtype:
                    logger.info(
                        f"torch_dtype {torch_dtype} will be overwritten by recommended config value {inference_config.torch_dtype}"
                    )
                torch_dtype = inference_config.torch_dtype
                self.max_new_tokens = inference_config.max_new_tokens
                self.batch_size = inference_config.batch_size
                vllm_config = {
                    "tensor-parallel-size": slots,
                    "swap-space": inference_config.swap_space,
                }
                recommended_config_found = True
            except ex.InferenceConfigNotFoundException as e:
                if resource_pool:
                    logger.warning(
                        f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with provided resource_pool: {resource_pool}"
                    )
                else:
                    logger.warning(
                        f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with default resource pool for the workspace"
                    )

        if not recommended_config_found:
            logger.info("User provided / default model config is used.")
            # Reset max_new_tokens and batch_size
            self.max_new_tokens = None
            self.batch_size = None
            if vllm_tensor_parallel_size:
                slots = vllm_tensor_parallel_size
            else:
                slots = const.DEFAULT_SLOTS_PER_TRIAL

            vllm_config = {
                "tensor-parallel-size": slots,
                "swap-space": vllm_swap_space if vllm_swap_space else const.DEFAULT_SWAP_SPACE,
            }

        model_load_config: bt.ModelLoadConfig = bt.ModelLoadConfig(
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=low_cpu_mem_usage,
            trust_remote_code=trust_remote_code,
            token=hf_token,
            vllm_config=vllm_config,
        )

        if isinstance(model, str):
            params = {
                "name": model,
                "is_base_model": is_genai_base_model,
                "genai_model_version": genai_model_version,
                "resource_pool": resource_pool,
                "slots": slots,
                "use_vllm": use_vllm,
                "is_arbitrary_hf_model": is_arbitrary_hf_model,
                "client_key": self._client_key,
            }
            endpoint = f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model"
        else:
            params = {
                "resource_pool": resource_pool,
                "slots": slots,
                "use_vllm": use_vllm,
                "client_key": self._client_key,
            }
            endpoint = (
                f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model/{model.id}"
            )
        params = {k: params[k] for k in params if params[k] is not None}
        return self._put(endpoint, params=params, in_data=model_load_config, sync=False)

    def generate(
        self,
        text: Union[str, List[str]],
        generation_config: Dict[str, Any] = {},
        batch_size: Optional[int] = None,
    ) -> Union[str, List[str]]:
        # Input validation
        if (
            self.batch_size is not None and batch_size is not None
        ) and batch_size > self.batch_size:
            raise ValueError(
                f"batch_size should be <= {self.batch_size} based on recommended model config"
            )
        if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
            if generation_config["max_new_tokens"] > self.max_new_tokens:
                raise ValueError(
                    f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
                )

        # End input validation

        seq = False if isinstance(text, str) else True
        data = bt.GenerateChatResponseRequest(
            prompts=text,
            generation_config=generation_config,
            batch_size=batch_size,
            client_key=self._client_key,
        )
        return self._post(
            f"{self._prefix}/workspace/{self.workspace.id}/inference/generate",
            in_data=data,
            out_cls=str,
            sync=False,
            seq=seq,
            poll_interval=0.1,
        )

    def generate_on_dataset(
        self,
        dataset: bt.Dataset,
        prompt_template: bt.PromptTemplate,
        as_dataset: bool = True,
        output_column_name: str = "genai_generated_text",
        expected_output_column_name: Optional[str] = "genai_expected_text",
        split: Optional[str] = None,
        start_index: Optional[int] = None,
        number_of_samples: Optional[int] = None,
        batch_size: int = 10,
        test_mode: bool = False,
        generation_config: Dict[str, Any] = {},
        num_proc: Optional[int] = None,
    ) -> bt.Dataset:
        # Input validation
        if (
            self.batch_size is not None and batch_size is not None
        ) and batch_size > self.batch_size:
            raise ValueError(
                f"batch_size should be <= {self.batch_size} based on recommended model config"
            )
        if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
            if generation_config["max_new_tokens"] > self.max_new_tokens:
                raise ValueError(
                    f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
                )
        # End input validation

        data = bt.GenerateOnDatasetRequest(
            dataset=dataset,
            prompt_template=prompt_template,
            output_column_name=output_column_name,
            expected_output_column_name=expected_output_column_name,
            split=split,
            start_index=start_index,
            number_of_samples=number_of_samples,
            batch_size=batch_size,
            test_mode=test_mode,
            generation_config=generation_config,
            num_proc=num_proc,
            client_key=self._client_key,
        )

        out_cls = bt.Dataset if as_dataset else dict
        return self._post(
            f"{self._prefix}/workspace/{self.workspace.id}/inference/generate_on_dataset",
            in_data=data,
            out_cls=out_cls,
            sync=False,
        )

    # Experiment resources.
    def _try_recommend_training_config(
        self, base_model: bt.Model, disable_training_config_recommendation: bool
    ) -> bool:
        if disable_training_config_recommendation:
            return False
        if base_model.model_architecture == "":
            # We need to know model_architecture to recommend training configs
            return False

        return True

    def _find_training_configs_recommendation(
        self, model_architecture: str, resource_pool: Optional[str]
    ) -> bt.ResourcePoolTrainingConfigs:
        if resource_pool is not None:
            request = bt.TrainingConfigRequest(
                model_architecture=model_architecture, max_config_per_resource_pool=1
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/model/training-configs",
                in_data=request,
                out_cls=bt.ResourcePoolTrainingConfigs,
                seq=True,
            )

            filtered_result = [c for c in result if c.resource_pool.name == resource_pool]
            if len(filtered_result) == 0:
                raise ex.TrainingConfigNotFoundException(
                    f"No config found for resource pool {resource_pool}"
                )
            config = filtered_result[0]
        else:
            request = bt.TrainingConfigRequest(
                model_architecture=model_architecture,
                max_config_per_resource_pool=1,
                max_resource_pool=1,
            )
            result = self._post(
                f"{self._prefix}/workspace/{self.workspace.id}/model/training-configs",
                in_data=request,
                out_cls=bt.ResourcePoolTrainingConfigs,
                seq=True,
            )

            if len(result) == 0:
                raise ex.TrainingConfigNotFoundException(
                    f"No config found for the model with architecture {model_architecture}"
                )
            config = result[0]
        return config

    def launch_training(
        self,
        dataset: bt.Dataset,
        base_model: bt.Model,
        name: Optional[str] = None,
        prompt_template: Optional[bt.PromptTemplate] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        slots_per_trial: Optional[int] = None,
        max_steps: int = -1,
        num_train_epochs: int = 1,
        save_total_limit: int = 1,
        learning_rate: float = 5e-5,
        per_device_train_batch_size: Optional[int] = None,
        per_device_eval_batch_size: Optional[int] = None,
        do_train: bool = True,
        do_eval: bool = True,
        block_size: Optional[int] = None,
        deepspeed: Optional[bool] = None,
        gradient_checkpointing: Optional[bool] = None,
        ds_cpu_offload: bool = False,
        seed: int = 1337,
        output_dir: str = "/tmp/lore_output",
        eval_steps: Optional[int] = None,
        logging_steps: Optional[int] = None,
        save_steps: Optional[int] = None,
        logging_strategy: str = "",
        evaluation_strategy: str = "",
        save_strategy: str = "",
        resource_pool: Optional[str] = None,
        disable_training_config_recommendation: bool = False,
    ) -> bt.Experiment:
        # Input validation
        if slots_per_trial is not None and resource_pool is None:
            raise ValueError(
                "slots_per_trial is provided but not resource_pool. Either provide both or neither."
            )

        if torch_dtype is not None and torch_dtype not in (
            const.BF16_TYPE_STR,
            const.FP16_TYPE_STR,
        ):
            raise ValueError(
                f"torch_dtype should be {const.FP16_TYPE_STR} or {const.BF16_TYPE_STR}. Note that {const.BF16_TYPE_STR} support is GPU dependent"
            )
        # End input validation

        # Register dataset if needed
        if dataset.id is None:
            dataset = self.commit_dataset(dataset)

        # Try get training config recommendation
        try_recommend_training_config = self._try_recommend_training_config(
            base_model=base_model,
            disable_training_config_recommendation=disable_training_config_recommendation,
        )
        config_recommendation_found = False

        if try_recommend_training_config:
            try:
                config = self._find_training_configs_recommendation(
                    model_architecture=base_model.model_architecture, resource_pool=resource_pool
                )
                logger.info(
                    "Using recommended config to launch training, to turn off recommendation, set disable_training_config_recommendation = True"
                )
                config_recommendation_found = True
                training_config = config.training_configs[0]
                slots_per_trial = training_config.slots_per_trial
                resource_pool = config.resource_pool.name
                logger.info(f"resource_pool is {resource_pool}")

                if (
                    per_device_train_batch_size is not None
                    and per_device_train_batch_size != training_config.batch_size
                ):
                    logger.warn(
                        f"per_device_train_batch_size: {per_device_train_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                    )
                per_device_train_batch_size = training_config.batch_size

                if (
                    per_device_eval_batch_size is not None
                    and per_device_eval_batch_size != training_config.batch_size
                ):
                    logger.warn(
                        f"per_device_eval_batch_size: {per_device_eval_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                    )
                per_device_eval_batch_size = training_config.batch_size

                if deepspeed is not None and deepspeed != training_config.deepspeed:
                    logger.warn(
                        f"deepspeed {deepspeed} will be overwritten by config recommendation to {training_config.deepspeed}"
                    )
                deepspeed = training_config.deepspeed

                if (
                    gradient_checkpointing is not None
                    and training_config.deepspeed != training_config.gradient_checkpointing
                ):
                    logger.warn(
                        f"gradient_checkpointing {gradient_checkpointing} will be overwritten by config recommendation to {training_config.gradient_checkpointing}"
                    )
                gradient_checkpointing = training_config.gradient_checkpointing

                if block_size is not None and block_size != training_config.context_window:
                    logger.warn(
                        f"block_size {block_size} will be overwritten by config recommendation to {training_config.context_window}"
                    )
                block_size = training_config.context_window

                if torch_dtype is not None and torch_dtype != training_config.torch_dtype:
                    logger.warn(
                        f"torch_dtype {torch_dtype} will be overwritten by config recommendation to {training_config.torch_dtype}"
                    )
                torch_dtype = training_config.torch_dtype

            except ex.TrainingConfigNotFoundException as e:
                logger.warn(str(e))

        if not config_recommendation_found:
            if slots_per_trial is None:
                slots_per_trial = const.DEFAULT_SLOTS_PER_TRIAL
                logger.info(f"Using default slots_per_trial: {slots_per_trial}")
            else:
                logger.info(f"Using user provided slots_per_trial: {slots_per_trial}")

            if per_device_train_batch_size is None:
                per_device_train_batch_size = const.DEFAULT_PER_DEVICE_TRAIN_BATCH_SIZE
                logger.info(
                    f"Using default per_device_train_batch_size: {per_device_train_batch_size}"
                )
            else:
                logger.info(
                    f"Using user provided per_device_train_batch_size: {per_device_train_batch_size}"
                )

            if per_device_eval_batch_size is None:
                per_device_eval_batch_size = const.DEFAULT_PER_DEVICE_EVAL_BATCH_SIZE
                logger.info(
                    f"Using default per_device_eval_batch_size: {per_device_eval_batch_size}"
                )
            else:
                logger.info(
                    f"Using user provided per_device_eval_batch_size: {per_device_eval_batch_size}"
                )

            if torch_dtype is None:
                torch_dtype = const.DEFAULT_TRAIN_TORCH_DTYPE
                logger.info(f"Using default torch_dtype: {torch_dtype}")
            else:
                logger.info(f"Using user provided torch_dtype: {torch_dtype}")

            if deepspeed is None:
                deepspeed = const.DEFAULT_TRAIN_DEEPSPEED_FLAG
                logger.info(f"Using default deepspeed: {deepspeed}")
            else:
                logger.info(f"Using user provided deepspeed: {deepspeed}")

            if gradient_checkpointing is None:
                gradient_checkpointing = const.DEFAULT_TRAIN_GRADIENT_CHECKPOINTING_FLAG
                logger.info(f"Using default gradient_checkpointing: {gradient_checkpointing}")
            else:
                logger.info(f"Using user provided gradient_checkpointing: {gradient_checkpointing}")

        # TODO: build on top of the existing ExperimentSettings instead of an
        # ad-hoc dict.
        data = {
            "name": name,
            "model": base_model,
            "project_id": self.project.id,
            "model_load_config": {
                "torch_dtype": torch_dtype,
                "low_cpu_mem_usage": low_cpu_mem_usage,
                "trust_remote_code": trust_remote_code,
                "token": hf_token,
            },
            "train_config": {
                "data_args": {
                    "dataset_data": dataset,
                    "block_size": block_size,
                },
                "slots_per_trial": slots_per_trial,
                "max_steps": max_steps,
                "num_train_epochs": num_train_epochs,
                "save_total_limit": save_total_limit,
                "learning_rate": learning_rate,
                "per_device_train_batch_size": per_device_train_batch_size,
                "per_device_eval_batch_size": per_device_eval_batch_size,
                "do_train": do_train,
                "do_eval": do_eval,
                "fp16": True if torch_dtype == const.FP16_TYPE_STR else False,
                "bf16": True if torch_dtype == const.BF16_TYPE_STR else False,
                "deepspeed": deepspeed,
                "gradient_checkpointing": gradient_checkpointing,
                "ds_cpu_offload": ds_cpu_offload,
                "seed": seed,
                "output_dir": output_dir,
                "eval_steps": eval_steps,
                "logging_steps": logging_steps,
                "save_steps": save_steps,
                "logging_strategy": logging_strategy,
                "evaluation_strategy": evaluation_strategy,
                "save_strategy": save_strategy,
            },
            "prompt_template": prompt_template,
            "resource_pool": resource_pool,
        }
        return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))

    def launch_evaluation(
        self,
        model: bt.Model,
        tasks: Union[str, List[str]],
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        name: Optional[str] = None,
        num_fewshot: int = 0,
        use_accelerate: bool = False,
        slots_per_trial: int = 1,
        max_concurrent_trials: int = 1,
        resource_pool: Optional[str] = None,
    ) -> bt.Experiment:
        data = bt.ExperimentSettings(
            name=name,
            model=model,
            project_id=self.project.id,
            model_load_config=bt.ModelLoadConfig(
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=low_cpu_mem_usage,
                trust_remote_code=trust_remote_code,
                token=hf_token,
            ),
            eval_config=bt.EvaluationConfig(
                data_args={},
                tasks=tasks,
                num_fewshot=num_fewshot,
                use_accelerate=use_accelerate,
                slots_per_trial=slots_per_trial,
                max_concurrent_trials=max_concurrent_trials,
            ),
            resource_pool=resource_pool,
        )

        return self._launch_experiment(in_data=data)

    def launch_evalution_custom_multi_choice(
        self,
        model: bt.Model,
        dataset: bt.Dataset,
        name: Optional[str] = None,
        hf_token: Optional[str] = None,
        torch_dtype: Optional[str] = None,
        low_cpu_mem_usage: bool = False,
        trust_remote_code: bool = False,
        query_col: str = "query",
        choices_col: str = "choices",
        gold_col: str = "gold",
        num_fewshot: int = 0,
        use_accelerate: bool = False,
        slots_per_trial: int = 1,
        max_concurrent_trials: int = 1,
        resource_pool: Optional[str] = None,
    ) -> bt.Experiment:
        data = {
            "name": name,
            "model": model,
            "project_id": self.project.id,
            "model_load_config": {
                "torch_dtype": torch_dtype,
                "low_cpu_mem_usage": low_cpu_mem_usage,
                "trust_remote_code": trust_remote_code,
                "token": hf_token,
            },
            "eval_config": {
                "data_args": {
                    "dataset_data": dataset,
                },
                "tasks": [
                    {
                        "query_col": query_col,
                        "choices_col": choices_col,
                        "gold_col": gold_col,
                        "dataset_uuid": None,
                        "hf_dataset_name": None,
                        "hf_dataset_path": dataset.storage_path,
                    },
                ],
                "num_fewshot": num_fewshot,
                "use_accelerate": use_accelerate,
                "slots_per_trial": slots_per_trial,
                "max_concurrent_trials": max_concurrent_trials,
            },
            "resource_pool": resource_pool,
        }

        return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))

    def _launch_experiment(self, in_data: bt.ExperimentSettings) -> bt.Experiment:
        return self._post(
            f"{self._prefix}/experiment",
            in_data=in_data,
            out_cls=bt.Experiment,
        )

    def get_experiments(self) -> List[bt.Experiment]:
        return self._get(
            f"{self._prefix}/project/{self.project.id}/experiments", seq=True, out_cls=bt.Experiment
        )

    def get_experiment(self, exp_id: int) -> bt.Experiment:
        return self._get(f"{self._prefix}/experiment/{exp_id}", out_cls=bt.Experiment)

    def delete_experiment(self, exp_id: int) -> None:
        self._delete(f"{self._prefix}/experiment/{exp_id}")

    def kill_experiment(self, exp_id: int) -> None:
        self._put(f"{self._prefix}/experiment/{exp_id}/kill")

    def get_experiment_models(self, exp_id: int) -> List[bt.Model]:
        return self._get(f"{self._prefix}/experiment/{exp_id}/models", seq=True, out_cls=bt.Model)

    def get_experiment_metrics(self, exp_id: int) -> Dict[str, Dict[str, Any]]:
        """Gets latest metrics for an experiment."""
        return self._get(f"{self._prefix}/experiment/{exp_id}/metrics", out_cls=None)

    def create_playground_snapshot(
        self, request: bt.CreatePlaygroundSnaphotRequest
    ) -> bt.PlaygroundSnapshot:
        return self._post(
            f"{self._prefix}/playground-snapshot", in_data=request, out_cls=bt.PlaygroundSnapshot
        )

    def get_playground_snapshot(self, id: int) -> bt.PlaygroundSnapshot:
        return self._get(f"{self._prefix}/playground-snapshot/{id}", out_cls=bt.PlaygroundSnapshot)

    def get_playground_snapshots(self) -> List[bt.PlaygroundSnapshot]:
        return self._get(
            f"{self._prefix}/project/{self.workspace.experiment_project_id}/playground-snapshots",
            seq=True,
            out_cls=bt.PlaygroundSnapshot,
        )

    def delete_playground_snapshot(self, id: int):
        self._delete(f"{self._prefix}/playground-snapshot/{id}")

    def update_playground_snapshot(
        self, request: bt.UpdatePlaygroundSnaphotRequest
    ) -> bt.PlaygroundSnapshot:
        return self._put(
            f"{self._prefix}/playground-snapshot/{request.id}",
            in_data=request,
            out_cls=bt.PlaygroundSnapshot,
        )

    def restart_controller(
        self, controller_type: bte.ControllerType
    ) -> bt.RestartControllerResponse:
        if controller_type == bte.ControllerType.DATA:
            request = bt.RestartControllerRequest(
                controller_type=controller_type,
                workspace_id=self.workspace.id,
            )
        else:
            request = bt.RestartControllerRequest(
                controller_type=controller_type,
                workspace_id=self.workspace.id,
                client_key=self._client_key,
            )
        route = f"{self._prefix}/controller/restart"
        out_cls = bt.RestartControllerResponse
        return self._put(route, request, out_cls=out_cls)

    # Utils

    def _get_dataset(self, dataset: Union[bt.Dataset, str, int]) -> bt.Dataset:
        if isinstance(dataset, bt.Dataset):
            return dataset
        elif isinstance(dataset, str):
            return self.get_dataset_by_name(dataset)
        else:
            return self.get_dataset(dataset)

    def build_local_dataset(
        self,
        train_files: Optional[Union[str, List[str]]] = None,
        validation_files: Optional[Union[str, List[str]]] = None,
        test_files: Optional[Union[str, List[str]]] = None,
    ) -> Union[hf.Dataset, hf.DatasetDict]:
        train_files = self._str_to_list(train_files)
        validation_files = self._str_to_list(validation_files)
        test_files = self._str_to_list(test_files)
        splits = {
            "train": train_files,
            "validation": validation_files,
            "test": test_files,
        }
        splits = {key: value for key, value in splits.items() if value is not None}
        file_format = self._get_file_extension(splits)

        dataset = hf.load_dataset(
            file_format,
            data_files=splits,
        )
        return dataset

    def _get_file_extension(self, splits: Dict[str, List[str]]) -> str:
        extensions: Set = set()
        for filenames in splits.values():
            if filenames is None:
                continue
            split_extensions = {filename.split(".")[-1] for filename in filenames}
            extensions.update(split_extensions)
            if len(extensions) > 1:
                raise ValueError(
                    f"Files have different extensions: {extensions}. "
                    f"Only one file file format (csv, json, or jsonl) allowed at a time."
                )

        extension = list(extensions)[0]
        if extension not in ["csv", "json", "jsonl"]:
            raise ValueError(
                f"File extension {extension} not supported. " f"Only csv, json, or jsonl allowed."
            )

        # AC: jsonl can be loaded as json.
        if extension == "jsonl":
            extension = "json"

        return extension

    def _str_to_list(self, value: Union[str, List[str]]) -> List[str]:
        return [value] if isinstance(value, str) else value

    def _join_filepaths_for_splits(
        self, dataset_directory_path: str, splits: Dict[str, List[str]]
    ) -> Dict[str, List[str]]:
        for s, filenames in splits.items():
            new_filenames = [os.path.join(dataset_directory_path, f) for f in filenames]
            splits[s] = new_filenames

        return splits

Ancestors

Methods

def apply_prompt_template(self, dataset: Dataset, prompt_template: PromptTemplate, model: Optional[Model] = None, new_column_name: str = 'genai_input_text', expected_output_column_name: Optional[str] = None, batch_size: int = 10) ‑> Dataset
Expand source code
def apply_prompt_template(
    self,
    dataset: bt.Dataset,
    prompt_template: bt.PromptTemplate,
    model: Optional[bt.Model] = None,
    new_column_name: str = "genai_input_text",
    expected_output_column_name: Optional[str] = None,
    batch_size: int = 10,
) -> bt.Dataset:
    data = {
        "dataset": dataset,
        "prompt_template": prompt_template,
        "model": model,
        "new_column_name": new_column_name,
        "expected_output_column_name": expected_output_column_name,
        "batch_size": batch_size,
        "workspace_id": self.workspace.id,
    }
    return self._post(
        f"{self._prefix}/dataset/apply_prompt_template",
        in_data=data,  # type: ignore
        out_cls=bt.Dataset,
        sync=False,
    )
def build_local_dataset(self, train_files: Union[str, List[str], ForwardRef(None)] = None, validation_files: Union[str, List[str], ForwardRef(None)] = None, test_files: Union[str, List[str], ForwardRef(None)] = None) ‑> Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict]
Expand source code
def build_local_dataset(
    self,
    train_files: Optional[Union[str, List[str]]] = None,
    validation_files: Optional[Union[str, List[str]]] = None,
    test_files: Optional[Union[str, List[str]]] = None,
) -> Union[hf.Dataset, hf.DatasetDict]:
    train_files = self._str_to_list(train_files)
    validation_files = self._str_to_list(validation_files)
    test_files = self._str_to_list(test_files)
    splits = {
        "train": train_files,
        "validation": validation_files,
        "test": test_files,
    }
    splits = {key: value for key, value in splits.items() if value is not None}
    file_format = self._get_file_extension(splits)

    dataset = hf.load_dataset(
        file_format,
        data_files=splits,
    )
    return dataset
def commit_dataset(self, dataset: Dataset) ‑> Dataset
Expand source code
def commit_dataset(self, dataset: bt.Dataset) -> bt.Dataset:
    return self._post(f"{self._prefix}/dataset", in_data=dataset, out_cls=bt.Dataset)
def commit_prompt_template(self, prompt_template: PromptTemplate) ‑> PromptTemplate
Expand source code
def commit_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
    return self._post(
        f"{self._prefix}/prompt-template", in_data=prompt_template, out_cls=bt.PromptTemplate
    )
def compute_metrics(self, dataset: Dataset, ground_truth_column_name: str, predictions_column_name: str, metrics: List[str] = ['exact_match'], substr_match: bool = True, split: Optional[str] = None, strip: bool = False, lower: bool = False) ‑> Tuple[Dataset, Union[Dict[str, float], Dict[str, Dict[str, float]]]]
Expand source code
def compute_metrics(
    self,
    dataset: bt.Dataset,
    ground_truth_column_name: str,
    predictions_column_name: str,
    metrics: List[str] = ["exact_match"],
    substr_match: bool = True,
    split: Optional[str] = None,
    strip: bool = False,
    lower: bool = False,
) -> Tuple[bt.Dataset, Union[Dict[str, float], Dict[str, Dict[str, float]]]]:
    data = bt.ComputeMetricsRequest(
        dataset=dataset,
        ground_truth_column_name=ground_truth_column_name,
        predictions_column_name=predictions_column_name,
        metrics=metrics,
        substr_match=substr_match,
        workspace_id=self.workspace.id,
        split=split,
        strip=strip,
        lower=lower,
    )
    results = self._post(
        f"{self._prefix}/dataset/compute_metrics",
        in_data=data,
        out_cls=None,
        sync=False,
    )
    return bt.Dataset.parse_obj(results[0]), results[1]
def concatenate_datasets(self, datasets: List[Dataset], name: str, register_dataset: bool = False) ‑> Dataset
Expand source code
def concatenate_datasets(
    self, datasets: List[bt.Dataset], name: str, register_dataset: bool = False
) -> bt.Dataset:
    data = {
        "datasets": datasets,
        "name": name,
        "workspace_id": self.workspace.id,
        "register_dataset": register_dataset,
        "system_generated": False,  # as it is trigger from user facing Python client
    }
    request_body = bt.ConcatenateDatasetsRequest.parse_obj(data)
    return self._post(
        f"{self._prefix}/dataset/concatenate", in_data=request_body, out_cls=bt.Dataset
    )
def construct_dataset_from_hf(self, dataset_name: str, dataset_config_name: Optional[str] = None, task_type: Optional[str] = None, data_type: Optional[List[DatasetDataType]] = None, token: Optional[str] = None) ‑> Dataset
Expand source code
def construct_dataset_from_hf(
    self,
    dataset_name: str,
    dataset_config_name: Optional[str] = None,
    task_type: Optional[str] = None,
    data_type: Optional[List[bt.DatasetDataType]] = None,
    token: Optional[str] = None,
) -> bt.Dataset:
    data = bt.ConstructDatasetFromHFRequest(
        dataset_name=dataset_name,
        dataset_config_name=dataset_config_name,
        task_type=task_type,
        data_type=data_type,
        token=token,
        workspace_id=self.workspace.id,
    )
    return self._post(
        f"{self._prefix}/construct_dataset_from_hf",
        in_data=data,
        out_cls=bt.Dataset,
        sync=False,
    )
def construct_dataset_from_local(self, hf_dataset: Union[datasets.arrow_dataset.Dataset, datasets.dataset_dict.DatasetDict, ForwardRef(None)] = None, dataset_name: Optional[str] = None) ‑> Dataset
Expand source code
def construct_dataset_from_local(
    self,
    hf_dataset: Optional[Union[hf.Dataset, hf.DatasetDict]] = None,
    dataset_name: Optional[str] = None,
) -> bt.Dataset:
    if dataset_name is None:
        if isinstance(hf_dataset, hf.Dataset):
            dataset_name = hf_dataset.info.dataset_name
        else:
            # Find the first split with info and dataset_name
            dataset_name = None
            for key in hf_dataset:
                if hasattr(hf_dataset[key], "info") and hasattr(
                    hf_dataset[key].info, "dataset_name"
                ):
                    dataset_name = hf_dataset[key].info.dataset_name
                    break

    assert (
        dataset_name
    ), "The 'dataset_name' parameter is missing in hf_dataset. Please provide the dataset name explicitly."

    with tempfile.TemporaryDirectory() as tmp_dir:
        with tempfile.TemporaryDirectory() as zip_tmp_dir:
            hf_dataset.save_to_disk(os.path.join(tmp_dir, "hf_dataset"))

            shutil.make_archive(
                os.path.join(zip_tmp_dir, "hf_dataset"), "zip", root_dir=tmp_dir
            )
            uploaded_files = self._upload_arrow_files(
                os.path.join(zip_tmp_dir, "hf_dataset.zip")
            )
            data = bt.ConstructDatasetFromLocalRequest(
                workspace_id=self.workspace.id,
                dataset_name=dataset_name,
                arrow_file_path=uploaded_files.file_paths[0],
            )

    return self._post(
        f"{self._prefix}/construct_dataset_from_local",
        in_data=data,
        out_cls=bt.Dataset,
        sync=False,
    )
def create_hf_model(self, hf_model_name: str, branch: Optional[str] = None, hf_model_commit: Optional[str] = None, formatting_tokens: Optional[FormattingTokens] = None, problem_type: str = '', model_config: Optional[Dict[str, Any]] = None, **kwargs) ‑> Model
Expand source code
def create_hf_model(
    self,
    hf_model_name: str,
    branch: Optional[str] = None,
    hf_model_commit: Optional[str] = None,
    formatting_tokens: Optional[bt.FormattingTokens] = None,
    problem_type: str = "",
    model_config: Optional[Dict[str, Any]] = None,
    **kwargs,
) -> bt.Model:
    return bt.Model.from_hf(
        hf_model_name=hf_model_name,
        branch=branch,
        hf_model_commit=hf_model_commit,
        formatting_tokens=formatting_tokens,
        problem_type=problem_type,
        model_config=model_config,
        **kwargs,
    )
def create_playground_snapshot(self, request: CreatePlaygroundSnaphotRequest) ‑> PlaygroundSnapshot
Expand source code
def create_playground_snapshot(
    self, request: bt.CreatePlaygroundSnaphotRequest
) -> bt.PlaygroundSnapshot:
    return self._post(
        f"{self._prefix}/playground-snapshot", in_data=request, out_cls=bt.PlaygroundSnapshot
    )
def create_workspace(self, workspace_name: str) ‑> Workspace
Expand source code
def create_workspace(self, workspace_name: str) -> bt.Workspace:
    workspace = bt.CreateWorkspaceRequest(name=workspace_name)
    return self._post(f"{self._prefix}/workspace", workspace, out_cls=bt.Workspace)
def delete_dataset(self, dataset_id: int) ‑> None
Expand source code
def delete_dataset(self, dataset_id: int) -> None:
    return self._delete(f"{self._prefix}/dataset/{dataset_id}", sync=False)
def delete_experiment(self, exp_id: int) ‑> None
Expand source code
def delete_experiment(self, exp_id: int) -> None:
    self._delete(f"{self._prefix}/experiment/{exp_id}")
def delete_model(self, model_id: int) ‑> None
Expand source code
def delete_model(self, model_id: int) -> None:
    self._delete(f"{self._prefix}/model/{model_id}")
def delete_playground_snapshot(self, id: int)
Expand source code
def delete_playground_snapshot(self, id: int):
    self._delete(f"{self._prefix}/playground-snapshot/{id}")
def delete_prompt_template(self, prompt_template_id: int) ‑> None
Expand source code
def delete_prompt_template(self, prompt_template_id: int) -> None:
    self._delete(f"{self._prefix}/prompt-template/{prompt_template_id}")
def delete_workspace(self, workspace_id: int) ‑> None
Expand source code
def delete_workspace(self, workspace_id: int) -> None:
    return self._delete(f"{self._prefix}/workspace/{workspace_id}")
def download_model(self, model_id: int, output_path: Optional[str] = None, mlde_host: Optional[str] = None) ‑> str

Downloads a model by checkpoint through the determined master directly

Arguments

model_id (int): id for the model output_path (string): the folder path to save the model to

Returns

String of model folder path

Expand source code
def download_model(
    self, model_id: int, output_path: Optional[str] = None, mlde_host: Optional[str] = None
) -> str:
    """Downloads a model by checkpoint through the determined master directly
    Arguments:
        model_id (int): id for the model
        output_path (string): the folder path to save the model to
    Returns:
        String of model folder path
    """
    model = self.get_model(model_id)
    assert (
        model.genai_checkpoint_uuid
    ), f"Model ({model_id}) is missing a checkpoint uuid. Cancelling download"
    if mlde_host is None:
        logger.warn(
            f"mlde_host is not provided for download_model. Falling back {get_det_master_address()}"
        )
        mlde_host = get_det_master_address()
    # Call determined directly to avoid downloading our
    # checkpoints in two hops.
    session = create_session(self._required_token, master_address=mlde_host)
    resp = bindings.get_GetCheckpoint(session, checkpointUuid=model.genai_checkpoint_uuid)
    ckpt = checkpoint.Checkpoint._from_bindings(resp.checkpoint, session)
    Path(output_path).mkdir(parents=True)
    ckpt.download(output_path)
    return str(output_path)
def generate(self, text: Union[str, List[str]], generation_config: Dict[str, Any] = {}, batch_size: Optional[int] = None) ‑> Union[str, List[str]]
Expand source code
def generate(
    self,
    text: Union[str, List[str]],
    generation_config: Dict[str, Any] = {},
    batch_size: Optional[int] = None,
) -> Union[str, List[str]]:
    # Input validation
    if (
        self.batch_size is not None and batch_size is not None
    ) and batch_size > self.batch_size:
        raise ValueError(
            f"batch_size should be <= {self.batch_size} based on recommended model config"
        )
    if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
        if generation_config["max_new_tokens"] > self.max_new_tokens:
            raise ValueError(
                f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
            )

    # End input validation

    seq = False if isinstance(text, str) else True
    data = bt.GenerateChatResponseRequest(
        prompts=text,
        generation_config=generation_config,
        batch_size=batch_size,
        client_key=self._client_key,
    )
    return self._post(
        f"{self._prefix}/workspace/{self.workspace.id}/inference/generate",
        in_data=data,
        out_cls=str,
        sync=False,
        seq=seq,
        poll_interval=0.1,
    )
def generate_on_dataset(self, dataset: Dataset, prompt_template: PromptTemplate, as_dataset: bool = True, output_column_name: str = 'genai_generated_text', expected_output_column_name: Optional[str] = 'genai_expected_text', split: Optional[str] = None, start_index: Optional[int] = None, number_of_samples: Optional[int] = None, batch_size: int = 10, test_mode: bool = False, generation_config: Dict[str, Any] = {}, num_proc: Optional[int] = None) ‑> Dataset
Expand source code
def generate_on_dataset(
    self,
    dataset: bt.Dataset,
    prompt_template: bt.PromptTemplate,
    as_dataset: bool = True,
    output_column_name: str = "genai_generated_text",
    expected_output_column_name: Optional[str] = "genai_expected_text",
    split: Optional[str] = None,
    start_index: Optional[int] = None,
    number_of_samples: Optional[int] = None,
    batch_size: int = 10,
    test_mode: bool = False,
    generation_config: Dict[str, Any] = {},
    num_proc: Optional[int] = None,
) -> bt.Dataset:
    # Input validation
    if (
        self.batch_size is not None and batch_size is not None
    ) and batch_size > self.batch_size:
        raise ValueError(
            f"batch_size should be <= {self.batch_size} based on recommended model config"
        )
    if self.max_new_tokens is not None and "max_new_tokens" in generation_config:
        if generation_config["max_new_tokens"] > self.max_new_tokens:
            raise ValueError(
                f"max_new_tokens should be <= {self.max_new_tokens} based on recommended model config"
            )
    # End input validation

    data = bt.GenerateOnDatasetRequest(
        dataset=dataset,
        prompt_template=prompt_template,
        output_column_name=output_column_name,
        expected_output_column_name=expected_output_column_name,
        split=split,
        start_index=start_index,
        number_of_samples=number_of_samples,
        batch_size=batch_size,
        test_mode=test_mode,
        generation_config=generation_config,
        num_proc=num_proc,
        client_key=self._client_key,
    )

    out_cls = bt.Dataset if as_dataset else dict
    return self._post(
        f"{self._prefix}/workspace/{self.workspace.id}/inference/generate_on_dataset",
        in_data=data,
        out_cls=out_cls,
        sync=False,
    )
def get_base_model(self, model_name: str) ‑> Model
Expand source code
def get_base_model(self, model_name: str) -> bt.Model:
    models = self.get_base_models()
    model = next((model for model in models if model.hf_model_name == model_name), None)
    if model is None:
        raise ValueError(f"{model_name} not found.")
    return model
def get_base_models(self) ‑> List[Model]
Expand source code
def get_base_models(self) -> List[bt.Model]:
    models = self._get(f"{self._prefix}/base-models", out_cls=bt.Model, seq=True)
    models = sorted(models, key=lambda x: x.hf_model_name)
    return models
def get_dataset(self, dataset_id: int) ‑> Dataset
Expand source code
def get_dataset(self, dataset_id: int) -> bt.Dataset:
    return self._get(
        f"{self._prefix}/dataset/{dataset_id}",
        out_cls=bt.Dataset,
    )
def get_dataset_by_name(self, dataset_name: str) ‑> Dataset
Expand source code
def get_dataset_by_name(self, dataset_name: str) -> bt.Dataset:
    return self._get(
        f"{self._prefix}/workspace/{self.workspace.id}/dataset/name",
        params={"name": dataset_name},
        out_cls=bt.Dataset,
    )
def get_datasets(self) ‑> List[Dataset]
Expand source code
def get_datasets(self) -> List[bt.Dataset]:
    return self._get(
        f"{self._prefix}/workspace/{self.workspace.id}/datasets", out_cls=bt.Dataset, seq=True
    )
def get_experiment(self, exp_id: int) ‑> Experiment
Expand source code
def get_experiment(self, exp_id: int) -> bt.Experiment:
    return self._get(f"{self._prefix}/experiment/{exp_id}", out_cls=bt.Experiment)
def get_experiment_metrics(self, exp_id: int) ‑> Dict[str, Dict[str, Any]]

Gets latest metrics for an experiment.

Expand source code
def get_experiment_metrics(self, exp_id: int) -> Dict[str, Dict[str, Any]]:
    """Gets latest metrics for an experiment."""
    return self._get(f"{self._prefix}/experiment/{exp_id}/metrics", out_cls=None)
def get_experiment_models(self, exp_id: int) ‑> List[Model]
Expand source code
def get_experiment_models(self, exp_id: int) -> List[bt.Model]:
    return self._get(f"{self._prefix}/experiment/{exp_id}/models", seq=True, out_cls=bt.Model)
def get_experiments(self) ‑> List[Experiment]
Expand source code
def get_experiments(self) -> List[bt.Experiment]:
    return self._get(
        f"{self._prefix}/project/{self.project.id}/experiments", seq=True, out_cls=bt.Experiment
    )
def get_model(self, model_id: int) ‑> Model
Expand source code
def get_model(self, model_id: int) -> bt.Model:
    return self._get(f"{self._prefix}/model/{model_id}", out_cls=bt.Model)
def get_model_by_name(self, model_name: str, is_base_model: bool = True) ‑> Model
Expand source code
def get_model_by_name(self, model_name: str, is_base_model: bool = True) -> bt.Model:
    model = self._post(
        f"{self._prefix}/model",
        in_data=bt.GetModelByNameRequest(
            name=model_name, workspace_id=self.workspace.id, is_base_model=is_base_model
        ),
        out_cls=bt.Model,
        seq=False,
    )
    return model
def get_models(self) ‑> List[Model]
Expand source code
def get_models(self) -> List[bt.Model]:
    return self._get(
        f"{self._prefix}/workspace/{self.workspace.id}/models", out_cls=bt.Model, seq=True
    )
def get_playground_snapshot(self, id: int) ‑> PlaygroundSnapshot
Expand source code
def get_playground_snapshot(self, id: int) -> bt.PlaygroundSnapshot:
    return self._get(f"{self._prefix}/playground-snapshot/{id}", out_cls=bt.PlaygroundSnapshot)
def get_playground_snapshots(self) ‑> List[PlaygroundSnapshot]
Expand source code
def get_playground_snapshots(self) -> List[bt.PlaygroundSnapshot]:
    return self._get(
        f"{self._prefix}/project/{self.workspace.experiment_project_id}/playground-snapshots",
        seq=True,
        out_cls=bt.PlaygroundSnapshot,
    )
def get_prompt_template(self, prompt_template_id: int) ‑> PromptTemplate
Expand source code
def get_prompt_template(self, prompt_template_id: int) -> bt.PromptTemplate:
    return self._get(
        f"{self._prefix}/prompt-template/{prompt_template_id}", out_cls=bt.PromptTemplate
    )
def get_prompt_templates(self) ‑> List[PromptTemplate]
Expand source code
def get_prompt_templates(self) -> List[bt.PromptTemplate]:
    return self._get(
        f"{self._prefix}/workspace/{self.workspace.id}/prompt-templates",
        out_cls=bt.PromptTemplate,
        seq=True,
    )
def get_start_prompt_templates(self) ‑> List[PromptTemplate]
Expand source code
def get_start_prompt_templates(self) -> List[bt.PromptTemplate]:
    return self._get(
        f"{self._prefix}/starting-prompt-templates", out_cls=bt.PromptTemplate, seq=True
    )
def get_workspace(self, workspace_id: int) ‑> Workspace
Expand source code
def get_workspace(self, workspace_id: int) -> bt.Workspace:
    return self._get(f"{self._prefix}/workspace/{workspace_id}", out_cls=bt.Workspace)
def get_workspaces(self) ‑> List[Workspace]
Expand source code
def get_workspaces(self) -> List[bt.Workspace]:
    return self._get(f"{self._prefix}/workspaces", out_cls=bt.Workspace, seq=True)
def kill_experiment(self, exp_id: int) ‑> None
Expand source code
def kill_experiment(self, exp_id: int) -> None:
    self._put(f"{self._prefix}/experiment/{exp_id}/kill")
def launch_evaluation(self, model: Model, tasks: Union[str, List[str]], hf_token: Optional[str] = None, torch_dtype: Optional[str] = None, low_cpu_mem_usage: bool = False, trust_remote_code: bool = False, name: Optional[str] = None, num_fewshot: int = 0, use_accelerate: bool = False, slots_per_trial: int = 1, max_concurrent_trials: int = 1, resource_pool: Optional[str] = None) ‑> Experiment
Expand source code
def launch_evaluation(
    self,
    model: bt.Model,
    tasks: Union[str, List[str]],
    hf_token: Optional[str] = None,
    torch_dtype: Optional[str] = None,
    low_cpu_mem_usage: bool = False,
    trust_remote_code: bool = False,
    name: Optional[str] = None,
    num_fewshot: int = 0,
    use_accelerate: bool = False,
    slots_per_trial: int = 1,
    max_concurrent_trials: int = 1,
    resource_pool: Optional[str] = None,
) -> bt.Experiment:
    data = bt.ExperimentSettings(
        name=name,
        model=model,
        project_id=self.project.id,
        model_load_config=bt.ModelLoadConfig(
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=low_cpu_mem_usage,
            trust_remote_code=trust_remote_code,
            token=hf_token,
        ),
        eval_config=bt.EvaluationConfig(
            data_args={},
            tasks=tasks,
            num_fewshot=num_fewshot,
            use_accelerate=use_accelerate,
            slots_per_trial=slots_per_trial,
            max_concurrent_trials=max_concurrent_trials,
        ),
        resource_pool=resource_pool,
    )

    return self._launch_experiment(in_data=data)
def launch_evalution_custom_multi_choice(self, model: Model, dataset: Dataset, name: Optional[str] = None, hf_token: Optional[str] = None, torch_dtype: Optional[str] = None, low_cpu_mem_usage: bool = False, trust_remote_code: bool = False, query_col: str = 'query', choices_col: str = 'choices', gold_col: str = 'gold', num_fewshot: int = 0, use_accelerate: bool = False, slots_per_trial: int = 1, max_concurrent_trials: int = 1, resource_pool: Optional[str] = None) ‑> Experiment
Expand source code
def launch_evalution_custom_multi_choice(
    self,
    model: bt.Model,
    dataset: bt.Dataset,
    name: Optional[str] = None,
    hf_token: Optional[str] = None,
    torch_dtype: Optional[str] = None,
    low_cpu_mem_usage: bool = False,
    trust_remote_code: bool = False,
    query_col: str = "query",
    choices_col: str = "choices",
    gold_col: str = "gold",
    num_fewshot: int = 0,
    use_accelerate: bool = False,
    slots_per_trial: int = 1,
    max_concurrent_trials: int = 1,
    resource_pool: Optional[str] = None,
) -> bt.Experiment:
    data = {
        "name": name,
        "model": model,
        "project_id": self.project.id,
        "model_load_config": {
            "torch_dtype": torch_dtype,
            "low_cpu_mem_usage": low_cpu_mem_usage,
            "trust_remote_code": trust_remote_code,
            "token": hf_token,
        },
        "eval_config": {
            "data_args": {
                "dataset_data": dataset,
            },
            "tasks": [
                {
                    "query_col": query_col,
                    "choices_col": choices_col,
                    "gold_col": gold_col,
                    "dataset_uuid": None,
                    "hf_dataset_name": None,
                    "hf_dataset_path": dataset.storage_path,
                },
            ],
            "num_fewshot": num_fewshot,
            "use_accelerate": use_accelerate,
            "slots_per_trial": slots_per_trial,
            "max_concurrent_trials": max_concurrent_trials,
        },
        "resource_pool": resource_pool,
    }

    return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))
def launch_training(self, dataset: Dataset, base_model: Model, name: Optional[str] = None, prompt_template: Optional[PromptTemplate] = None, hf_token: Optional[str] = None, torch_dtype: Optional[str] = None, low_cpu_mem_usage: bool = False, trust_remote_code: bool = False, slots_per_trial: Optional[int] = None, max_steps: int = -1, num_train_epochs: int = 1, save_total_limit: int = 1, learning_rate: float = 5e-05, per_device_train_batch_size: Optional[int] = None, per_device_eval_batch_size: Optional[int] = None, do_train: bool = True, do_eval: bool = True, block_size: Optional[int] = None, deepspeed: Optional[bool] = None, gradient_checkpointing: Optional[bool] = None, ds_cpu_offload: bool = False, seed: int = 1337, output_dir: str = '/tmp/lore_output', eval_steps: Optional[int] = None, logging_steps: Optional[int] = None, save_steps: Optional[int] = None, logging_strategy: str = '', evaluation_strategy: str = '', save_strategy: str = '', resource_pool: Optional[str] = None, disable_training_config_recommendation: bool = False) ‑> Experiment
Expand source code
def launch_training(
    self,
    dataset: bt.Dataset,
    base_model: bt.Model,
    name: Optional[str] = None,
    prompt_template: Optional[bt.PromptTemplate] = None,
    hf_token: Optional[str] = None,
    torch_dtype: Optional[str] = None,
    low_cpu_mem_usage: bool = False,
    trust_remote_code: bool = False,
    slots_per_trial: Optional[int] = None,
    max_steps: int = -1,
    num_train_epochs: int = 1,
    save_total_limit: int = 1,
    learning_rate: float = 5e-5,
    per_device_train_batch_size: Optional[int] = None,
    per_device_eval_batch_size: Optional[int] = None,
    do_train: bool = True,
    do_eval: bool = True,
    block_size: Optional[int] = None,
    deepspeed: Optional[bool] = None,
    gradient_checkpointing: Optional[bool] = None,
    ds_cpu_offload: bool = False,
    seed: int = 1337,
    output_dir: str = "/tmp/lore_output",
    eval_steps: Optional[int] = None,
    logging_steps: Optional[int] = None,
    save_steps: Optional[int] = None,
    logging_strategy: str = "",
    evaluation_strategy: str = "",
    save_strategy: str = "",
    resource_pool: Optional[str] = None,
    disable_training_config_recommendation: bool = False,
) -> bt.Experiment:
    # Input validation
    if slots_per_trial is not None and resource_pool is None:
        raise ValueError(
            "slots_per_trial is provided but not resource_pool. Either provide both or neither."
        )

    if torch_dtype is not None and torch_dtype not in (
        const.BF16_TYPE_STR,
        const.FP16_TYPE_STR,
    ):
        raise ValueError(
            f"torch_dtype should be {const.FP16_TYPE_STR} or {const.BF16_TYPE_STR}. Note that {const.BF16_TYPE_STR} support is GPU dependent"
        )
    # End input validation

    # Register dataset if needed
    if dataset.id is None:
        dataset = self.commit_dataset(dataset)

    # Try get training config recommendation
    try_recommend_training_config = self._try_recommend_training_config(
        base_model=base_model,
        disable_training_config_recommendation=disable_training_config_recommendation,
    )
    config_recommendation_found = False

    if try_recommend_training_config:
        try:
            config = self._find_training_configs_recommendation(
                model_architecture=base_model.model_architecture, resource_pool=resource_pool
            )
            logger.info(
                "Using recommended config to launch training, to turn off recommendation, set disable_training_config_recommendation = True"
            )
            config_recommendation_found = True
            training_config = config.training_configs[0]
            slots_per_trial = training_config.slots_per_trial
            resource_pool = config.resource_pool.name
            logger.info(f"resource_pool is {resource_pool}")

            if (
                per_device_train_batch_size is not None
                and per_device_train_batch_size != training_config.batch_size
            ):
                logger.warn(
                    f"per_device_train_batch_size: {per_device_train_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                )
            per_device_train_batch_size = training_config.batch_size

            if (
                per_device_eval_batch_size is not None
                and per_device_eval_batch_size != training_config.batch_size
            ):
                logger.warn(
                    f"per_device_eval_batch_size: {per_device_eval_batch_size} will be overwritten by config recommendation to {training_config.batch_size}"
                )
            per_device_eval_batch_size = training_config.batch_size

            if deepspeed is not None and deepspeed != training_config.deepspeed:
                logger.warn(
                    f"deepspeed {deepspeed} will be overwritten by config recommendation to {training_config.deepspeed}"
                )
            deepspeed = training_config.deepspeed

            if (
                gradient_checkpointing is not None
                and training_config.deepspeed != training_config.gradient_checkpointing
            ):
                logger.warn(
                    f"gradient_checkpointing {gradient_checkpointing} will be overwritten by config recommendation to {training_config.gradient_checkpointing}"
                )
            gradient_checkpointing = training_config.gradient_checkpointing

            if block_size is not None and block_size != training_config.context_window:
                logger.warn(
                    f"block_size {block_size} will be overwritten by config recommendation to {training_config.context_window}"
                )
            block_size = training_config.context_window

            if torch_dtype is not None and torch_dtype != training_config.torch_dtype:
                logger.warn(
                    f"torch_dtype {torch_dtype} will be overwritten by config recommendation to {training_config.torch_dtype}"
                )
            torch_dtype = training_config.torch_dtype

        except ex.TrainingConfigNotFoundException as e:
            logger.warn(str(e))

    if not config_recommendation_found:
        if slots_per_trial is None:
            slots_per_trial = const.DEFAULT_SLOTS_PER_TRIAL
            logger.info(f"Using default slots_per_trial: {slots_per_trial}")
        else:
            logger.info(f"Using user provided slots_per_trial: {slots_per_trial}")

        if per_device_train_batch_size is None:
            per_device_train_batch_size = const.DEFAULT_PER_DEVICE_TRAIN_BATCH_SIZE
            logger.info(
                f"Using default per_device_train_batch_size: {per_device_train_batch_size}"
            )
        else:
            logger.info(
                f"Using user provided per_device_train_batch_size: {per_device_train_batch_size}"
            )

        if per_device_eval_batch_size is None:
            per_device_eval_batch_size = const.DEFAULT_PER_DEVICE_EVAL_BATCH_SIZE
            logger.info(
                f"Using default per_device_eval_batch_size: {per_device_eval_batch_size}"
            )
        else:
            logger.info(
                f"Using user provided per_device_eval_batch_size: {per_device_eval_batch_size}"
            )

        if torch_dtype is None:
            torch_dtype = const.DEFAULT_TRAIN_TORCH_DTYPE
            logger.info(f"Using default torch_dtype: {torch_dtype}")
        else:
            logger.info(f"Using user provided torch_dtype: {torch_dtype}")

        if deepspeed is None:
            deepspeed = const.DEFAULT_TRAIN_DEEPSPEED_FLAG
            logger.info(f"Using default deepspeed: {deepspeed}")
        else:
            logger.info(f"Using user provided deepspeed: {deepspeed}")

        if gradient_checkpointing is None:
            gradient_checkpointing = const.DEFAULT_TRAIN_GRADIENT_CHECKPOINTING_FLAG
            logger.info(f"Using default gradient_checkpointing: {gradient_checkpointing}")
        else:
            logger.info(f"Using user provided gradient_checkpointing: {gradient_checkpointing}")

    # TODO: build on top of the existing ExperimentSettings instead of an
    # ad-hoc dict.
    data = {
        "name": name,
        "model": base_model,
        "project_id": self.project.id,
        "model_load_config": {
            "torch_dtype": torch_dtype,
            "low_cpu_mem_usage": low_cpu_mem_usage,
            "trust_remote_code": trust_remote_code,
            "token": hf_token,
        },
        "train_config": {
            "data_args": {
                "dataset_data": dataset,
                "block_size": block_size,
            },
            "slots_per_trial": slots_per_trial,
            "max_steps": max_steps,
            "num_train_epochs": num_train_epochs,
            "save_total_limit": save_total_limit,
            "learning_rate": learning_rate,
            "per_device_train_batch_size": per_device_train_batch_size,
            "per_device_eval_batch_size": per_device_eval_batch_size,
            "do_train": do_train,
            "do_eval": do_eval,
            "fp16": True if torch_dtype == const.FP16_TYPE_STR else False,
            "bf16": True if torch_dtype == const.BF16_TYPE_STR else False,
            "deepspeed": deepspeed,
            "gradient_checkpointing": gradient_checkpointing,
            "ds_cpu_offload": ds_cpu_offload,
            "seed": seed,
            "output_dir": output_dir,
            "eval_steps": eval_steps,
            "logging_steps": logging_steps,
            "save_steps": save_steps,
            "logging_strategy": logging_strategy,
            "evaluation_strategy": evaluation_strategy,
            "save_strategy": save_strategy,
        },
        "prompt_template": prompt_template,
        "resource_pool": resource_pool,
    }
    return self._launch_experiment(in_data=bt.ExperimentSettings.parse_obj(data))
def load_model(self, model: Union[str, Model], is_genai_base_model: bool = True, genai_model_version: Optional[int] = None, *, resource_pool: Optional[str] = None, use_vllm: bool = True, vllm_tensor_parallel_size: Optional[int] = None, vllm_swap_space: Optional[int] = None, hf_token: Optional[str] = None, torch_dtype: Optional[str] = None, low_cpu_mem_usage: bool = False, trust_remote_code: bool = False, is_arbitrary_hf_model: bool = False, disable_inference_config_recommendation: bool = False) ‑> bool
Expand source code
def load_model(
    self,
    model: Union[str, bt.Model],
    is_genai_base_model: bool = True,
    genai_model_version: Optional[int] = None,
    *,
    resource_pool: Optional[str] = None,
    use_vllm: bool = True,
    vllm_tensor_parallel_size: Optional[int] = None,
    vllm_swap_space: Optional[int] = None,
    hf_token: Optional[str] = None,
    torch_dtype: Optional[str] = None,
    low_cpu_mem_usage: bool = False,
    trust_remote_code: bool = False,
    is_arbitrary_hf_model: bool = False,
    disable_inference_config_recommendation: bool = False,
) -> bool:
    # Input validation
    if not use_vllm and vllm_tensor_parallel_size is not None:
        raise ValueError(
            "vllm_tensor_parallel_size should not be provided if use_vllm is False"
        )

    if vllm_tensor_parallel_size is not None and resource_pool is None:
        raise ValueError(
            "Please provide resource_pool when vllm_tensor_parallel_size is given. Otherwise, leave both fields blank for system recommendation."
        )

    if not use_vllm and vllm_tensor_parallel_size is not None:
        raise ValueError("use_vllm is False but vllm_tensor_parallel_size is not None")

    if is_arbitrary_hf_model and is_genai_base_model:
        raise ValueError("is_arbitrary_hf_model and is_genai_base_model cannot both be True")

    if torch_dtype is not None and torch_dtype not in (
        const.BF16_TYPE_STR,
        const.FP16_TYPE_STR,
    ):
        raise ValueError(
            f"torch_dtype should be {const.BF16_TYPE_STR} or {const.FP16_TYPE_STR}. Note that support for {const.BF16_TYPE_STR} is GPU dependent."
        )

    # End input validation

    try_config_recommendation = self._try_recommend_vllm_config(
        use_vllm=use_vllm,
        vllm_tensor_parallel_size=vllm_tensor_parallel_size,
        is_arbitrary_hf_model=is_arbitrary_hf_model,
        disable_inference_config_recommendation=disable_inference_config_recommendation,
    )

    recommended_config_found = False
    if try_config_recommendation:
        logger.info("Getting recommended vllm inference config from server")
        try:
            recommended_config = self._get_recommended_vllm_configs(
                resource_pool=resource_pool, model=model, is_base_model=is_genai_base_model
            )
            logger.info(
                "Using recommended config to run inference, to turn off recommendation, set disable_inference_config_recommendation = True"
            )
            # We should be only returning one vllm config per resource pool
            inference_config = recommended_config.vllm_configs[0]
            slots = inference_config.slots_per_trial
            resource_pool = recommended_config.resource_pool.name
            if torch_dtype is not None and torch_dtype != inference_config.torch_dtype:
                logger.info(
                    f"torch_dtype {torch_dtype} will be overwritten by recommended config value {inference_config.torch_dtype}"
                )
            torch_dtype = inference_config.torch_dtype
            self.max_new_tokens = inference_config.max_new_tokens
            self.batch_size = inference_config.batch_size
            vllm_config = {
                "tensor-parallel-size": slots,
                "swap-space": inference_config.swap_space,
            }
            recommended_config_found = True
        except ex.InferenceConfigNotFoundException as e:
            if resource_pool:
                logger.warning(
                    f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with provided resource_pool: {resource_pool}"
                )
            else:
                logger.warning(
                    f"No recommended inference config is found for model. Will use default number of GPU: {const.DEFAULT_SLOTS_PER_TRIAL} with default resource pool for the workspace"
                )

    if not recommended_config_found:
        logger.info("User provided / default model config is used.")
        # Reset max_new_tokens and batch_size
        self.max_new_tokens = None
        self.batch_size = None
        if vllm_tensor_parallel_size:
            slots = vllm_tensor_parallel_size
        else:
            slots = const.DEFAULT_SLOTS_PER_TRIAL

        vllm_config = {
            "tensor-parallel-size": slots,
            "swap-space": vllm_swap_space if vllm_swap_space else const.DEFAULT_SWAP_SPACE,
        }

    model_load_config: bt.ModelLoadConfig = bt.ModelLoadConfig(
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=low_cpu_mem_usage,
        trust_remote_code=trust_remote_code,
        token=hf_token,
        vllm_config=vllm_config,
    )

    if isinstance(model, str):
        params = {
            "name": model,
            "is_base_model": is_genai_base_model,
            "genai_model_version": genai_model_version,
            "resource_pool": resource_pool,
            "slots": slots,
            "use_vllm": use_vllm,
            "is_arbitrary_hf_model": is_arbitrary_hf_model,
            "client_key": self._client_key,
        }
        endpoint = f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model"
    else:
        params = {
            "resource_pool": resource_pool,
            "slots": slots,
            "use_vllm": use_vllm,
            "client_key": self._client_key,
        }
        endpoint = (
            f"{self._prefix}/workspace/{self.workspace.id}/inference/load_model/{model.id}"
        )
    params = {k: params[k] for k in params if params[k] is not None}
    return self._put(endpoint, params=params, in_data=model_load_config, sync=False)
def login(self, credentials: Union[determined.common.api.authentication.Credentials, str, determined.common.api._session.Session, Tuple[str, str], determined.common.api.authentication.Authentication], token: Optional[str] = None, mlde_host: Optional[str] = None) ‑> str

Login to the lore server.

token: @deprecated

Expand source code
def login(
    self,
    credentials: SupportedCredentials,
    token: Optional[str] = None,
    mlde_host: Optional[str] = None,
) -> str:
    """Login to the lore server.

    token: @deprecated
    """
    if token is not None:
        # for backward compatibility.
        self.token = token
        return token
    token = None
    if isinstance(credentials, str):
        token = credentials
    elif isinstance(credentials, Credentials):
        if mlde_host is None:
            logger.warn(
                f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
            )
            mlde_host = get_det_master_address()
        token = obtain_token(
            credentials.username, credentials.password, master_address=mlde_host
        )
    elif isinstance(credentials, tuple):
        if mlde_host is None:
            logger.warn(
                f"mlde_host is not provided for login. Falling back {get_det_master_address()}"
            )
            mlde_host = get_det_master_address()
        token = obtain_token(credentials[0], credentials[1], master_address=mlde_host)
    elif isinstance(credentials, Session):
        assert credentials._auth is not None, "Session must be authenticated."
        token = credentials._auth.get_session_token()
    elif isinstance(credentials, Authentication):
        token = credentials.get_session_token()
    else:
        raise ValueError(f"Unsupported credentials type: {type(credentials)}")
    self.token = token
    return token
def materialize_dataset(self, dataset: Dataset, output_path: Optional[str] = None) ‑> datasets.dataset_dict.DatasetDict
Expand source code
def materialize_dataset(
    self, dataset: bt.Dataset, output_path: Optional[str] = None
) -> hf.DatasetDict:
    if os.path.exists(dataset.storage_path):
        storage_path = dataset.storage_path
    else:
        storage_path = self._download_dataset(dataset, output_path)

    return hf.load_from_disk(storage_path)
def merge_and_resplit(self, dataset: Union[Dataset, str, int], train_ratio: float, validation_ratio: float, test_ratio: float, splits_to_resplit: List[str], as_dataset: bool = True, shuffle: bool = True, seed: int = 1234) ‑> Union[Dataset, Dict]
Expand source code
def merge_and_resplit(
    self,
    dataset: Union[bt.Dataset, str, int],
    train_ratio: float,
    validation_ratio: float,
    test_ratio: float,
    splits_to_resplit: List[str],
    as_dataset: bool = True,
    shuffle: bool = True,
    seed: int = 1234,
) -> Union[bt.Dataset, Dict]:
    dataset = self._get_dataset(dataset)
    data = bt.MergeAndResplitRequest(
        dataset=dataset,
        train_ratio=train_ratio,
        validation_ratio=validation_ratio,
        test_ratio=test_ratio,
        splits_to_resplit=splits_to_resplit,
        workspace_id=self.workspace.id,
        shuffle=shuffle,
        seed=seed,
    )
    out_cls = bt.Dataset if as_dataset is True else dict
    return self._post(
        f"{self._prefix}/dataset/merge_and_resplit",
        out_cls=out_cls,
        in_data=data,
        sync=False,
    )
def register_model(self, model: Model) ‑> Model
Expand source code
def register_model(self, model: bt.Model) -> bt.Model:
    if model.workspace_id is None:
        model.workspace_id = self.workspace.id
    assert model.workspace_id is not None, f"Workspace id is required for model registration"
    return self._post(f"{self._prefix}/model/register", model, out_cls=bt.Model)
def restart_controller(self, controller_type: ControllerType) ‑> RestartControllerResponse
Expand source code
def restart_controller(
    self, controller_type: bte.ControllerType
) -> bt.RestartControllerResponse:
    if controller_type == bte.ControllerType.DATA:
        request = bt.RestartControllerRequest(
            controller_type=controller_type,
            workspace_id=self.workspace.id,
        )
    else:
        request = bt.RestartControllerRequest(
            controller_type=controller_type,
            workspace_id=self.workspace.id,
            client_key=self._client_key,
        )
    route = f"{self._prefix}/controller/restart"
    out_cls = bt.RestartControllerResponse
    return self._put(route, request, out_cls=out_cls)
def sample_dataset(self, dataset: Union[Dataset, str, int], start_index: Optional[int] = None, number_of_samples: Optional[int] = None, ratio: Optional[float] = None, splits: Optional[List[str]] = None, seed: int = 1234, as_dataset: bool = False) ‑> Union[Dataset, Dict]
Expand source code
def sample_dataset(
    self,
    dataset: Union[bt.Dataset, str, int],
    start_index: Optional[int] = None,
    number_of_samples: Optional[int] = None,
    ratio: Optional[float] = None,
    splits: Optional[List[str]] = None,
    seed: int = 1234,
    as_dataset: bool = False,
) -> Union[bt.Dataset, Dict]:
    dataset = self._get_dataset(dataset)
    request = bt.SampleDatasetRequest(
        dataset=dataset,
        start_index=start_index,
        number_of_samples=number_of_samples,
        ratio=ratio,
        splits=splits,
        workspace_id=self.workspace.id,
        seed=seed,
        as_dataset=as_dataset,
    )

    out_cls = bt.Dataset if as_dataset is True else None
    output = self._post(
        f"{self._prefix}/dataset/sample", in_data=request, out_cls=out_cls, sync=False
    )
    if isinstance(output, bt.Dataset):
        return output
    else:
        dataset_sample = dict(output)
        return {key: pd.read_json(sample) for key, sample in dataset_sample.items()}
def set_workspace(self, workspace: Union[str, Workspace, int]) ‑> None
Expand source code
def set_workspace(self, workspace: Union[str, bt.Workspace, int]) -> None:
    if isinstance(workspace, str):
        workspaces = self.get_workspaces()
        workspace = next((w for w in workspaces if w.name == workspace), None)
        assert isinstance(workspace, bt.Workspace), f"Workspace {workspace} not found."
        self.workspace = workspace
    elif isinstance(workspace, bt.Workspace):
        self.workspace = self.get_workspace(workspace.id)
    else:
        self.workspace = self.get_workspace(workspace)
    self.project = self._get(
        f"{self._prefix}/project/{self.workspace.experiment_project_id}", out_cls=bt.Project
    )
def update_playground_snapshot(self, request: UpdatePlaygroundSnaphotRequest) ‑> PlaygroundSnapshot
Expand source code
def update_playground_snapshot(
    self, request: bt.UpdatePlaygroundSnaphotRequest
) -> bt.PlaygroundSnapshot:
    return self._put(
        f"{self._prefix}/playground-snapshot/{request.id}",
        in_data=request,
        out_cls=bt.PlaygroundSnapshot,
    )
def update_prompt_template(self, prompt_template: PromptTemplate) ‑> PromptTemplate
Expand source code
def update_prompt_template(self, prompt_template: bt.PromptTemplate) -> bt.PromptTemplate:
    return self._put(
        f"{self._prefix}/prompt-template/{prompt_template.id}",
        in_data=prompt_template,
        out_cls=bt.PromptTemplate,
    )
def upload_model_to_hf(self, model: Model, hf_repo_owner: str, hf_repo_name: str, hf_token: str, private: bool = False) ‑> int
Expand source code
def upload_model_to_hf(
    self,
    model: bt.Model,
    hf_repo_owner: str,
    hf_repo_name: str,
    hf_token: str,
    private: bool = False,
) -> int:
    data = bt.UploadModelToHFRequest(
        model=model,
        hf_repo_owner=hf_repo_owner,
        hf_repo_name=hf_repo_name,
        hf_token=hf_token,
        private=private,
    )
    response: bt.UploadModelToHFResponse = self._post(
        f"{self._prefix}/model/upload_model_to_hf",
        in_data=data,
        out_cls=bt.UploadModelToHFResponse,
        sync=True,
    )
    return response.experiment_id