diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index 0b8f8fc9..e717db45 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -31,7 +31,7 @@ is_context_value, ) from flag_engine.segments.utils import get_matching_function -from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids +from flag_engine.utils.hashing import get_hashed_percentage_for_object_id_pair from flag_engine.utils.semver import is_semver from flag_engine.utils.types import SupportsStr, get_casting_function @@ -58,8 +58,9 @@ def get_evaluation_result( :return: EvaluationResult containing the context, flags, and segments """ context = get_enriched_context(context) + identity_key = _get_identity_key(context) segments, segment_overrides = evaluate_segments(context) - flags = evaluate_features(context, segment_overrides) + flags = evaluate_features(context, segment_overrides, identity_key=identity_key) return { "flags": flags, @@ -138,26 +139,57 @@ def evaluate_segments( def evaluate_features( context: EvaluationContext[typing.Any, FeatureMetadataT], segment_overrides: SegmentOverrides[FeatureMetadataT], + *, + identity_key: typing.Optional[str] = None, ) -> dict[str, FlagResult[FeatureMetadataT]]: if not (features := context.get("features")): return {} + # ``identity_key`` is invariant across all features in a single evaluation. + # Resolving it here (or accepting it from the caller) means the per-feature + # hot loop below doesn't have to re-walk ``context["identity"]`` N times. + if identity_key is None: + identity_key = _get_identity_key(context) + + # Localise loop dependencies once so the tight per-feature loop doesn't + # chase module globals on every iteration. ``_build_flag_result`` is + # inlined below for environments with many features (e.g. 250+), where + # the function-call overhead is otherwise ~15% of per-call time. + hash_fn = get_hashed_percentage_for_object_id_pair + overrides_get = segment_overrides.get + flags: dict[str, FlagResult[FeatureMetadataT]] = {} + for feature_name, feature_context in features.items(): + if segment_override := overrides_get(feature_name): + effective_feature_context = segment_override["feature_context"] + reason = f"TARGETING_MATCH; segment={segment_override['segment_name']}" + else: + effective_feature_context = feature_context + reason = "DEFAULT" - for feature_context in features.values(): - feature_name = feature_context["name"] - if segment_override := segment_overrides.get(feature_name): - flags[feature_name] = get_flag_result_from_context( - context=context, - feature_context=segment_override["feature_context"], - reason=f"TARGETING_MATCH; segment={segment_override['segment_name']}", - ) - continue - flags[feature_name] = get_flag_result_from_context( - context=context, - feature_context=context["features"][feature_name], - reason="DEFAULT", - ) + value: typing.Any = effective_feature_context["value"] + if identity_key is not None and ( + variants := effective_feature_context.get("variants") + ): + percentage_value = hash_fn(effective_feature_context["key"], identity_key) + start_percentage = 0.0 + for variant in sorted(variants, key=_variant_priority): + limit = (weight := variant["weight"]) + start_percentage + if start_percentage <= percentage_value < limit: + value = variant["value"] + reason = f"SPLIT; weight={weight}" + break + start_percentage = limit + + flag_result: FlagResult[FeatureMetadataT] = { + "enabled": effective_feature_context["enabled"], + "name": effective_feature_context["name"], + "reason": reason, + "value": value, + } + if metadata := effective_feature_context.get("metadata"): + flag_result["metadata"] = metadata + flags[feature_name] = flag_result return flags @@ -176,47 +208,38 @@ def get_flag_result_from_context( :param reason: reason to use when no variant selected :return: the value for the feature name in the evaluation context """ - key = _get_identity_key(context) + identity_key = _get_identity_key(context) + value: typing.Any = feature_context["value"] - flag_result: typing.Optional[FlagResult[FeatureMetadataT]] = None - - if key is not None and (variants := feature_context.get("variants")): - percentage_value = get_hashed_percentage_for_object_ids( - [feature_context["key"], key] + if identity_key is not None and (variants := feature_context.get("variants")): + percentage_value = get_hashed_percentage_for_object_id_pair( + feature_context["key"], identity_key ) - start_percentage = 0.0 - - for variant in sorted( - variants, - key=operator.itemgetter("priority"), - ): + for variant in sorted(variants, key=_variant_priority): limit = (weight := variant["weight"]) + start_percentage if start_percentage <= percentage_value < limit: - flag_result = { - "enabled": feature_context["enabled"], - "name": feature_context["name"], - "reason": f"SPLIT; weight={weight}", - "value": variant["value"], - } + value = variant["value"] + reason = f"SPLIT; weight={weight}" break - start_percentage = limit - if flag_result is None: - flag_result = { - "enabled": feature_context["enabled"], - "name": feature_context["name"], - "reason": reason, - "value": feature_context["value"], - } - + flag_result: FlagResult[FeatureMetadataT] = { + "enabled": feature_context["enabled"], + "name": feature_context["name"], + "reason": reason, + "value": value, + } if metadata := feature_context.get("metadata"): flag_result["metadata"] = metadata - return flag_result +def _variant_priority(variant: typing.Mapping[str, typing.Any]) -> int: + priority: int = variant["priority"] + return priority + + def is_context_in_segment( context: _EvaluationContextAnyMeta, segment_context: SegmentContext[typing.Any, typing.Any], @@ -290,14 +313,14 @@ def context_matches_condition( if condition_operator == constants.PERCENTAGE_SPLIT: if context_value is None: return False - - object_ids = [segment_key, context_value] - try: float_value = float(condition["value"]) except ValueError: return False - return get_hashed_percentage_for_object_ids(object_ids) <= float_value + return ( + get_hashed_percentage_for_object_id_pair(segment_key, context_value) + <= float_value + ) if condition_operator == constants.IS_NOT_SET: return context_value is None diff --git a/flag_engine/utils/hashing.py b/flag_engine/utils/hashing.py index c4618e1e..321f63c9 100644 --- a/flag_engine/utils/hashing.py +++ b/flag_engine/utils/hashing.py @@ -31,3 +31,24 @@ def get_hashed_percentage_for_object_ids( ) return value + + +def get_hashed_percentage_for_object_id_pair( + first: SupportsStr, + second: SupportsStr, +) -> float: + """Fast path for the hot two-key case used by variant selection and + ``PERCENTAGE_SPLIT`` conditions. Skips the iterator / list wrapping that + the generic helper performs on every call. + + Returns the same value as + ``get_hashed_percentage_for_object_ids([first, second])``. + """ + to_hash = f"{first},{second}" + hashed_value = hashlib.md5(to_hash.encode("utf-8")) + hashed_value_as_int = int(hashed_value.hexdigest(), base=16) + value = ((hashed_value_as_int % 9999) / 9998) * 100 + if value == 100: + # Extremely unlikely; fall back to the generic recursion-capable path. + return get_hashed_percentage_for_object_ids([first, second], iterations=2) + return value diff --git a/tests/engine_tests/engine-test-data b/tests/engine_tests/engine-test-data index 9307930e..c2b2f034 160000 --- a/tests/engine_tests/engine-test-data +++ b/tests/engine_tests/engine-test-data @@ -1 +1 @@ -Subproject commit 9307930e9e64482a35e7d6b254225addb6e44687 +Subproject commit c2b2f0347a52b4069a429c663ad3bbc53fed3eb6 diff --git a/tests/engine_tests/test_engine.py b/tests/engine_tests/test_engine.py index fb98f4ba..8ac8c166 100644 --- a/tests/engine_tests/test_engine.py +++ b/tests/engine_tests/test_engine.py @@ -11,6 +11,9 @@ from flag_engine.result.types import EvaluationResult TEST_CASES_PATH = Path(__file__).parent / "engine-test-data/test_cases" +LARGE_ENVIRONMENT_TEST_CASE = ( + "test_000000cf-0000-0000-0000-000000000000__large_environment.json" +) EnvironmentDocument = dict[str, typing.Any] @@ -43,11 +46,19 @@ def _extract_benchmark_contexts( yield pyjson5.loads((test_cases_dir_path / file_path).read_text())["context"] +def _load_test_case_context(name: str) -> EvaluationContext: + ctx: EvaluationContext = pyjson5.loads((TEST_CASES_PATH / name).read_text())[ + "context" + ] + return ctx + + TEST_CASES = sorted( _extract_test_cases(TEST_CASES_PATH), key=lambda param: str(param.id), ) BENCHMARK_CONTEXTS = list(_extract_benchmark_contexts(TEST_CASES_PATH)) +LARGE_BENCHMARK_CONTEXT = _load_test_case_context(LARGE_ENVIRONMENT_TEST_CASE) @pytest.mark.parametrize( @@ -69,3 +80,8 @@ def test_engine( def test_engine_benchmark() -> None: for context in BENCHMARK_CONTEXTS: get_evaluation_result(context) + + +@pytest.mark.benchmark +def test_engine_benchmark_large_context() -> None: + get_evaluation_result(LARGE_BENCHMARK_CONTEXT) diff --git a/tests/unit/segments/test_segments_evaluator.py b/tests/unit/segments/test_segments_evaluator.py index 844d4fa2..906f6cd2 100644 --- a/tests/unit/segments/test_segments_evaluator.py +++ b/tests/unit/segments/test_segments_evaluator.py @@ -265,7 +265,7 @@ def test_context_in_segment_percentage_split( } mock_get_hashed_percentage = mocker.patch( - "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids" + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair" ) mock_get_hashed_percentage.return_value = identity_hashed_percentage @@ -308,7 +308,7 @@ def test_context_in_segment_percentage_split__no_identity__returns_expected( } mock_get_hashed_percentage = mocker.patch( - "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids" + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair" ) # When @@ -352,7 +352,7 @@ def test_context_in_segment_percentage_split__trait_value__calls_expected( } mock_get_hashed_percentage = mocker.patch( - "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids" + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair" ) mock_get_hashed_percentage.return_value = 1 @@ -361,7 +361,7 @@ def test_context_in_segment_percentage_split__trait_value__calls_expected( # Then mock_get_hashed_percentage.assert_called_once_with( - [segment_context["key"], "custom_value"] + segment_context["key"], "custom_value" ) assert result @@ -806,6 +806,7 @@ def test_segment_condition_matches_context_value_for_modulo( "name": "my_feature", "reason": "SPLIT; weight=30", "value": "foo", + "metadata": {"id": 7}, }, ), ( @@ -815,6 +816,7 @@ def test_segment_condition_matches_context_value_for_modulo( "name": "my_feature", "reason": "SPLIT; weight=30", "value": "bar", + "metadata": {"id": 7}, }, ), ( @@ -824,6 +826,7 @@ def test_segment_condition_matches_context_value_for_modulo( "name": "my_feature", "reason": "DEFAULT", "value": "control", + "metadata": {"id": 7}, }, ), ), @@ -841,7 +844,7 @@ def test_get_flag_result_from_context__calls_returns_expected( # we mock the function which gets the percentage value for an identity to # return a deterministic value so we know which value to expect get_hashed_percentage_for_object_ids_mock = mocker.patch( - "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids", + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair", ) get_hashed_percentage_for_object_ids_mock.return_value = percentage_value @@ -851,6 +854,7 @@ def test_get_flag_result_from_context__calls_returns_expected( "enabled": False, "name": "my_feature", "value": "control", + "metadata": {"id": 7}, "variants": [ {"value": "foo", "weight": 30, "priority": 1}, {"value": "bar", "weight": 30, "priority": 2}, @@ -870,10 +874,8 @@ def test_get_flag_result_from_context__calls_returns_expected( # the function is called with the expected key get_hashed_percentage_for_object_ids_mock.assert_called_once_with( - [ - expected_feature_context_key, - expected_key, - ] + expected_feature_context_key, + expected_key, ) @@ -885,7 +887,7 @@ def test_get_flag_result_from_feature_context__null_key__calls_returns_expected( expected_feature_context_key = "2" get_hashed_percentage_for_object_ids_mock = mocker.patch( - "flag_engine.segments.evaluator.get_hashed_percentage_for_object_ids", + "flag_engine.segments.evaluator.get_hashed_percentage_for_object_id_pair", ) feature_context: FeatureContext = { diff --git a/tests/unit/utils/test_utils_hashing.py b/tests/unit/utils/test_utils_hashing.py index 2b8621f9..a96ac63c 100644 --- a/tests/unit/utils/test_utils_hashing.py +++ b/tests/unit/utils/test_utils_hashing.py @@ -5,7 +5,10 @@ import pytest -from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids +from flag_engine.utils.hashing import ( + get_hashed_percentage_for_object_id_pair, + get_hashed_percentage_for_object_ids, +) @pytest.mark.parametrize( @@ -146,3 +149,31 @@ def hexdigest_side_effect() -> str: # the second call, with a string (in bytes) that contains each object id twice expected_bytes_2 = ",".join(str(id_) for id_ in object_ids * 2).encode("utf-8") assert call_list[1][0][0] == expected_bytes_2 + + +@mock.patch("flag_engine.utils.hashing.hashlib") +def test_get_hashed_percentage_for_object_id_pair__value_is_100__falls_back( + mock_hashlib: mock.Mock, +) -> None: + """When the two-key fast path would return exactly 100, it must fall back + to the generic helper with iterations=2 (same anti-boundary guarantee as + ``get_hashed_percentage_for_object_ids``).""" + + # 270e converts to 9998, forcing value == 100. 270f → 9999 → value == 0. + hashed_values = ["270f", "270e"] + + def hexdigest_side_effect() -> str: + return hashed_values.pop() + + mock_hash = mock.MagicMock() + mock_hashlib.md5.return_value = mock_hash + mock_hash.hexdigest.side_effect = hexdigest_side_effect + + value = get_hashed_percentage_for_object_id_pair("12", "93") + + assert value == 0 + # First call: fast-path two-key hash (single pair); second: recursive fallback. + call_list = mock_hashlib.md5.call_args_list + assert len(call_list) == 2 + assert call_list[0][0][0] == b"12,93" + assert call_list[1][0][0] == b"12,93,12,93"