Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 71 additions & 48 deletions flag_engine/segments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
is_context_value,
)
from flag_engine.segments.utils import get_matching_function
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
from flag_engine.utils.hashing import get_hashed_percentage_for_object_id_pair
from flag_engine.utils.semver import is_semver
from flag_engine.utils.types import SupportsStr, get_casting_function

Expand All @@ -58,8 +58,9 @@ def get_evaluation_result(
:return: EvaluationResult containing the context, flags, and segments
"""
context = get_enriched_context(context)
identity_key = _get_identity_key(context)
segments, segment_overrides = evaluate_segments(context)
flags = evaluate_features(context, segment_overrides)
flags = evaluate_features(context, segment_overrides, identity_key=identity_key)

return {
"flags": flags,
Expand Down Expand Up @@ -138,26 +139,57 @@ def evaluate_segments(
def evaluate_features(
context: EvaluationContext[typing.Any, FeatureMetadataT],
segment_overrides: SegmentOverrides[FeatureMetadataT],
*,
identity_key: typing.Optional[str] = None,
) -> dict[str, FlagResult[FeatureMetadataT]]:
if not (features := context.get("features")):
return {}

# ``identity_key`` is invariant across all features in a single evaluation.
# Resolving it here (or accepting it from the caller) means the per-feature
# hot loop below doesn't have to re-walk ``context["identity"]`` N times.
if identity_key is None:
identity_key = _get_identity_key(context)

# Localise loop dependencies once so the tight per-feature loop doesn't
# chase module globals on every iteration. ``_build_flag_result`` is
# inlined below for environments with many features (e.g. 250+), where
# the function-call overhead is otherwise ~15% of per-call time.
hash_fn = get_hashed_percentage_for_object_id_pair
overrides_get = segment_overrides.get

flags: dict[str, FlagResult[FeatureMetadataT]] = {}
for feature_name, feature_context in features.items():
if segment_override := overrides_get(feature_name):
effective_feature_context = segment_override["feature_context"]
reason = f"TARGETING_MATCH; segment={segment_override['segment_name']}"
else:
effective_feature_context = feature_context
reason = "DEFAULT"

for feature_context in features.values():
feature_name = feature_context["name"]
if segment_override := segment_overrides.get(feature_name):
flags[feature_name] = get_flag_result_from_context(
context=context,
feature_context=segment_override["feature_context"],
reason=f"TARGETING_MATCH; segment={segment_override['segment_name']}",
)
continue
flags[feature_name] = get_flag_result_from_context(
context=context,
feature_context=context["features"][feature_name],
reason="DEFAULT",
)
value: typing.Any = effective_feature_context["value"]
if identity_key is not None and (
variants := effective_feature_context.get("variants")
):
percentage_value = hash_fn(effective_feature_context["key"], identity_key)
start_percentage = 0.0
for variant in sorted(variants, key=_variant_priority):
limit = (weight := variant["weight"]) + start_percentage
if start_percentage <= percentage_value < limit:
value = variant["value"]
reason = f"SPLIT; weight={weight}"
break
start_percentage = limit

flag_result: FlagResult[FeatureMetadataT] = {
"enabled": effective_feature_context["enabled"],
"name": effective_feature_context["name"],
"reason": reason,
"value": value,
}
if metadata := effective_feature_context.get("metadata"):
flag_result["metadata"] = metadata
flags[feature_name] = flag_result

return flags

Expand All @@ -176,47 +208,38 @@ def get_flag_result_from_context(
:param reason: reason to use when no variant selected
:return: the value for the feature name in the evaluation context
"""
key = _get_identity_key(context)
identity_key = _get_identity_key(context)
value: typing.Any = feature_context["value"]

flag_result: typing.Optional[FlagResult[FeatureMetadataT]] = None

if key is not None and (variants := feature_context.get("variants")):
percentage_value = get_hashed_percentage_for_object_ids(
[feature_context["key"], key]
if identity_key is not None and (variants := feature_context.get("variants")):
percentage_value = get_hashed_percentage_for_object_id_pair(
feature_context["key"], identity_key
)

start_percentage = 0.0

for variant in sorted(
variants,
key=operator.itemgetter("priority"),
):
for variant in sorted(variants, key=_variant_priority):
limit = (weight := variant["weight"]) + start_percentage
if start_percentage <= percentage_value < limit:
flag_result = {
"enabled": feature_context["enabled"],
"name": feature_context["name"],
"reason": f"SPLIT; weight={weight}",
"value": variant["value"],
}
value = variant["value"]
reason = f"SPLIT; weight={weight}"
break

start_percentage = limit

if flag_result is None:
flag_result = {
"enabled": feature_context["enabled"],
"name": feature_context["name"],
"reason": reason,
"value": feature_context["value"],
}

flag_result: FlagResult[FeatureMetadataT] = {
"enabled": feature_context["enabled"],
"name": feature_context["name"],
"reason": reason,
"value": value,
}
if metadata := feature_context.get("metadata"):
flag_result["metadata"] = metadata

return flag_result


def _variant_priority(variant: typing.Mapping[str, typing.Any]) -> int:
priority: int = variant["priority"]
return priority


def is_context_in_segment(
context: _EvaluationContextAnyMeta,
segment_context: SegmentContext[typing.Any, typing.Any],
Expand Down Expand Up @@ -290,14 +313,14 @@ def context_matches_condition(
if condition_operator == constants.PERCENTAGE_SPLIT:
if context_value is None:
return False

object_ids = [segment_key, context_value]

try:
float_value = float(condition["value"])
except ValueError:
return False
return get_hashed_percentage_for_object_ids(object_ids) <= float_value
return (
get_hashed_percentage_for_object_id_pair(segment_key, context_value)
<= float_value
)

if condition_operator == constants.IS_NOT_SET:
return context_value is None
Expand Down
21 changes: 21 additions & 0 deletions flag_engine/utils/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,24 @@ def get_hashed_percentage_for_object_ids(
)

return value


def get_hashed_percentage_for_object_id_pair(
first: SupportsStr,
second: SupportsStr,
) -> float:
"""Fast path for the hot two-key case used by variant selection and
``PERCENTAGE_SPLIT`` conditions. Skips the iterator / list wrapping that
the generic helper performs on every call.

Returns the same value as
``get_hashed_percentage_for_object_ids([first, second])``.
"""
to_hash = f"{first},{second}"
hashed_value = hashlib.md5(to_hash.encode("utf-8"))
hashed_value_as_int = int(hashed_value.hexdigest(), base=16)
value = ((hashed_value_as_int % 9999) / 9998) * 100
if value == 100:
# Extremely unlikely; fall back to the generic recursion-capable path.
return get_hashed_percentage_for_object_ids([first, second], iterations=2)
return value
16 changes: 16 additions & 0 deletions tests/engine_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from flag_engine.result.types import EvaluationResult

TEST_CASES_PATH = Path(__file__).parent / "engine-test-data/test_cases"
LARGE_ENVIRONMENT_TEST_CASE = (
"test_000000cf-0000-0000-0000-000000000000__large_environment.json"
)

EnvironmentDocument = dict[str, typing.Any]

Expand Down Expand Up @@ -43,11 +46,19 @@ def _extract_benchmark_contexts(
yield pyjson5.loads((test_cases_dir_path / file_path).read_text())["context"]


def _load_test_case_context(name: str) -> EvaluationContext:
ctx: EvaluationContext = pyjson5.loads((TEST_CASES_PATH / name).read_text())[
"context"
]
return ctx


TEST_CASES = sorted(
_extract_test_cases(TEST_CASES_PATH),
key=lambda param: str(param.id),
)
BENCHMARK_CONTEXTS = list(_extract_benchmark_contexts(TEST_CASES_PATH))
LARGE_BENCHMARK_CONTEXT = _load_test_case_context(LARGE_ENVIRONMENT_TEST_CASE)


@pytest.mark.parametrize(
Expand All @@ -69,3 +80,8 @@ def test_engine(
def test_engine_benchmark() -> None:
for context in BENCHMARK_CONTEXTS:
get_evaluation_result(context)


@pytest.mark.benchmark
def test_engine_benchmark_large_context() -> None:
get_evaluation_result(LARGE_BENCHMARK_CONTEXT)
22 changes: 12 additions & 10 deletions tests/unit/segments/test_segments_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_context_in_segment_percentage_split(
}

mock_get_hashed_percentage = mocker.patch(
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids"
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair"
)
mock_get_hashed_percentage.return_value = identity_hashed_percentage

Expand Down Expand Up @@ -308,7 +308,7 @@ def test_context_in_segment_percentage_split__no_identity__returns_expected(
}

mock_get_hashed_percentage = mocker.patch(
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids"
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair"
)

# When
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_context_in_segment_percentage_split__trait_value__calls_expected(
}

mock_get_hashed_percentage = mocker.patch(
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids"
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair"
)
mock_get_hashed_percentage.return_value = 1

Expand All @@ -361,7 +361,7 @@ def test_context_in_segment_percentage_split__trait_value__calls_expected(

# Then
mock_get_hashed_percentage.assert_called_once_with(
[segment_context["key"], "custom_value"]
segment_context["key"], "custom_value"
)
assert result

Expand Down Expand Up @@ -806,6 +806,7 @@ def test_segment_condition_matches_context_value_for_modulo(
"name": "my_feature",
"reason": "SPLIT; weight=30",
"value": "foo",
"metadata": {"id": 7},
},
),
(
Expand All @@ -815,6 +816,7 @@ def test_segment_condition_matches_context_value_for_modulo(
"name": "my_feature",
"reason": "SPLIT; weight=30",
"value": "bar",
"metadata": {"id": 7},
},
),
(
Expand All @@ -824,6 +826,7 @@ def test_segment_condition_matches_context_value_for_modulo(
"name": "my_feature",
"reason": "DEFAULT",
"value": "control",
"metadata": {"id": 7},
},
),
),
Expand All @@ -841,7 +844,7 @@ def test_get_flag_result_from_context__calls_returns_expected(
# we mock the function which gets the percentage value for an identity to
# return a deterministic value so we know which value to expect
get_hashed_percentage_for_object_ids_mock = mocker.patch(
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids",
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair",
)
get_hashed_percentage_for_object_ids_mock.return_value = percentage_value

Expand All @@ -851,6 +854,7 @@ def test_get_flag_result_from_context__calls_returns_expected(
"enabled": False,
"name": "my_feature",
"value": "control",
"metadata": {"id": 7},
"variants": [
{"value": "foo", "weight": 30, "priority": 1},
{"value": "bar", "weight": 30, "priority": 2},
Expand All @@ -870,10 +874,8 @@ def test_get_flag_result_from_context__calls_returns_expected(

# the function is called with the expected key
get_hashed_percentage_for_object_ids_mock.assert_called_once_with(
[
expected_feature_context_key,
expected_key,
]
expected_feature_context_key,
expected_key,
)


Expand All @@ -885,7 +887,7 @@ def test_get_flag_result_from_feature_context__null_key__calls_returns_expected(
expected_feature_context_key = "2"

get_hashed_percentage_for_object_ids_mock = mocker.patch(
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids",
"flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair",
)

feature_context: FeatureContext = {
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/utils/test_utils_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

import pytest

from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
from flag_engine.utils.hashing import (
get_hashed_percentage_for_object_id_pair,
get_hashed_percentage_for_object_ids,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -146,3 +149,31 @@ def hexdigest_side_effect() -> str:
# the second call, with a string (in bytes) that contains each object id twice
expected_bytes_2 = ",".join(str(id_) for id_ in object_ids * 2).encode("utf-8")
assert call_list[1][0][0] == expected_bytes_2


@mock.patch("flag_engine.utils.hashing.hashlib")
def test_get_hashed_percentage_for_object_id_pair__value_is_100__falls_back(
mock_hashlib: mock.Mock,
) -> None:
"""When the two-key fast path would return exactly 100, it must fall back
to the generic helper with iterations=2 (same anti-boundary guarantee as
``get_hashed_percentage_for_object_ids``)."""

# 270e converts to 9998, forcing value == 100. 270f → 9999 → value == 0.
hashed_values = ["270f", "270e"]

def hexdigest_side_effect() -> str:
return hashed_values.pop()

mock_hash = mock.MagicMock()
mock_hashlib.md5.return_value = mock_hash
mock_hash.hexdigest.side_effect = hexdigest_side_effect

value = get_hashed_percentage_for_object_id_pair("12", "93")

assert value == 0
# First call: fast-path two-key hash (single pair); second: recursive fallback.
call_list = mock_hashlib.md5.call_args_list
assert len(call_list) == 2
assert call_list[0][0][0] == b"12,93"
assert call_list[1][0][0] == b"12,93,12,93"
Loading