Skip to content
36 changes: 22 additions & 14 deletions src/Middleware/OutputCaching/src/Memory/MemoryOutputCacheStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Linq;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.AspNetCore.OutputCaching.Memory;

internal sealed class MemoryOutputCacheStore : IOutputCacheStore
{
private readonly MemoryCache _cache;
private readonly Dictionary<string, HashSet<string>> _taggedEntries = new();
private readonly Dictionary<string, HashSet<TaggedEntry>> _taggedEntries = [];
private readonly object _tagsLock = new();

internal MemoryOutputCacheStore(MemoryCache cache)
Expand All @@ -20,7 +21,7 @@ internal MemoryOutputCacheStore(MemoryCache cache)
}

// For testing
internal Dictionary<string, HashSet<string>> TaggedEntries => _taggedEntries;
internal Dictionary<string, HashSet<string>> TaggedEntries => _taggedEntries.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.Select(t => t.Key).ToHashSet());

public ValueTask EvictByTagAsync(string tag, CancellationToken cancellationToken)
{
Expand All @@ -30,7 +31,7 @@ public ValueTask EvictByTagAsync(string tag, CancellationToken cancellationToken
{
if (_taggedEntries.TryGetValue(tag, out var keys))
{
if (keys != null && keys.Count > 0)
if (keys is { Count: > 0 })
{
// If MemoryCache changed to run eviction callbacks inline in Remove, iterating over keys could throw
// To prevent allocating a copy of the keys we check if the eviction callback ran,
Expand All @@ -40,7 +41,7 @@ public ValueTask EvictByTagAsync(string tag, CancellationToken cancellationToken
while (i > 0)
{
var oldCount = keys.Count;
foreach (var key in keys)
foreach (var (key, _) in keys)
{
_cache.Remove(key);
i--;
Expand Down Expand Up @@ -74,6 +75,8 @@ public ValueTask SetAsync(string key, byte[] value, string[]? tags, TimeSpan val
ArgumentNullException.ThrowIfNull(key);
ArgumentNullException.ThrowIfNull(value);

var entryId = Guid.NewGuid();

if (tags != null)
{
// Lock with SetEntry() to prevent EvictByTagAsync() from trying to remove a tag whose entry hasn't been added yet.
Expand All @@ -90,27 +93,27 @@ public ValueTask SetAsync(string key, byte[] value, string[]? tags, TimeSpan val

if (!_taggedEntries.TryGetValue(tag, out var keys))
{
keys = new HashSet<string>();
keys = new HashSet<TaggedEntry>();
_taggedEntries[tag] = keys;
}

Debug.Assert(keys != null);

keys.Add(key);
keys.Add(new TaggedEntry(key, entryId));
}

SetEntry(key, value, tags, validFor);
SetEntry(key, value, tags, validFor, entryId);
}
}
else
{
SetEntry(key, value, tags, validFor);
SetEntry(key, value, tags, validFor, entryId);
}

return ValueTask.CompletedTask;
}

void SetEntry(string key, byte[] value, string[]? tags, TimeSpan validFor)
private void SetEntry(string key, byte[] value, string[]? tags, TimeSpan validFor, Guid entryId)
{
Debug.Assert(key != null);

Expand All @@ -120,30 +123,33 @@ void SetEntry(string key, byte[] value, string[]? tags, TimeSpan validFor)
Size = value.Length
};

if (tags != null && tags.Length > 0)
if (tags is { Length: > 0 })
{
// Remove cache keys from tag lists when the entry is evicted
options.RegisterPostEvictionCallback(RemoveFromTags, tags);
options.RegisterPostEvictionCallback(RemoveFromTags, (tags, entryId));
}

_cache.Set(key, value, options);
}

void RemoveFromTags(object key, object? value, EvictionReason reason, object? state)
private void RemoveFromTags(object key, object? value, EvictionReason reason, object? state)
{
var tags = state as string[];
Debug.Assert(state != null);

var (tags, entryId) = ((string[] Tags, Guid EntryId))state;

Debug.Assert(tags != null);
Debug.Assert(tags.Length > 0);
Debug.Assert(key is string);
Debug.Assert(entryId != Guid.Empty);

lock (_tagsLock)
{
foreach (var tag in tags)
{
if (_taggedEntries.TryGetValue(tag, out var tagged))
{
tagged.Remove((string)key);
tagged.Remove(new TaggedEntry((string)key, entryId));

// Remove the collection if there is no more keys in it
if (tagged.Count == 0)
Expand All @@ -154,4 +160,6 @@ void RemoveFromTags(object key, object? value, EvictionReason reason, object? st
}
}
}

private record TaggedEntry(string Key, Guid EntryId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,43 @@ public async Task ExpiredEntries_AreRemovedFromTags()
Assert.Single(tag2s);
}

[Fact]
public async Task ReplacedEntries_AreNotRemovedFromTags()
{
var testClock = new TestMemoryOptionsClock { UtcNow = DateTimeOffset.UtcNow };
var cache = new MemoryCache(new MemoryCacheOptions { SizeLimit = 1000, Clock = testClock, ExpirationScanFrequency = TimeSpan.FromMilliseconds(1) });
var store = new MemoryOutputCacheStore(cache);
var value = "abc"u8.ToArray();

await store.SetAsync("a", value, new[] { "tag1", "tag2" }, TimeSpan.FromMilliseconds(5), default);
await store.SetAsync("a", value, new[] { "tag1" }, TimeSpan.FromMilliseconds(20), default);

testClock.Advance(TimeSpan.FromMilliseconds(10));

// Trigger background expiration by accessing the cache.
_ = cache.Get("a");

var resulta = await store.GetAsync("a", default);

Assert.NotNull(resulta);

HashSet<string> tag1s, tag2s;

// Wait for the tag2 HashSet to be removed by the background expiration thread.

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

while (store.TaggedEntries.TryGetValue("tag2", out tag2s) && !cts.IsCancellationRequested)
{
await Task.Yield();
}

store.TaggedEntries.TryGetValue("tag1", out tag1s);

Assert.Null(tag2s);
Assert.Single(tag1s);
}

[Theory]
[InlineData(null)]
public async Task Store_Throws_OnInvalidTag(string tag)
Expand Down
Loading