diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs index 4293d54bc30..6cc340b8609 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs @@ -225,7 +225,7 @@ private async ValueTask WritePayloadAsync(string key, CacheItem cacheItem, Buffe byte[] oversized = ArrayPool.Shared.Rent(maxLength); int length = HybridCachePayload.Write(oversized, key, cacheItem.CreationTimestamp, GetL2AbsoluteExpirationRelativeToNow(options), - HybridCachePayload.PayloadFlags.None, cacheItem.Tags, payload.AsSequence()); + HybridCachePayload.PayloadFlags.None, cacheItem.Tags, payload.AsSequence(), localCacheSize: options?.LocalSize); await SetDirectL2Async(key, new(oversized, 0, length, true), GetL2DistributedCacheOptions(options), token).ConfigureAwait(false); diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs index c48e01969bf..a1230e1a929 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs @@ -13,6 +13,7 @@ internal sealed partial class MutableCacheItem : CacheItem // used to hold private IHybridCacheSerializer? _serializer; private BufferChunk _buffer; private T? _fallbackValue; // only used in the case of serialization failures + private long? _localSizeOverride; public MutableCacheItem(long creationTimestamp, TagSet tags) : base(creationTimestamp, tags) @@ -66,7 +67,7 @@ public override bool TryGetSize(out long size) // only if we haven't already burned if (TryReserve()) { - size = _buffer.Length; + size = _localSizeOverride ?? _buffer.Length; _ = Release(); return true; } @@ -75,6 +76,11 @@ public override bool TryGetSize(out long size) return false; } + public void SetLocalSizeOverride(long size) + { + _localSizeOverride = size; + } + public override bool TryReserveBuffer(out BufferChunk buffer) { // only if we haven't already burned diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs index 34a68ef30aa..63047832ffa 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs @@ -18,12 +18,13 @@ internal partial class DefaultHybridCache internal sealed class StampedeState : StampedeState { // note on terminology: L1 and L2 are, for brevity, used interchangeably with "local" and "distributed" cache, i.e. `IMemoryCache` and `IDistributedCache` - private const HybridCacheEntryFlags FlagsDisableL1AndL2Write = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite; private readonly TaskCompletionSource>? _result; private TState? _state; - private Func>? _underlying; // main data factory + private Func>? _factory; // main data factory + private Func>? _factoryWithOptions; // options-aware data factory private HybridCacheEntryOptions? _options; + private int _factoryOptionsRevision; // initial Revision of the options instance passed to the factory; used to detect mutations private Task? _sharedUnwrap; // allows multiple non-cancellable callers to share a single task (when no defensive copy needed) // ONLY set the result, without any other side-effects @@ -44,14 +45,14 @@ public StampedeState(DefaultHybridCache cache, in StampedeKey key, TagSet tags, public override Type Type => typeof(T); - public void QueueUserWorkItem(in TState state, Func> underlying, HybridCacheEntryOptions? options) + public void QueueUserWorkItem(in TState state, Func> factory, HybridCacheEntryOptions? options) { - Debug.Assert(_underlying is null, "should not already have factory field"); - Debug.Assert(underlying is not null, "factory argument should be meaningful"); + Debug.Assert(_factory is null, "should not already have factory field"); + Debug.Assert(factory is not null, "factory argument should be meaningful"); // initialize the callback state _state = state; - _underlying = underlying; + _factory = factory; _options = options; #if NETCOREAPP3_0_OR_GREATER @@ -61,16 +62,49 @@ public void QueueUserWorkItem(in TState state, Func> factory, HybridCacheEntryOptions options) + { + Debug.Assert(_factoryWithOptions is null, "should not already have factory field"); + Debug.Assert(factory is not null, "factory argument should be meaningful"); + + // initialize the callback state + _state = state; + _factoryWithOptions = factory; + _options = options; + _factoryOptionsRevision = _options.Revision; + +#if NETCOREAPP3_0_OR_GREATER + ThreadPool.UnsafeQueueUserWorkItem(this, false); +#else + ThreadPool.UnsafeQueueUserWorkItem(SharedWaitCallback, this); +#endif + } + + [SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Cancellation is handled separately via SharedToken")] + public Task ExecuteDirectAsync(in TState state, Func> factory, HybridCacheEntryOptions? options) + { + Debug.Assert(_factory is null, "should not already have factory field"); + Debug.Assert(factory is not null, "factory argument should be meaningful"); + + // initialize the callback state + _state = state; + _factory = factory; + _options = options; + + return BackgroundFetchAsync(); + } + [SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Cancellation is handled separately via SharedToken")] - public Task ExecuteDirectAsync(in TState state, Func> underlying, HybridCacheEntryOptions? options) + public Task ExecuteDirectAsync(in TState state, Func> factory, HybridCacheEntryOptions options) { - Debug.Assert(_underlying is null, "should not already have factory field"); - Debug.Assert(underlying is not null, "factory argument should be meaningful"); + Debug.Assert(_factoryWithOptions is null, "should not already have factory field"); + Debug.Assert(factory is not null, "factory argument should be meaningful"); // initialize the callback state _state = state; - _underlying = underlying; + _factoryWithOptions = factory; _options = options; + _factoryOptionsRevision = _options.Revision; return BackgroundFetchAsync(); } @@ -164,12 +198,18 @@ private async Task BackgroundFetchAsync() try { HybridCacheEntryFlags activeFlags = Key.Flags; + + // Track write-side flags that are required regardless of + // anything the user-supplied options or the factory say. + HybridCacheEntryFlags mandatoryWriteSideFlags = Cache._hardFlags & WriteSideFlags; + if ((activeFlags & HybridCacheEntryFlags.DisableDistributedCache) != HybridCacheEntryFlags.DisableDistributedCache) { // in order to use distributed cache, the tags and keys must be valid unicode, to avoid security complications if (!ValidateUnicodeCorrectness(Cache._logger, Key.Key, CacheItem.Tags)) { activeFlags |= HybridCacheEntryFlags.DisableDistributedCache; + mandatoryWriteSideFlags |= HybridCacheEntryFlags.DisableDistributedCacheWrite; } } @@ -225,7 +265,7 @@ private async Task BackgroundFetchAsync() // result is the wider payload including HC headers; unwrap it: HybridCachePayload.HybridCachePayloadParseResult parseResult = HybridCachePayload.TryParse( result.AsArraySegment(), Key.Key, CacheItem.Tags, Cache, out ArraySegment payload, out TimeSpan remainingTime, - out HybridCachePayload.PayloadFlags flags, out ushort entropy, out TagSet pendingTags, out Exception? fault); + out HybridCachePayload.PayloadFlags flags, out ushort entropy, out TagSet pendingTags, out long? payloadLocalSize, out Exception? fault); switch (parseResult) { case HybridCachePayload.HybridCachePayloadParseResult.Success: @@ -234,7 +274,7 @@ private async Task BackgroundFetchAsync() { // move into the payload segment (minus any framing/header/etc data) result = new(payload.Array!, payload.Offset, payload.Count, result.ReturnToPool); - SetResultAndRecycleIfAppropriate(ref result, remainingTime); + SetResultAndRecycleIfAppropriate(ref result, remainingTime, payloadLocalSize); return; } @@ -265,7 +305,14 @@ private async Task BackgroundFetchAsync() HybridCacheEventSource.Log.UnderlyingDataQueryStart(); } - newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false); + if (_factoryWithOptions is not null) + { + newValue = await _factoryWithOptions(_state!, _options!, SharedToken).ConfigureAwait(false); + } + else + { + newValue = await _factory!(_state!, SharedToken).ConfigureAwait(false); + } if (eventSourceEnabled) { @@ -289,6 +336,13 @@ private async Task BackgroundFetchAsync() throw; } + // honor any mutations the factory made to the options it received; we use Revision + // as a fast-path skip for the common case where the factory didn't touch them. + if (_factoryWithOptions is not null && _options!.Revision != _factoryOptionsRevision) + { + ApplyFactoryOptions(_options, mandatoryWriteSideFlags, ref activeFlags); + } + // check whether we're going to hit a timing problem with tag invalidation if (!Cache.IsValid(CacheItem)) { @@ -318,20 +372,26 @@ private async Task BackgroundFetchAsync() CacheItem.UnsafeSetCreationTimestamp(time); } - // If we're writing this value *anywhere*, we're going to need to serialize; this is obvious - // in the case of L2, but we also need it for L1, because MemoryCache might be enforcing - // SizeLimit (we can't know - it is an abstraction), and for *that* we need to know the item size. - // Likewise, if we're writing to a MutableCacheItem, we'll be serializing *anyway* for the payload. - // - // Rephrasing that: the only scenario in which we *do not* need to serialize is if: - // - it is an ImmutableCacheItem (so we don't need bytes for the CacheItem, L1) - // - we're not writing to L2 + // Serialization is required in any of these cases: + // - the CacheItem is mutable (the serialized bytes are the L1 and stampede storage, since each + // read deserializes a defensive copy, and they double as the L2 payload); + // - the CacheItem is immutable but we are writing to L2 (serialization produces + // the payload); + // - the CacheItem is immutable, we are writing to L1, and the entry size is not + // known up-front (MemoryCache may enforce SizeLimit, and we report the + // serialized byte length as the size). When the caller or factory has supplied + // LocalSize via the entry options the size is already known, so this last + // case does not apply. + CacheItem cacheItem = CacheItem; - bool skipSerialize = cacheItem is ImmutableCacheItem && (activeFlags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write; + long? knownLocalSize = ResolveLocalSize(); + bool skipSerialize = cacheItem is ImmutableCacheItem + && (activeFlags & HybridCacheEntryFlags.DisableDistributedCacheWrite) != 0 + && ((activeFlags & HybridCacheEntryFlags.DisableLocalCacheWrite) != 0 || knownLocalSize is not null); if (skipSerialize) { - SetImmutableResultWithoutSerialize(newValue); + SetImmutableResultWithoutSerialize(newValue, activeFlags); } else if (cacheItem.TryReserve()) { @@ -354,7 +414,7 @@ private async Task BackgroundFetchAsync() buffer = buffer.DoNotReturnToPool(); // set the underlying result for this operation (includes L1 write if appropriate) - SetResultPreSerialized(newValue, ref bufferToRelease, serializer); + SetResultPreSerialized(newValue, ref bufferToRelease, serializer, activeFlags); // Note that at this point we've already released most or all of the waiting callers. Everything // from this point onwards happens in the background, from the perspective of the calling code. @@ -385,7 +445,7 @@ private async Task BackgroundFetchAsync() { // unable to serialize (or quota exceeded); try to at least store the onwards value; this is // especially useful for immutable data types - SetResultPreSerialized(newValue, ref bufferToRelease, serializer); + SetResultPreSerialized(newValue, ref bufferToRelease, serializer, activeFlags); } // Release our hook on the CacheItem (only really important for "mutable"). @@ -435,18 +495,20 @@ private void SetDefaultResult() } } - private void SetResultAndRecycleIfAppropriate(ref BufferChunk value, TimeSpan remainingTime) + private void SetResultAndRecycleIfAppropriate(ref BufferChunk value, TimeSpan remainingTime, long? payloadLocalSize = null) { // set a result from L2 cache Debug.Assert(value.OversizedArray is not null, "expected buffer"); + // precedence: size persisted in the payload > caller-provided options.LocalSize > buffer.Length + long? localSizeOverride = payloadLocalSize ?? ResolveLocalSize(); IHybridCacheSerializer serializer = Cache.GetSerializer(); CacheItem cacheItem; switch (CacheItem) { case ImmutableCacheItem immutable: // deserialize; and store object; buffer can be recycled now - immutable.SetValue(serializer.Deserialize(new(value.OversizedArray!, value.Offset, value.Length)), value.Length); + immutable.SetValue(serializer.Deserialize(new(value.OversizedArray!, value.Offset, value.Length)), localSizeOverride ?? value.Length); value.RecycleIfAppropriate(); cacheItem = immutable; break; @@ -454,6 +516,11 @@ private void SetResultAndRecycleIfAppropriate(ref BufferChunk value, TimeSpan re // use the buffer directly as the backing in the cache-item; do *not* recycle now mutable.SetValue(ref value, serializer); mutable.DebugOnlyTrackBuffer(Cache); + if (localSizeOverride is { } mutableSize) + { + mutable.SetLocalSizeOverride(mutableSize); + } + cacheItem = mutable; break; default: @@ -461,12 +528,23 @@ private void SetResultAndRecycleIfAppropriate(ref BufferChunk value, TimeSpan re break; } - SetResult(cacheItem, remainingTime); + SetResult(cacheItem, Key.Flags, remainingTime); + } + + private long? ResolveLocalSize() + { + // factory mutations are applied directly to _options (which is a clone). null means "use + // implementation default", per the LocalSize API contract — which we honor by falling + // back to HybridCacheOptions.DefaultEntryOptions.LocalSize (also nullable). + return _options?.LocalSize ?? Cache._defaultLocalSize; } - private void SetImmutableResultWithoutSerialize(T value) + private void SetImmutableResultWithoutSerialize(T value, HybridCacheEntryFlags activeFlags) { - Debug.Assert((Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write, "Only expected if L1+L2 disabled"); + Debug.Assert( + (activeFlags & HybridCacheEntryFlags.DisableDistributedCacheWrite) != 0 + && ((activeFlags & HybridCacheEntryFlags.DisableLocalCacheWrite) != 0 || ResolveLocalSize() is not null), + "Only expected if L2 is disabled and either L1 is disabled or LocalSize is known."); // set a result from a value we calculated directly CacheItem cacheItem; @@ -474,7 +552,7 @@ private void SetImmutableResultWithoutSerialize(T value) { case ImmutableCacheItem immutable: // no serialize needed - immutable.SetValue(value, size: -1); + immutable.SetValue(value, size: ResolveLocalSize() ?? -1); cacheItem = immutable; break; default: @@ -482,10 +560,10 @@ private void SetImmutableResultWithoutSerialize(T value) break; } - SetResult(cacheItem); + SetResult(cacheItem, activeFlags); } - private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer? serializer) + private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer? serializer, HybridCacheEntryFlags activeFlags) { // set a result from a value we calculated directly that // has ALREADY BEEN SERIALIZED (we can optionally consume this buffer) @@ -494,7 +572,7 @@ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCach { case ImmutableCacheItem immutable: // no serialize needed - immutable.SetValue(value, size: buffer.Length); + immutable.SetValue(value, size: ResolveLocalSize() ?? buffer.Length); cacheItem = immutable; // (but leave the buffer alone) @@ -511,6 +589,11 @@ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCach mutable.DebugOnlyTrackBuffer(Cache); } + if (ResolveLocalSize() is { } mutableSize) + { + mutable.SetLocalSizeOverride(mutableSize); + } + cacheItem = mutable; break; default: @@ -518,14 +601,14 @@ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCach break; } - SetResult(cacheItem); + SetResult(cacheItem, activeFlags); } - private void SetResult(CacheItem value) => SetResult(value, TimeSpan.MaxValue); + private void SetResult(CacheItem value, HybridCacheEntryFlags activeFlags) => SetResult(value, activeFlags, TimeSpan.MaxValue); - private void SetResult(CacheItem value, TimeSpan maxRelativeTime) + private void SetResult(CacheItem value, HybridCacheEntryFlags activeFlags, TimeSpan maxRelativeTime) { - if ((Key.Flags & HybridCacheEntryFlags.DisableLocalCacheWrite) == 0) + if ((activeFlags & HybridCacheEntryFlags.DisableLocalCacheWrite) == 0) { Cache.SetL1(Key.Key, value, _options, maxRelativeTime); // we can do this without a TCS, for SetValue } @@ -536,6 +619,29 @@ private void SetResult(CacheItem value, TimeSpan maxRelativeTime) _ = _result.TrySetResult(value); } } + + private const HybridCacheEntryFlags WriteSideFlags = + HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite | HybridCacheEntryFlags.DisableCompression; + + /// + /// Applies factory mutations to the active flags after the factory callback has executed. + /// Only write-side flags are honored; read-side flags are ignored (reads already happened). + /// The factory's write-side flags fully replace the user-supplied write-side flags, but + /// are always preserved. Expiration / LocalCacheExpiration / LocalSize mutations need no + /// action here because the factory wrote directly to , which is read + /// by SetL1 / SetL2Async / ResolveLocalSize. LocalSize is validated via + /// since this is the only point at which factory mutations are observed. + /// + private static void ApplyFactoryOptions( + HybridCacheEntryOptions factoryOptions, + HybridCacheEntryFlags mandatoryWriteSideFlags, + ref HybridCacheEntryFlags activeFlags) + { + ValidateOptions(factoryOptions); + + HybridCacheEntryFlags factoryFlags = factoryOptions.Flags ?? HybridCacheEntryFlags.None; + activeFlags = (activeFlags & ~WriteSideFlags) | (factoryFlags & WriteSideFlags) | mandatoryWriteSideFlags; + } } private static bool ValidateUnicodeCorrectness(ILogger logger, string key, TagSet tags) diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.TagInvalidation.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.TagInvalidation.cs index ef5b7f1a01a..92f9aeb4309 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.TagInvalidation.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.TagInvalidation.cs @@ -21,7 +21,7 @@ internal partial class DefaultHybridCache private Task _globalInvalidateTimestamp; - public override ValueTask RemoveByTagAsync(string tag, CancellationToken token = default) + public override ValueTask RemoveByTagAsync(string tag, CancellationToken cancellationToken = default) { if (string.IsNullOrWhiteSpace(tag)) { @@ -30,7 +30,7 @@ public override ValueTask RemoveByTagAsync(string tag, CancellationToken token = long now = CurrentTimestamp(); InvalidateTagLocalCore(tag, now, isNow: true); // isNow to be 100% explicit - return InvalidateL2TagAsync(tag, now, token); + return InvalidateL2TagAsync(tag, now, cancellationToken); } public bool IsValid(CacheItem cacheItem) diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs index 93e1e5457cb..6102d05e899 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs @@ -26,13 +26,13 @@ internal sealed partial class DefaultHybridCache : HybridCache { internal const int DefaultExpirationMinutes = 5; - [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] + [SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] private readonly IDistributedCache? _backendCache; - [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] + [SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] private readonly IMemoryCache _localCache; private readonly IServiceProvider _services; // we can't resolve per-type serializers until we see each T private readonly IHybridCacheSerializerFactory[] _serializerFactories; - [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] + [SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] private readonly HybridCacheOptions _options; private readonly ILogger _logger; private readonly CacheFeatures _features; // used to avoid constant type-testing @@ -42,6 +42,7 @@ internal sealed partial class DefaultHybridCache : HybridCache private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags private readonly TimeSpan _defaultExpiration; private readonly TimeSpan _defaultLocalCacheExpiration; + private readonly long? _defaultLocalSize; private readonly int _maximumKeyLength; private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration; @@ -111,6 +112,7 @@ public DefaultHybridCache(HybridCacheOptions options, IServiceProvider services) _maximumKeyLength = _options.MaximumKeyLength; HybridCacheEntryOptions? defaultEntryOptions = _options.DefaultEntryOptions; + ValidateOptions(defaultEntryOptions); if (_backendCache is null) { @@ -120,6 +122,7 @@ public DefaultHybridCache(HybridCacheOptions options, IServiceProvider services) _defaultFlags = (defaultEntryOptions?.Flags ?? HybridCacheEntryFlags.None) | _hardFlags; _defaultExpiration = defaultEntryOptions?.Expiration ?? TimeSpan.FromMinutes(DefaultExpirationMinutes); _defaultLocalCacheExpiration = GetEffectiveLocalCacheExpiration(defaultEntryOptions) ?? _defaultExpiration; + _defaultLocalSize = defaultEntryOptions?.LocalSize; _defaultDistributedCacheExpiration = new DistributedCacheEntryOptions { AbsoluteExpirationRelativeToNow = _defaultExpiration }; #if NET9_0_OR_GREATER @@ -135,86 +138,192 @@ public DefaultHybridCache(HybridCacheOptions options, IServiceProvider services) internal HybridCacheOptions Options => _options; - public override ValueTask GetOrCreateAsync(string key, TState state, Func> underlyingDataCallback, + public override ValueTask GetOrCreateAsync(string key, TState state, Func> factory, HybridCacheEntryOptions? options = null, IEnumerable? tags = null, CancellationToken cancellationToken = default) { - bool canBeCanceled = cancellationToken.CanBeCanceled; + GetOrCreateOutcome outcome = TryBeginGetOrCreate(key, options, tags, cancellationToken, + out HybridCacheEntryFlags flags, out bool canBeCanceled, out StampedeState? stampede, out ValueTask result); + + switch (outcome) + { + case GetOrCreateOutcome.NoCache: + return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0 + ? factory(state, cancellationToken) : default; + case GetOrCreateOutcome.L1Hit: + case GetOrCreateOutcome.JoinedStampede: + return result; + } + + // GetOrCreateOutcome.NewStampede: we own this stampede and must dispatch the factory + if (canBeCanceled) + { + // *we* might cancel, but someone else might be depending on the result; start the + // work independently, then join the outcome + stampede!.QueueUserWorkItem(in state, factory, options); + return stampede.JoinAsync(_logger, cancellationToken); + } + + // we're going to run to completion; no need to get complicated + _ = stampede!.ExecuteDirectAsync(in state, factory, options); // this larger task includes L2 write etc + + return stampede.UnwrapReservedAsync(_logger); + } + + public override ValueTask GetOrCreateAsync(string key, TState state, + Func> factory, + HybridCacheEntryOptions? options = null, IEnumerable? tags = null, CancellationToken cancellationToken = default) + { + GetOrCreateOutcome outcome = TryBeginGetOrCreate(key, options, tags, cancellationToken, + out HybridCacheEntryFlags flags, out bool canBeCanceled, out StampedeState? stampede, out ValueTask result); + + // We always pass a clone (or fresh instance) to the factory so + // it can mutate the options without affecting the caller's instance. + // In the NoCache case, there's no cache to honor the mutations against, but the factory may rely on + // being able to mutate the parameter without surprising the caller. + + switch (outcome) + { + case GetOrCreateOutcome.NoCache: + return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0 + ? factory(state, CloneOptionsOrNew(options), cancellationToken) : default; + case GetOrCreateOutcome.L1Hit: + case GetOrCreateOutcome.JoinedStampede: + return result; + } + + options = CloneOptionsOrNew(options); + + if (canBeCanceled) + { + stampede!.QueueUserWorkItem(in state, factory, options); + return stampede.JoinAsync(_logger, cancellationToken); + } + + _ = stampede!.ExecuteDirectAsync(in state, factory, options); + + return stampede.UnwrapReservedAsync(_logger); + } + + private enum GetOrCreateOutcome + { + NoCache, + L1Hit, + JoinedStampede, + NewStampede + } + + // Performs the shared pre-factory work for both GetOrCreateAsync overloads. + private GetOrCreateOutcome TryBeginGetOrCreate( + string key, + HybridCacheEntryOptions? options, + IEnumerable? tags, + CancellationToken cancellationToken, + out HybridCacheEntryFlags flags, + out bool canBeCanceled, + out StampedeState? stampede, + out ValueTask result) + { + ValidateOptions(options); + + canBeCanceled = cancellationToken.CanBeCanceled; if (canBeCanceled) { cancellationToken.ThrowIfCancellationRequested(); } - HybridCacheEntryFlags flags = GetEffectiveFlags(options); + flags = GetEffectiveFlags(options); + stampede = null; + result = default; + if (!ValidateKey(key)) { - // we can't use cache, but we can still provide the data - return RunWithoutCacheAsync(flags, state, underlyingDataCallback, cancellationToken); + return GetOrCreateOutcome.NoCache; } bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled(); if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0) { - if (TryGetExisting(key, out CacheItem? typed) + if (TryGetExisting(key, out CacheItem? typed) && typed.TryGetValue(_logger, out T? value)) { - // short-circuit if (eventSourceEnabled) { HybridCacheEventSource.Log.LocalCacheHit(); } - return new(value); + result = new(value); + return GetOrCreateOutcome.L1Hit; } - else + + if (eventSourceEnabled) { - if (eventSourceEnabled) - { - HybridCacheEventSource.Log.LocalCacheMiss(); - } + HybridCacheEventSource.Log.LocalCacheMiss(); } } - if (GetOrCreateStampedeState(key, flags, out StampedeState? stampede, canBeCanceled, tags)) + if (GetOrCreateStampedeState(key, flags, out stampede, canBeCanceled, tags)) { - // new query; we're responsible for making it happen - if (canBeCanceled) - { - // *we* might cancel, but someone else might be depending on the result; start the - // work independently, then we'll with join the outcome - stampede.QueueUserWorkItem(in state, underlyingDataCallback, options); - } - else - { - // we're going to run to completion; no need to get complicated - _ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc - return stampede.UnwrapReservedAsync(_logger); - } + return GetOrCreateOutcome.NewStampede; } - else + + // joined a pre-existing stampede + if (eventSourceEnabled) { - // pre-existing query - if (eventSourceEnabled) - { - HybridCacheEventSource.Log.StampedeJoin(); - } + HybridCacheEventSource.Log.StampedeJoin(); } - return stampede.JoinAsync(_logger, cancellationToken); + result = stampede.JoinAsync(_logger, cancellationToken); + return GetOrCreateOutcome.JoinedStampede; } - public override ValueTask RemoveAsync(string key, CancellationToken token = default) + public override ValueTask RemoveAsync(string key, CancellationToken cancellationToken = default) { _localCache.Remove(key); - return _backendCache is null ? default : new(_backendCache.RemoveAsync(key, token)); + return _backendCache is null ? default : new(_backendCache.RemoveAsync(key, cancellationToken)); + } + + private static void ValidateOptions(HybridCacheEntryOptions? options) + { + if (options?.LocalSize is { } size && size < 0) + { + Throw.ArgumentException(nameof(options), + $"{nameof(HybridCacheEntryOptions)}.{nameof(HybridCacheEntryOptions.LocalSize)} must be non-negative."); + } } - public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptions? options = null, IEnumerable? tags = null, CancellationToken token = default) + internal static HybridCacheEntryOptions CloneOptionsOrNew(HybridCacheEntryOptions? options) { + if (options is null) + { + return new HybridCacheEntryOptions(); + } + +#if NET8_0_OR_GREATER + return Clone(options); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = nameof(Clone))] + extern static HybridCacheEntryOptions Clone(HybridCacheEntryOptions options); +#else + // Down-level TFMs cannot reach the internal Clone(); copy by hand. + return new HybridCacheEntryOptions + { + Expiration = options.Expiration, + LocalCacheExpiration = options.LocalCacheExpiration, + Flags = options.Flags, + LocalSize = options.LocalSize, + }; +#endif + } + + public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptions? options = null, IEnumerable? tags = null, CancellationToken cancellationToken = default) + { + ValidateOptions(options); + // since we're forcing a write: disable L1+L2 read; we'll use a direct pass-thru of the value as the callback, to reuse all the code // note also that stampede token is not shared with anyone else HybridCacheEntryFlags flags = GetEffectiveFlags(options) | (HybridCacheEntryFlags.DisableLocalCacheRead | HybridCacheEntryFlags.DisableDistributedCacheRead); - var state = new StampedeState(this, new StampedeKey(key, flags), TagSet.Create(tags), token); + var state = new StampedeState(this, new StampedeKey(key, flags), TagSet.Create(tags), cancellationToken); return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc } @@ -228,14 +337,6 @@ internal TimeSpan GetL1AbsoluteExpirationRelativeToNow(HybridCacheEntryOptions? internal HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options) => (options?.Flags | _hardFlags) ?? _defaultFlags; - private static ValueTask RunWithoutCacheAsync(HybridCacheEntryFlags flags, TState state, - Func> underlyingDataCallback, - CancellationToken cancellationToken) - { - return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0 - ? underlyingDataCallback(state, cancellationToken) : default; - } - private static TimeSpan? GetEffectiveLocalCacheExpiration(HybridCacheEntryOptions? options) { // If LocalCacheExpiration is not specified, then use option's Expiration, to keep in sync by default. diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCachePayload.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCachePayload.cs index 9199a5cfd15..50d4269f161 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCachePayload.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCachePayload.cs @@ -11,29 +11,40 @@ namespace Microsoft.Extensions.Caching.Hybrid.Internal; // logic related to the payload that we send to IDistributedCache internal static class HybridCachePayload { - // FORMAT (v1): - // fixed-size header (so that it can be reliably broadcast) + // FORMAT (v2): + // fixed-size header (so that it can be reliably broadcast) // 2 bytes: sentinel+version - // 2 bytes: entropy (this is a random, and is to help with multi-node collisions at the same time) + // 2 bytes: entropy (random, to help with multi-node collisions at the same time) // 8 bytes: creation time (UTC ticks, little-endian) // and the dynamic part // varint: flags (little-endian) // varint: payload size // varint: duration (ticks relative to creation time) + // (when PayloadFlags.HasLocalSize is set): varint: local cache size // varint: tag count // varint+utf8: key // (for each tag): varint+utf8: tagN // (payload-size bytes): payload // 2 bytes: sentinel+version (repeated, for reliability) // (at this point, all bytes *must* be exhausted, or it is treated as failure) - - // the encoding for varint etc is akin to BinaryWriter, also comparable to FormatterBinaryWriter in OutputCaching + // + // FORMAT (v1): identical to v2 but predates the optional local-size varint. + // Writers currently emit v1 when no LocalSize override is + // being persisted, so existing L2 entries remain readable across upgrades and v1-only readers + // can still read newer writers' output in the common case. v2 is emitted when LocalSize is + // being persisted; older readers reject v2 (FormatNotRecognized -> cache miss). + // + // The expectation is that readers will support reading all old formats + // (at least as long as they stay easily backwards compatible), while writers will write a couple + // of the latest formats, to make switching versions easier, while not accumulating too much cruft in the code. private const int MaxVarint64Length = 10; private const byte SentinelPrefix = 0x03; - private const byte ProtocolVersion = 0x01; - private const ushort UInt16SentinelPrefixPair = (ProtocolVersion << 8) | SentinelPrefix; + private const byte ProtocolVersion1 = 0x01; + private const byte ProtocolVersion2 = 0x02; + private const ushort UInt16SentinelPrefixPairV1 = (ProtocolVersion1 << 8) | SentinelPrefix; + private const ushort UInt16SentinelPrefixPairV2 = (ProtocolVersion2 << 8) | SentinelPrefix; private static readonly Random _entropySource = new(); // doesn't need to be cryptographic @@ -42,6 +53,10 @@ internal static class HybridCachePayload internal enum PayloadFlags : uint { None = 0, + + // When set, the dynamic part carries an additional varint encoding a per-entry + // local cache size override. + HasLocalSize = 1, } internal enum HybridCachePayloadParseResult @@ -67,6 +82,7 @@ public static int GetMaxBytes(string key, TagSet tags, int payloadSize) + MaxVarint64Length // flags + MaxVarint64Length // payload size + MaxVarint64Length // duration + + MaxVarint64Length // optional local cache size + MaxVarint64Length // tag count + 2 // trailing sentinel + version + GetMaxStringLength(key.Length) // key @@ -99,11 +115,35 @@ static int GetMaxStringLength(int charCount) => [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA5394:Do not use insecure randomness", Justification = "Not cryptographic")] public static int Write(byte[] destination, - string key, long creationTime, TimeSpan duration, PayloadFlags flags, TagSet tags, ReadOnlySequence payload) + string key, long creationTime, TimeSpan duration, PayloadFlags flags, TagSet tags, ReadOnlySequence payload, + long? localCacheSize = null) { int payloadLength = checked((int)payload.Length); - BinaryPrimitives.WriteUInt16LittleEndian(destination.AsSpan(0, 2), UInt16SentinelPrefixPair); + // a negative localCacheSize is the "reset to default" sentinel - treat it as absent. + if (localCacheSize is < 0) + { + localCacheSize = null; + } + + // v2 is used if we are persisting a LocalSize override; otherwise stay on v1 so that + // upgrades don't invalidate every existing L2 entry and older readers can still read + // entries written by newer writers when no override is used. + ushort sentinel; + if (localCacheSize is null) + { + sentinel = UInt16SentinelPrefixPairV1; + + // defensive: never persist a stale HasLocalSize bit without an accompanying value + flags &= ~PayloadFlags.HasLocalSize; + } + else + { + sentinel = UInt16SentinelPrefixPairV2; + flags |= PayloadFlags.HasLocalSize; + } + + BinaryPrimitives.WriteUInt16LittleEndian(destination.AsSpan(0, 2), sentinel); BinaryPrimitives.WriteUInt16LittleEndian(destination.AsSpan(2, 2), (ushort)_entropySource.Next(0, 0x010000)); // Next is exclusive at RHS BinaryPrimitives.WriteInt64LittleEndian(destination.AsSpan(4, 8), creationTime); int len = 12; @@ -117,6 +157,11 @@ public static int Write(byte[] destination, Write7BitEncodedInt64(destination, ref len, (uint)flags); Write7BitEncodedInt64(destination, ref len, (ulong)payloadLength); Write7BitEncodedInt64(destination, ref len, (ulong)durationTicks); + if (localCacheSize is { } size) + { + Write7BitEncodedInt64(destination, ref len, (ulong)size); + } + Write7BitEncodedInt64(destination, ref len, (ulong)tags.Count); WriteString(destination, ref len, key); switch (tags.Count) @@ -137,7 +182,7 @@ public static int Write(byte[] destination, payload.CopyTo(destination.AsSpan(len, payloadLength)); len += payloadLength; - BinaryPrimitives.WriteUInt16LittleEndian(destination.AsSpan(len, 2), UInt16SentinelPrefixPair); + BinaryPrimitives.WriteUInt16LittleEndian(destination.AsSpan(len, 2), sentinel); return len + 2; static void Write7BitEncodedInt64(byte[] target, ref int offset, ulong value) @@ -171,7 +216,7 @@ static void WriteString(byte[] target, ref int offset, string value) "SA1122:Use string.Empty for empty strings", Justification = "Subjective, but; ugly")] [System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.OrderingRules", "SA1204:Static elements should appear before instance elements", Justification = "False positive?")] public static HybridCachePayloadParseResult TryParse(ArraySegment source, string key, TagSet knownTags, DefaultHybridCache cache, - out ArraySegment payload, out TimeSpan remainingTime, out PayloadFlags flags, out ushort entropy, out TagSet pendingTags, out Exception? fault) + out ArraySegment payload, out TimeSpan remainingTime, out PayloadFlags flags, out ushort entropy, out TagSet pendingTags, out long? localCacheSize, out Exception? fault) { fault = null; @@ -180,6 +225,7 @@ public static HybridCachePayloadParseResult TryParse(ArraySegment source, payload = default; flags = 0; remainingTime = TimeSpan.Zero; + localCacheSize = null; string[] pendingTagBuffer = []; int pendingTagsCount = 0; @@ -194,128 +240,142 @@ public static HybridCachePayloadParseResult TryParse(ArraySegment source, char[] scratch = []; try { - switch (BinaryPrimitives.ReadUInt16LittleEndian(bytes)) + ushort sentinel = BinaryPrimitives.ReadUInt16LittleEndian(bytes); + switch (sentinel) { - case UInt16SentinelPrefixPair: - entropy = BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(2)); - long creationTime = BinaryPrimitives.ReadInt64LittleEndian(bytes.Slice(4)); - bytes = bytes.Slice(12); // the end of the fixed part + case UInt16SentinelPrefixPairV1: + case UInt16SentinelPrefixPairV2: + break; + default: + return HybridCachePayloadParseResult.FormatNotRecognized; + } - if (cache.IsWildcardExpired(creationTime)) - { - return HybridCachePayloadParseResult.ExpiredByWildcard; - } + entropy = BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(2)); + long creationTime = BinaryPrimitives.ReadInt64LittleEndian(bytes.Slice(4)); + bytes = bytes.Slice(12); // the end of the fixed part - if (!TryRead7BitEncodedInt64(ref bytes, out ulong u64)) // flags - { - return HybridCachePayloadParseResult.InvalidData; - } + if (cache.IsWildcardExpired(creationTime)) + { + return HybridCachePayloadParseResult.ExpiredByWildcard; + } - flags = (PayloadFlags)u64; + if (!TryRead7BitEncodedInt64(ref bytes, out ulong u64)) // flags + { + return HybridCachePayloadParseResult.InvalidData; + } - if (!TryRead7BitEncodedInt64(ref bytes, out u64) || u64 > int.MaxValue) // payload length - { - return HybridCachePayloadParseResult.InvalidData; - } + flags = (PayloadFlags)u64; - int payloadLength = (int)u64; + if (!TryRead7BitEncodedInt64(ref bytes, out u64) || u64 > int.MaxValue) // payload length + { + return HybridCachePayloadParseResult.InvalidData; + } - if (!TryRead7BitEncodedInt64(ref bytes, out ulong duration)) // duration - { - return HybridCachePayloadParseResult.InvalidData; - } + int payloadLength = (int)u64; - var remainingTicks = (creationTime + (long)duration) - now; - if (remainingTicks <= 0) - { - return HybridCachePayloadParseResult.ExpiredByEntry; - } + if (!TryRead7BitEncodedInt64(ref bytes, out ulong duration)) // duration + { + return HybridCachePayloadParseResult.InvalidData; + } - remainingTime = DefaultHybridCache.TicksToTimeSpan(remainingTicks); + var remainingTicks = (creationTime + (long)duration) - now; + if (remainingTicks <= 0) + { + return HybridCachePayloadParseResult.ExpiredByEntry; + } - if (!TryRead7BitEncodedInt64(ref bytes, out u64) || u64 > int.MaxValue) // tag count - { - return HybridCachePayloadParseResult.InvalidData; - } + remainingTime = DefaultHybridCache.TicksToTimeSpan(remainingTicks); - int tagCount = (int)u64; + if ((flags & PayloadFlags.HasLocalSize) != 0) + { + if (!TryRead7BitEncodedInt64(ref bytes, out ulong sizeU64) || sizeU64 > long.MaxValue) + { + return HybridCachePayloadParseResult.InvalidData; + } - if (!TryReadString(ref bytes, ref scratch, out ReadOnlySpan stringSpan)) - { - return HybridCachePayloadParseResult.InvalidData; - } + localCacheSize = (long)sizeU64; + } - if (!stringSpan.SequenceEqual(key.AsSpan())) - { - return HybridCachePayloadParseResult.InvalidKey; // key must match! - } + if (!TryRead7BitEncodedInt64(ref bytes, out u64) || u64 > int.MaxValue) // tag count + { + return HybridCachePayloadParseResult.InvalidData; + } - for (int i = 0; i < tagCount; i++) - { - if (!TryReadString(ref bytes, ref scratch, out stringSpan)) - { - return HybridCachePayloadParseResult.InvalidData; - } - - bool isTagExpired; - bool isPending; - if (knownTags.TryFind(stringSpan, out string? tagString)) - { - // prefer to re-use existing tag strings when they exist - isTagExpired = cache.IsTagExpired(tagString, creationTime, out isPending); - } - else - { - // if an unknown tag; we might need to juggle - isTagExpired = cache.IsTagExpired(stringSpan, creationTime, out isPending); - } - - if (isPending) - { - // might be expired, but the operation is still in-flight - if (pendingTagsCount == pendingTagBuffer.Length) - { - string[] newBuffer = ArrayPool.Shared.Rent(Math.Max(4, pendingTagsCount * 2)); - pendingTagBuffer.CopyTo(newBuffer, 0); - ArrayPool.Shared.Return(pendingTagBuffer); - pendingTagBuffer = newBuffer; - } - - pendingTagBuffer[pendingTagsCount++] = tagString ?? stringSpan.ToString(); - } - else if (isTagExpired) - { - // definitely an expired tag - return HybridCachePayloadParseResult.ExpiredByTag; - } - } + int tagCount = (int)u64; - if (bytes.Length != payloadLength + 2 - || BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(payloadLength)) != UInt16SentinelPrefixPair) - { - return HybridCachePayloadParseResult.InvalidData; - } + if (!TryReadString(ref bytes, ref scratch, out ReadOnlySpan stringSpan)) + { + return HybridCachePayloadParseResult.InvalidData; + } - int start = source.Offset + source.Count - (payloadLength + 2); - payload = new(source.Array!, start, payloadLength); + if (!stringSpan.SequenceEqual(key.AsSpan())) + { + return HybridCachePayloadParseResult.InvalidKey; // key must match! + } - // finalize the pending tag buffer (in-flight tag expirations) - switch (pendingTagsCount) + for (int i = 0; i < tagCount; i++) + { + if (!TryReadString(ref bytes, ref scratch, out stringSpan)) + { + return HybridCachePayloadParseResult.InvalidData; + } + + bool isTagExpired; + bool isPending; + if (knownTags.TryFind(stringSpan, out string? tagString)) + { + // prefer to re-use existing tag strings when they exist + isTagExpired = cache.IsTagExpired(tagString, creationTime, out isPending); + } + else + { + // if an unknown tag; we might need to juggle + isTagExpired = cache.IsTagExpired(stringSpan, creationTime, out isPending); + } + + if (isPending) + { + // might be expired, but the operation is still in-flight + if (pendingTagsCount == pendingTagBuffer.Length) { - case 0: - break; - case 1: - pendingTags = new(pendingTagBuffer[0]); - break; - default: - pendingTags = new(pendingTagBuffer.AsSpan(0, pendingTagsCount).ToArray()); - break; + string[] newBuffer = ArrayPool.Shared.Rent(Math.Max(4, pendingTagsCount * 2)); + pendingTagBuffer.CopyTo(newBuffer, 0); + ArrayPool.Shared.Return(pendingTagBuffer); + pendingTagBuffer = newBuffer; } - return HybridCachePayloadParseResult.Success; + pendingTagBuffer[pendingTagsCount++] = tagString ?? stringSpan.ToString(); + } + else if (isTagExpired) + { + // definitely an expired tag + return HybridCachePayloadParseResult.ExpiredByTag; + } + } + + if (bytes.Length != payloadLength + 2 + || BinaryPrimitives.ReadUInt16LittleEndian(bytes.Slice(payloadLength)) != sentinel) + { + return HybridCachePayloadParseResult.InvalidData; + } + + int start = source.Offset + source.Count - (payloadLength + 2); + payload = new(source.Array!, start, payloadLength); + + // finalize the pending tag buffer (in-flight tag expirations) + switch (pendingTagsCount) + { + case 0: + break; + case 1: + pendingTags = new(pendingTagBuffer[0]); + break; default: - return HybridCachePayloadParseResult.FormatNotRecognized; + pendingTags = new(pendingTagBuffer.AsSpan(0, pendingTagsCount).ToArray()); + break; } + + return HybridCachePayloadParseResult.Success; } catch (Exception ex) { @@ -385,7 +445,11 @@ static bool TryRead7BitEncodedInt64(ref ReadOnlySpan buffer, out ulong res int index = 0; for (int shift = 0; shift < MaxBytesWithoutOverflow * 7; shift += 7) { - // ReadByte handles end of stream cases for us. + if (index >= buffer.Length) + { + return false; // truncated + } + byteReadJustNow = buffer[index++]; result |= (byteReadJustNow & 0x7Ful) << shift; @@ -400,10 +464,15 @@ static bool TryRead7BitEncodedInt64(ref ReadOnlySpan buffer, out ulong res // the value of this byte must fit within 1 bit (64 - 63), // and it must not have the high bit set. + if (index >= buffer.Length) + { + return false; // truncated + } + byteReadJustNow = buffer[index++]; if (byteReadJustNow > 0b_1u) { - throw new OverflowException(); + return false; } result |= (ulong)byteReadJustNow << (MaxBytesWithoutOverflow * 7); diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FactoryOptionsTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FactoryOptionsTests.cs new file mode 100644 index 00000000000..12a4aca23ad --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FactoryOptionsTests.cs @@ -0,0 +1,554 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Reflection; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Hybrid.Internal; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Internal; +using Xunit.Abstractions; +using static Microsoft.Extensions.Caching.Hybrid.Tests.DistributedCacheTests; +using static Microsoft.Extensions.Caching.Hybrid.Tests.L2Tests; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +// Covers mutations the factory makes to the HybridCacheEntryOptions it is handed. +public class FactoryOptionsTests(ITestOutputHelper log) : IClassFixture +{ + private static ServiceProvider GetDefaultCache(out DefaultHybridCache cache, Action? config = null) + { + var services = new ServiceCollection(); + config?.Invoke(services); + services.AddHybridCache(); + ServiceProvider provider = services.BuildServiceProvider(); + cache = Assert.IsType(provider.GetRequiredService()); + return provider; + } + + private static ServiceProvider BuildCacheWithL2(ITestOutputHelper log, out DefaultHybridCache cache, out CapturingCache localCache) + { + var captured = new CapturingCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); + localCache = captured; + return GetDefaultCache(out cache, services => services.AddSingleton(new LoggingCache(log, captured))); + } + + private static async Task WaitForBackgroundL2WriteAsync(CapturingCache cache, string key) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + try + { + await cache.WaitForWriteAsync(key, cts.Token); + } + catch (OperationCanceledException) when (cts.IsCancellationRequested) + { + Assert.Fail($"Timed out waiting for background L2 write for key '{key}'."); + } + } + + [Fact] + public async Task FactoryCanReEnableL2Write_ThatCallerDisabled() + { + // Caller disabled L2 writes; factory clears the flag — value must be persisted to L2. + using var provider = BuildCacheWithL2(log, out var cache, out var localCache); + string key = nameof(FactoryCanReEnableL2Write_ThatCallerDisabled); + + _ = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + entryOptions.Flags = HybridCacheEntryFlags.None; + return new ValueTask(Guid.NewGuid()); + }, + options: new HybridCacheEntryOptions + { + Expiration = TimeSpan.FromMinutes(1), + Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite, + }); + + await WaitForBackgroundL2WriteAsync(localCache, key); + Assert.NotNull(localCache.Get(key)); + } + + [Fact] + public async Task FactoryCanDisableL2Write_ThatCallerEnabled() + { + // Symmetric tightening: caller allowed L2 writes (None), factory disables them. + using var provider = BuildCacheWithL2(log, out var cache, out var localCache); + string key = nameof(FactoryCanDisableL2Write_ThatCallerEnabled); + + _ = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + entryOptions.Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite; + return new ValueTask(Guid.NewGuid()); + }, + options: new HybridCacheEntryOptions + { + Expiration = TimeSpan.FromMinutes(1), + Flags = HybridCacheEntryFlags.None, + }); + + await Task.Delay(500); + Assert.Null(localCache.Get(key)); + } + + [Fact] + public async Task FactoryCanEnableL1Write_ThatCallerDisabled() + { + // L1 counterpart of FactoryCanReEnableL2Write: caller passed DisableLocalCacheWrite, + // factory clears it. If the override sticks, a subsequent read returns the same value + // from L1 without re-invoking the factory. + using var provider = GetDefaultCache(out var cache); + + string key = nameof(FactoryCanEnableL1Write_ThatCallerDisabled); + int factoryCalls = 0; + + var first = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + Interlocked.Increment(ref factoryCalls); + entryOptions.Flags = HybridCacheEntryFlags.None; + return new ValueTask(Guid.NewGuid()); + }, + options: new HybridCacheEntryOptions + { + Expiration = TimeSpan.FromMinutes(1), + Flags = HybridCacheEntryFlags.DisableLocalCacheWrite, + }); + + Assert.Equal(1, factoryCalls); + + // Second call should be served from L1 — same Guid, factory not invoked again. + var second = await cache.GetOrCreateAsync(key, _ => new(Guid.NewGuid())); + Assert.Equal(first, second); + Assert.Equal(1, factoryCalls); + } + + [Fact] + public async Task FactoryExpirationMutation_PropagatesToL2() + { + // Factory mutates Expiration; the L2 backend must receive DistributedCacheEntryOptions + // whose AbsoluteExpirationRelativeToNow matches the factory-set value. + var captured = new CapturingCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); + using var provider = GetDefaultCache(out var cache, services => services.AddSingleton(captured)); + + var factoryExpiration = TimeSpan.FromMinutes(7); + _ = await cache.GetOrCreateAsync( + nameof(FactoryExpirationMutation_PropagatesToL2), + (entryOptions, _) => + { + entryOptions.Expiration = factoryExpiration; + return new ValueTask(Guid.NewGuid()); + }, + options: new HybridCacheEntryOptions { Expiration = TimeSpan.FromMinutes(1) }); + + await WaitForBackgroundL2WriteAsync(captured, nameof(FactoryExpirationMutation_PropagatesToL2)); + Assert.Equal(factoryExpiration, captured.LastSetOptions?.AbsoluteExpirationRelativeToNow); + } + + [Fact] + public async Task FactoryLocalCacheExpirationMutation_ShortensL1Only() + { + // Factory sets LocalCacheExpiration tighter than Expiration. After the L1 entry expires + // but before the overall entry does, the next call must re-fetch from L2 (factory not + // invoked again), proving the factory-set L1 expiration was honored. + var clock = new FakeTime(); + using var l1 = new MemoryCache(new MemoryCacheOptions { Clock = clock }); + var l2 = new LoggingCache(log, new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions { Clock = clock }))); + + using var provider = GetDefaultCache(out var cache, services => + { + services.AddSingleton(clock); + services.AddSingleton(clock); + services.AddSingleton(l1); + services.AddSingleton(l2); + }); + + string key = nameof(FactoryLocalCacheExpirationMutation_ShortensL1Only); + int factoryCalls = 0; + Func> factory = (entryOptions, _) => + { + Interlocked.Increment(ref factoryCalls); + entryOptions.LocalCacheExpiration = TimeSpan.FromSeconds(30); + return new ValueTask(Guid.NewGuid()); + }; + + var options = new HybridCacheEntryOptions { Expiration = TimeSpan.FromMinutes(5) }; + + var first = await cache.GetOrCreateAsync(key, factory, options); + Assert.Equal(1, factoryCalls); + + // Past factory-set L1 expiration (30s) but well before overall expiration (5m). + clock.Add(TimeSpan.FromSeconds(45)); + + var second = await cache.GetOrCreateAsync(key, factory, options); + + // Same value (came from L2 round-trip), factory not invoked again. + Assert.Equal(first, second); + Assert.Equal(1, factoryCalls); + } + + [Fact] + public async Task FactoryLocalSizeMutation_HonoredForL1SizeAccounting() + { + // Tight L1 SizeLimit + a payload large enough to exceed it. Without the override the + // entry would be evicted from L1; the factory sets LocalSize = 1 to make the entry fit. + using var provider = GetDefaultCache(out var cache, services => services.AddMemoryCache(options => options.SizeLimit = 5)); + + string key = nameof(FactoryLocalSizeMutation_HonoredForL1SizeAccounting); + int factoryCalls = 0; + + // Use a string payload large enough that its serialized size would otherwise exceed + // the L1 SizeLimit (the default L1 size for a string is its byte length). + string payload = new('x', 256); + + var first = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + Interlocked.Increment(ref factoryCalls); + entryOptions.LocalSize = 1; + return new ValueTask(payload); + }, + options: new HybridCacheEntryOptions + { + Expiration = TimeSpan.FromMinutes(1), + Flags = HybridCacheEntryFlags.DisableDistributedCache, // no L2; force L1-only path + }); + + Assert.Equal(1, factoryCalls); + + // Second call: served from L1 because the override kept the entry under SizeLimit. + var second = await cache.GetOrCreateAsync(key, _ => new(Guid.NewGuid().ToString())); + Assert.Equal(first, second); + Assert.Equal(1, factoryCalls); + } + + [Fact] + public async Task FactoryMutations_DoNotLeakToCallerOptionsInstance() + { + // The implementation passes a clone of the caller's options to the + // factory so that any mutations the factory performs do not bleed back into the caller's + // shared instance. A caller that reuses the same options across many calls must see the + // exact values it constructed. + using var provider = GetDefaultCache(out var cache); + + var callerOptions = new HybridCacheEntryOptions + { + Expiration = TimeSpan.FromMinutes(1), + LocalCacheExpiration = TimeSpan.FromSeconds(30), + LocalSize = 100, + Flags = HybridCacheEntryFlags.None, + }; + + // Snapshot before + var origExpiration = callerOptions.Expiration; + var origLocalCacheExpiration = callerOptions.LocalCacheExpiration; + var origLocalSize = callerOptions.LocalSize; + var origFlags = callerOptions.Flags; + + _ = await cache.GetOrCreateAsync( + nameof(FactoryMutations_DoNotLeakToCallerOptionsInstance), + (entryOptions, _) => + { + // Aggressively mutate everything; none of this should leak. + Assert.NotSame(callerOptions, entryOptions); + entryOptions.Expiration = TimeSpan.FromHours(99); + entryOptions.LocalCacheExpiration = TimeSpan.FromHours(99); + entryOptions.LocalSize = 9_999_999; + entryOptions.Flags = HybridCacheEntryFlags.DisableDistributedCache | HybridCacheEntryFlags.DisableLocalCache; + return new ValueTask(Guid.NewGuid()); + }, + options: callerOptions); + + Assert.Equal(origExpiration, callerOptions.Expiration); + Assert.Equal(origLocalCacheExpiration, callerOptions.LocalCacheExpiration); + Assert.Equal(origLocalSize, callerOptions.LocalSize); + Assert.Equal(origFlags, callerOptions.Flags); + } + + [Fact] + public async Task FactoryReceivesUsableOptions_WhenCallerPassedNull() + { + // The options-aware overload must hand the factory a real, mutable instance even when + // the caller did not supply one. + using var provider = BuildCacheWithL2(log, out var cache, out var localCache); + string key = nameof(FactoryReceivesUsableOptions_WhenCallerPassedNull); + int factoryCalls = 0; + + _ = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + Interlocked.Increment(ref factoryCalls); + Assert.NotNull(entryOptions); + entryOptions.Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite; + return new ValueTask(Guid.NewGuid()); + }); + + Assert.Equal(1, factoryCalls); + + // The factory's mutation must have taken effect — no value written to L2. + await Task.Delay(500); + Assert.Null(localCache.Get(key)); + } + + [Fact] + public async Task FactoryLocalSize_PersistedInL2_AndReappliedOnL2Reload() + { + // The factory-set LocalSize must be persisted into the L2 payload so a *different* cache + // instance reading from the shared L2 still gets the size override applied to its L1 entry. + // + // Setup: + // Cache A: unlimited L1 SizeLimit; factory sets LocalSize=1, payload is 256 bytes. + // Cache B: shares L2 with A, has SizeLimit=5. On its first read it must fetch from L2, + // and the persisted LocalSize=1 must be reapplied to its L1 entry so the (256-byte) + // value fits. A second read on B then comes from L1 — observable via the LoggingCache + // L2 op count not increasing. + var sharedL2 = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + string key = nameof(FactoryLocalSize_PersistedInL2_AndReappliedOnL2Reload); + string payload = new('x', 256); + + // ---- Cache A: writes with factory-set LocalSize override ---- + var capturingA = new CapturingCache(sharedL2); + var servicesA = new ServiceCollection(); + servicesA.AddMemoryCache(); // unlimited + servicesA.AddSingleton(capturingA); + servicesA.AddHybridCache(); + using (var providerA = servicesA.BuildServiceProvider()) + { + var cacheA = providerA.GetRequiredService(); + _ = await cacheA.GetOrCreateAsync( + key, + (entryOptions, _) => + { + entryOptions.LocalSize = 1; + return new ValueTask(payload); + }, + options: new HybridCacheEntryOptions { Expiration = TimeSpan.FromMinutes(5) }); + + await WaitForBackgroundL2WriteAsync(capturingA, key); + } + + // Confirm the value was actually written to L2 by Cache A. + Assert.NotNull(sharedL2.Get(key)); + + // ---- Cache B: shares L2, has a tight SizeLimit that the raw payload would exceed ---- + var loggingB = new LoggingCache(log, sharedL2); + var servicesB = new ServiceCollection(); + servicesB.AddMemoryCache(options => options.SizeLimit = 5); + servicesB.AddSingleton(loggingB); + servicesB.AddHybridCache(); + using var providerB = servicesB.BuildServiceProvider(); + var cacheB = providerB.GetRequiredService(); + + // First call: L1 miss, fetches from L2; persisted LocalSize override must let it fit in L1. + var firstB = await cacheB.GetOrCreateAsync(key, _ => new(Guid.NewGuid().ToString())); + Assert.Equal(payload, firstB); + int opsAfterFirst = loggingB.OpCount; + + // Second call must hit L1 — no further L2 traffic. If the LocalSize override + // wasn't reapplied, the entry would have been evicted from L1 immediately and + // this call would have to re-read L2. + var secondB = await cacheB.GetOrCreateAsync(key, _ => new(Guid.NewGuid().ToString())); + Assert.Equal(payload, secondB); + Assert.Equal(opsAfterFirst, loggingB.OpCount); + } + + // Test-only IDistributedCache that wraps another IDistributedCache and adds: + // - LastSetOptions: the DistributedCacheEntryOptions from the most recent Set/SetAsync, + // for tests that assert on the options the HybridCache layer produced. + // - WaitForWriteAsync(key): a Task that completes after the (possibly background) write + // for that key has finished, so tests can wait deterministically instead of sleeping. + private sealed class CapturingCache(IDistributedCache tail) : IDistributedCache + { + private readonly ConcurrentDictionary> _writes = new(StringComparer.Ordinal); + + public DistributedCacheEntryOptions? LastSetOptions { get; private set; } + + public Task WaitForWriteAsync(string key, CancellationToken cancellationToken = default) + { + var tcs = SignalFor(key); + if (!cancellationToken.CanBeCanceled) + { + return tcs.Task; + } + + return WaitWithCancellationAsync(tcs.Task, cancellationToken); + + static async Task WaitWithCancellationAsync(Task task, CancellationToken ct) + { + var cancelTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (ct.Register(static state => ((TaskCompletionSource)state!).TrySetCanceled(), cancelTcs)) + { + var completed = await Task.WhenAny(task, cancelTcs.Task).ConfigureAwait(false); + await completed.ConfigureAwait(false); + } + } + } + + public byte[]? Get(string key) => tail.Get(key); + public Task GetAsync(string key, CancellationToken token = default) => tail.GetAsync(key, token); + public void Refresh(string key) => tail.Refresh(key); + public Task RefreshAsync(string key, CancellationToken token = default) => tail.RefreshAsync(key, token); + public void Remove(string key) => tail.Remove(key); + public Task RemoveAsync(string key, CancellationToken token = default) => tail.RemoveAsync(key, token); + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + LastSetOptions = options; + tail.Set(key, value, options); + _ = SignalFor(key).TrySetResult(true); + } + + public async Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + { + LastSetOptions = options; + await tail.SetAsync(key, value, options, token).ConfigureAwait(false); + _ = SignalFor(key).TrySetResult(true); + } + + private TaskCompletionSource SignalFor(string key) + => _writes.GetOrAdd(key, _ => new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously)); + } + + [Fact] + public void CloneOptionsOrNew_CopiesEveryPublicWritableProperty() + { + // Guards against silent data loss when HybridCacheEntryOptions + // gains a new property and the down-level branch of CloneOptionsOrNew is not updated. + // Also exercises the UnsafeAccessor path on net8.0+. + var writableProps = typeof(HybridCacheEntryOptions) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => p.CanRead && p.CanWrite) + // Revision is an internal mutation counter used by StampedeStateT to detect + // factory mutations; a clone deliberately starts fresh so the diff is meaningful. + .Where(p => p.Name != "Revision") + .ToArray(); + + Assert.NotEmpty(writableProps); + + var source = new HybridCacheEntryOptions(); + var expected = new Dictionary(writableProps.Length); + foreach (var prop in writableProps) + { + object? value = MakeDistinctiveValue(prop.PropertyType, prop.Name); + prop.SetValue(source, value); + expected[prop.Name] = value; + } + + var clone = DefaultHybridCache.CloneOptionsOrNew(source); + + Assert.NotSame(source, clone); + foreach (var prop in writableProps) + { + Assert.Equal(expected[prop.Name], prop.GetValue(clone)); + } + } + + private static object MakeDistinctiveValue(Type t, string propName) + { + var underlying = Nullable.GetUnderlyingType(t) ?? t; + + // Per-property unique values so cross-wired assignments + // (e.g. Expiration <-> LocalCacheExpiration) are also caught. + if (underlying == typeof(TimeSpan)) + { + return TimeSpan.FromSeconds((StableHash(propName) % 3600) + 1); + } + + if (underlying == typeof(long)) + { + return (long)StableHash(propName) + 1; + } + + if (underlying.IsEnum) + { + var nonDefault = Enum.GetValues(underlying).Cast() + .FirstOrDefault(v => !v.Equals(Activator.CreateInstance(underlying))); + return nonDefault ?? Activator.CreateInstance(underlying)!; + } + + throw new NotSupportedException( + $"HybridCacheEntryOptions has a new property '{propName}' of type {underlying.FullName}. " + + $"Add a distinctive value generator here AND update DefaultHybridCache.CloneOptionsOrNew."); + } + + private static int StableHash(string s) + { + // FNV-ish stable hash; avoids dependence on string.GetHashCode randomization. + int h = 17; + foreach (char c in s) + { + h = unchecked((h * 31) + c); + } + + return Math.Abs(h); + } + + [Fact] + public async Task FactoryNegativeLocalSize_Throws() + { + using var provider = GetDefaultCache(out var cache); + + var ex = await Assert.ThrowsAsync(() => cache.GetOrCreateAsync( + nameof(FactoryNegativeLocalSize_Throws), + (entryOptions, _) => + { + entryOptions.LocalSize = -1; + return new ValueTask(0); + }).AsTask()); + + Assert.Equal("options", ex.ParamName); + } + + [Fact] + public async Task DefaultEntryOptionsLocalSize_AppliedWhenCallerOmitsIt() + { + // We prove that DefaultEntryOptions are honored by setting a tight + // L1 SizeLimit + a default LocalSize=1 so a 256-byte payload (which would normally exceed + // SizeLimit and be evicted) survives in L1. A second call must hit L1 (same value back, + // factory not re-invoked) without the caller ever supplying per-call options. + string key = nameof(DefaultEntryOptionsLocalSize_AppliedWhenCallerOmitsIt); + int factoryCalls = 0; + + using var provider = GetDefaultCache(out var cache, services => + { + services.AddMemoryCache(o => o.SizeLimit = 5); + services.Configure(o => o.DefaultEntryOptions = new HybridCacheEntryOptions + { + LocalSize = 1, + Flags = HybridCacheEntryFlags.DisableDistributedCache + }); + }); + + string payload = new('x', 256); + var first = await cache.GetOrCreateAsync(key, _ => + { + Interlocked.Increment(ref factoryCalls); + return new ValueTask(payload); + }); + Assert.Equal(1, factoryCalls); + + var second = await cache.GetOrCreateAsync(key, _ => new ValueTask(Guid.NewGuid().ToString())); + Assert.Equal(first, second); + Assert.Equal(1, factoryCalls); + } + + [Fact] + public void DefaultEntryOptionsNegativeLocalSize_ThrowsAtConstruction() + { + var services = new ServiceCollection(); + services.AddHybridCache(); + services.Configure(o => o.DefaultEntryOptions = new HybridCacheEntryOptions { LocalSize = -1 }); + using var provider = services.BuildServiceProvider(); + + var ex = Assert.Throws(() => provider.GetRequiredService()); + Assert.Equal("options", ex.ParamName); + } +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FunctionalTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FunctionalTests.cs index 730013dbe4f..dd36ae1a3f3 100644 --- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FunctionalTests.cs +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/FunctionalTests.cs @@ -78,6 +78,31 @@ public async Task RemoveMultipleKeysViaArray() Assert.Equal(96, await cache.GetOrCreateAsync(key, _ => new ValueTask(96))); } - private static string Me([CallerMemberName] string caller = "") => caller; + [Fact] + public async Task GetOrCreateAsync_RejectsNegativeLocalSize() + { + using var provider = GetDefaultCache(out var cache); + var options = new HybridCacheEntryOptions { LocalSize = -1 }; + + var ex = await Assert.ThrowsAsync( + () => cache.GetOrCreateAsync(Me(), _ => new ValueTask(42), options).AsTask()); + Assert.Equal("options", ex.ParamName); + + ex = await Assert.ThrowsAsync( + () => cache.GetOrCreateAsync(Me(), 0, (_, _, _) => new ValueTask(42), options).AsTask()); + Assert.Equal("options", ex.ParamName); + } + [Fact] + public async Task SetAsync_RejectsNegativeLocalSize() + { + using var provider = GetDefaultCache(out var cache); + var options = new HybridCacheEntryOptions { LocalSize = -1 }; + + var ex = await Assert.ThrowsAsync( + () => cache.SetAsync(Me(), 42, options).AsTask()); + Assert.Equal("options", ex.ParamName); + } + + private static string Me([CallerMemberName] string caller = "") => caller; } diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs index 8e23143475f..627695c9a6b 100644 --- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs @@ -11,7 +11,7 @@ public class HybridCacheEventSourceTests(ITestOutputHelper log, TestEventListene { // see notes in TestEventListener for context on fixture usage - [SkippableFact] + [Fact] public void MatchesNameAndGuid() { // Assert diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/PayloadTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/PayloadTests.cs index 6ee4a8a5558..2cce7b5efba 100644 --- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/PayloadTests.cs +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/PayloadTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; @@ -54,13 +54,72 @@ public void RoundTrip_Success(string delimitedTags, int expectedLength, int tagC Assert.Equal(expectedLength, actualLength); clock.Add(TimeSpan.FromSeconds(10)); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); log.WriteLine($"Entropy: {entropy}; Flags: {flags}"); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); Assert.True(payload.SequenceEqual(bytes)); Assert.True(pendingTags.IsEmpty); } + [Theory] + [InlineData(null)] + [InlineData(0L)] + [InlineData(42L)] + [InlineData(long.MaxValue)] + public void RoundTrip_LocalCacheSize(long? localCacheSize) + { + using var provider = GetDefaultCache(out var cache); + + byte[] bytes = new byte[64]; + new Random().NextBytes(bytes); + + string key = "k"; + TagSet tags = TagSet.Empty; + int maxLen = HybridCachePayload.GetMaxBytes(key, tags, bytes.Length); + byte[] oversized = ArrayPool.Shared.Rent(maxLen); + + int actualLength = HybridCachePayload.Write(oversized, key, cache.CurrentTimestamp(), TimeSpan.FromMinutes(1), 0, tags, new(bytes), localCacheSize: localCacheSize); + + var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, + out var payload, out _, out var flags, out _, out var pendingTags, out long? parsedSize, out _); + + Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); + Assert.True(payload.SequenceEqual(bytes)); + Assert.True(pendingTags.IsEmpty); + Assert.Equal(localCacheSize, parsedSize); + Assert.Equal(localCacheSize is not null, (flags & HybridCachePayload.PayloadFlags.HasLocalSize) != 0); + + ArrayPool.Shared.Return(oversized); + } + + // Guards the v1 wire format. This test pins a literal v1 byte sequence and asserts the reader + // still accepts it. + // + // The bytes below were produced by `HybridCachePayload.Write` with: + // key="frozen-key", tags=Empty, payload=[0x01..0x08], + // creationTime = 638000000000000000 ticks (~2023-01-30 UTC), + // duration = TimeSpan.FromDays(36500) (100 years headroom so the payload stays valid). + [Fact] + public void V1_FrozenBytes_StillReadable() + { + using var provider = GetDefaultCache(out var cache); + + byte[] frozen = Convert.FromBase64String( + "AwGPlwAAs6aeodoIAAiAgIztsrqCOAAKZnJvemVuLWtleQECAwQFBgcIAwE="); + + var result = HybridCachePayload.TryParse( + new(frozen), "frozen-key", TagSet.Empty, cache, + out var payload, out var remaining, out var flags, out _, out var pendingTags, + out long? parsedSize, out _); + + Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); + Assert.True(payload.AsSpan().SequenceEqual([ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08 ])); + Assert.True(pendingTags.IsEmpty); + Assert.Null(parsedSize); + Assert.Equal(HybridCachePayload.PayloadFlags.None, flags & HybridCachePayload.PayloadFlags.HasLocalSize); + Assert.True(remaining > TimeSpan.Zero, "v1 frozen entry should not be expired (100-year duration)."); + } + [Fact] public void RoundTrip_SelfExpiration() { @@ -83,13 +142,14 @@ public void RoundTrip_SelfExpiration() Assert.Equal(1063, actualLength); clock.Add(TimeSpan.FromSeconds(58)); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse( + new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); Assert.True(payload.SequenceEqual(bytes)); Assert.True(pendingTags.IsEmpty); clock.Add(TimeSpan.FromSeconds(4)); - result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out payload, out remaining, out flags, out entropy, out pendingTags, out _); + result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out payload, out remaining, out flags, out entropy, out pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.ExpiredByEntry, result); Assert.Equal(0, payload.Count); Assert.True(pendingTags.IsEmpty); @@ -119,7 +179,8 @@ public async Task RoundTrip_WildcardExpiration() clock.Add(TimeSpan.FromSeconds(2)); await cache.RemoveByTagAsync("*"); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse( + new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.ExpiredByWildcard, result); Assert.Equal(0, payload.Count); Assert.True(pendingTags.IsEmpty); @@ -149,13 +210,14 @@ public async Task RoundTrip_TagExpiration() clock.Add(TimeSpan.FromSeconds(2)); await cache.RemoveByTagAsync("other_tag"); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse( + new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); Assert.True(payload.SequenceEqual(bytes)); Assert.True(pendingTags.IsEmpty); await cache.RemoveByTagAsync("some_tag"); - result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out payload, out remaining, out flags, out entropy, out pendingTags, out _); + result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out payload, out remaining, out flags, out entropy, out pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.ExpiredByTag, result); Assert.Equal(0, payload.Count); Assert.True(pendingTags.IsEmpty); @@ -187,7 +249,8 @@ public async Task RoundTrip_TagExpiration_Pending() var tcs = new TaskCompletionSource(); cache.DebugInvalidateTag("some_tag", tcs.Task); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse( + new(oversized, 0, actualLength), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); Assert.True(payload.SequenceEqual(bytes)); Assert.Equal(1, pendingTags.Count); @@ -209,15 +272,18 @@ public void Gibberish() byte[] bytes = new byte[1024]; new Random().NextBytes(bytes); - var result = HybridCachePayload.TryParse(new(bytes), "whatever", TagSet.Empty, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse(new(bytes), "whatever", TagSet.Empty, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.FormatNotRecognized, result); Assert.Equal(0, payload.Count); Assert.True(pendingTags.IsEmpty); } [Fact] - public void RoundTrip_Truncated() + public void RoundTrip_TruncatedAtEveryLength_NeverThrows() { + // Truncating the payload at *every* prefix length must surface as a clean parse result + // (never an exception that becomes ParseFault). In particular, a buffer that ends in the + // middle of a varint must not read past the end of the span. var clock = new FakeTime(); using var provider = GetDefaultCache(out var cache, config => { @@ -236,10 +302,19 @@ public void RoundTrip_Truncated() log.WriteLine($"bytes written: {actualLength}"); Assert.Equal(1063, actualLength); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength - 1), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); - Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.InvalidData, result); - Assert.Equal(0, payload.Count); - Assert.True(pendingTags.IsEmpty); + for (int truncatedLength = 0; truncatedLength < actualLength; truncatedLength++) + { + var result = HybridCachePayload.TryParse( + new(oversized, 0, truncatedLength), key, tags, cache, + out var payload, out _, out _, out _, out var pendingTags, out _, out var fault); + + Assert.Null(fault); + Assert.NotEqual(HybridCachePayload.HybridCachePayloadParseResult.Success, result); + Assert.Equal(0, payload.Count); + Assert.True(pendingTags.IsEmpty); + } + + ArrayPool.Shared.Return(oversized); } [Fact] @@ -263,7 +338,8 @@ public void RoundTrip_Oversized() log.WriteLine($"bytes written: {actualLength}"); Assert.Equal(1063, actualLength); - var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength + 1), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + var result = HybridCachePayload.TryParse( + new(oversized, 0, actualLength + 1), key, tags, cache, out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.InvalidData, result); Assert.Equal(0, payload.Count); Assert.True(pendingTags.IsEmpty); @@ -323,6 +399,83 @@ public async Task MalformedTagDetected() collector.AssertErrors([Log.IdTagInvalidUnicode]); } + public enum FactoryMutation + { + MutateNonFlags, + SetFlagsToNone, + SetFlagsReadSideOnly, + CallerDisabledL2Write_FactoryClears, + } + + [Theory] + [InlineData(FactoryMutation.MutateNonFlags)] + [InlineData(FactoryMutation.SetFlagsToNone)] + [InlineData(FactoryMutation.SetFlagsReadSideOnly)] + [InlineData(FactoryMutation.CallerDisabledL2Write_FactoryClears)] + public async Task MalformedKey_DoesNotWriteToL2_EvenWhenFactoryMutatesOptions(FactoryMutation mutation) + { + // When the key fails unicode validation, make sure the (corrupted) key/tags are not persisted to L2. + using var collector = new LogCollector(); + MemoryDistributedCache? localCache = null; + using var provider = GetDefaultCache(out var cache, config => + { + localCache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + config.AddSingleton(new LoggingCache(log, localCache)); + config.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); + }); + + Assert.NotNull(localCache); + + string key = "my\uD801\uD802key"; // malformed unicode + string[] tags = ["mytag"]; + + HybridCacheEntryFlags callerFlags = mutation == FactoryMutation.CallerDisabledL2Write_FactoryClears + ? HybridCacheEntryFlags.DisableDistributedCacheWrite + : HybridCacheEntryFlags.None; + + _ = await cache.GetOrCreateAsync( + key, + (entryOptions, _) => + { + switch (mutation) + { + case FactoryMutation.MutateNonFlags: + entryOptions.LocalSize = 1234; + break; + case FactoryMutation.SetFlagsToNone: + entryOptions.Flags = HybridCacheEntryFlags.None; + break; + case FactoryMutation.SetFlagsReadSideOnly: + entryOptions.Flags = HybridCacheEntryFlags.DisableLocalCacheRead; + break; + case FactoryMutation.CallerDisabledL2Write_FactoryClears: + entryOptions.Flags = HybridCacheEntryFlags.None; + break; + } + + return new ValueTask(Guid.NewGuid()); + }, + options: new HybridCacheEntryOptions { Expiration = TimeSpan.FromMinutes(1), Flags = callerFlags }, + tags: tags); + + // Wait until the unicode-validation log fires (synchronous, inside BackgroundFetchAsync, + // just before the L2 write decision); then give the background work a brief moment + // to flush any erroneous L2 SetAsync. + await collector.WaitForLogsAsync([Log.IdKeyInvalidUnicode], TimeSpan.FromSeconds(5)); + await Task.Delay(500); + + collector.WriteTo(log); + collector.AssertErrors([Log.IdKeyInvalidUnicode]); + + // The corrupted key must NOT have been persisted to L2. (Note: unrelated + // tag-invalidation reads against valid tag keys may still appear in the backend.) + Assert.Null(localCache.Get(key)); + } + [Theory] [InlineData("tag1,tag2", 2)] [InlineData("tag1,tag2,tag3", 3)] @@ -358,7 +511,7 @@ public void RoundTrip_WithPendingTags_WhenKnownTagsMismatch(string delimitedTags // Parse with empty knownTags to force all tags into pendingTags via the rented buffer path var result = HybridCachePayload.TryParse(new(oversized, 0, actualLength), key, TagSet.Empty, cache, - out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _); + out var payload, out var remaining, out var flags, out var entropy, out var pendingTags, out _, out _); Assert.Equal(HybridCachePayload.HybridCachePayloadParseResult.Success, result); Assert.True(payload.SequenceEqual(bytes)); Assert.Equal(tagCount, pendingTags.Count);