Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,7 @@ def init(
experiment: str | None = ...,
description: str | None = ...,
dataset: "Dataset | None" = ...,
_internal_btql: dict[str, Any] | None = ...,
parameters: RemoteEvalParameters | ParametersRef | None = ...,
open: Literal[False] = ...,
base_experiment: str | None = ...,
Expand All @@ -1560,6 +1561,7 @@ def init(
experiment: str | None = ...,
description: str | None = ...,
dataset: "Dataset | None" = ...,
_internal_btql: dict[str, Any] | None = ...,
parameters: RemoteEvalParameters | ParametersRef | None = ...,
open: Literal[True] = ...,
base_experiment: str | None = ...,
Expand All @@ -1584,6 +1586,7 @@ def init(
experiment: str | None = None,
description: str | None = None,
dataset: "Dataset | None | DatasetRef" = None,
_internal_btql: dict[str, Any] | None = None,
parameters: RemoteEvalParameters | ParametersRef | None = None,
open: bool = False,
base_experiment: str | None = None,
Expand Down Expand Up @@ -1719,6 +1722,12 @@ def compute_metadata():
args["dataset_id"] = dataset.id
args["dataset_version"] = dataset.version

dataset_filter = _internal_btql
if dataset_filter is None and isinstance(dataset, Dataset):
dataset_filter = dataset._internal_btql
if dataset_filter is not None:
args["internal_metadata"] = {"dataset_filter": dataset_filter}

parameters_ref = _get_parameters_ref(parameters)
if parameters_ref is not None:
args["parameters_id"] = parameters_ref["id"]
Expand Down Expand Up @@ -2876,6 +2885,38 @@ def __next__(self) -> T:
MAX_BTQL_ITERATIONS = 10000


def _is_internal_btql_filter_clause(value: Any) -> bool:
return isinstance(value, dict) and isinstance(value.get("op"), str)


def _normalize_internal_btql(
internal_btql: dict[str, Any] | None,
) -> dict[str, Any] | None:
if internal_btql is None or "filters" not in internal_btql:
return internal_btql

normalized_internal_btql = {key: value for key, value in internal_btql.items() if key != "filters"}
if "filter" in normalized_internal_btql:
return normalized_internal_btql

filters = internal_btql.get("filters")
if not isinstance(filters, list) or not all(_is_internal_btql_filter_clause(value) for value in filters):
return internal_btql

if len(filters) == 1:
normalized_internal_btql["filter"] = filters[0]
return normalized_internal_btql

if len(filters) > 1:
normalized_internal_btql["filter"] = {
"op": "and",
"children": filters,
}
return normalized_internal_btql

return normalized_internal_btql


class ObjectFetcher(ABC, Generic[TMapping]):
def __init__(
self,
Expand Down Expand Up @@ -2941,6 +2982,7 @@ def _refetch(self, batch_size: int | None = None) -> list[TMapping]:
cursor = None
data = None
iterations = 0
normalized_internal_btql = _normalize_internal_btql(self._internal_btql)
while True:
resp = state.api_conn().post(
f"btql",
Expand All @@ -2962,7 +3004,7 @@ def _refetch(self, batch_size: int | None = None) -> list[TMapping]:
},
"cursor": cursor,
"limit": limit,
**(self._internal_btql or {}),
**(normalized_internal_btql or {}),
},
"use_columnstore": False,
"brainstore_realtime": True,
Expand Down