Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/SignalR/common/Shared/MessageBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,16 @@ private async Task RunTimer()

public ValueTask<FlushResult> WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken)
{
return WriteAsyncCore(hubMessage.Message!, hubMessage.GetSerializedMessage(_protocol), cancellationToken);
// Default to HubInvocationMessage as that's the only type we use SerializedHubMessage for currently. Should harden this in the future.
return WriteAsyncCore(hubMessage.Message?.GetType() ?? typeof(HubInvocationMessage), hubMessage.GetSerializedMessage(_protocol), cancellationToken);
}

public ValueTask<FlushResult> WriteAsync(HubMessage hubMessage, CancellationToken cancellationToken)
{
return WriteAsyncCore(hubMessage, _protocol.GetMessageBytes(hubMessage), cancellationToken);
return WriteAsyncCore(hubMessage.GetType(), _protocol.GetMessageBytes(hubMessage), cancellationToken);
}

private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
private async ValueTask<FlushResult> WriteAsyncCore(Type hubMessageType, ReadOnlyMemory<byte> messageBytes, CancellationToken cancellationToken)
{
// TODO: Add backpressure based on message count
if (_bufferedByteCount > _bufferLimit)
Expand Down Expand Up @@ -158,7 +159,7 @@ private async ValueTask<FlushResult> WriteAsyncCore(HubMessage hubMessage, ReadO
await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false);
try
{
if (hubMessage is HubInvocationMessage invocationMessage)
if (typeof(HubInvocationMessage).IsAssignableFrom(hubMessageType))
{
_totalMessageCount++;
_bufferedByteCount += messageBytes.Length;
Expand Down
3 changes: 3 additions & 0 deletions src/SignalR/server/Core/src/SerializedHubMessage.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.AspNetCore.SignalR;
Expand Down Expand Up @@ -40,6 +41,8 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
/// <param name="message">The hub message for the cache. This will be serialized with an <see cref="IHubProtocol"/> in <see cref="GetSerializedMessage"/> to get the message's serialized representation.</param>
public SerializedHubMessage(HubMessage message)
{
// Type currently only used for invocation messages, we should probably refactor it to be explicit about that e.g. new property for message type?
Debug.Assert(message.GetType().IsAssignableTo(typeof(HubInvocationMessage)));
Message = message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Text.Json;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Time.Testing;

Expand Down Expand Up @@ -169,6 +170,62 @@ public async Task UnAckedMessageResentOnReconnect()
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
}

// Regression test for https://github.com/dotnet/aspnetcore/issues/55575
[Fact]
public async Task UnAckedSerializedMessageResentOnReconnect()
{
var protocol = new JsonHubProtocol();
var connection = new TestConnectionContext();
var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions());
connection.Transport = pipes.Transport;
using var messageBuffer = new MessageBuffer(connection, protocol, bufferLimit: 1000, NullLogger.Instance);

var invocationMessage = new SerializedHubMessage([new SerializedMessage(protocol.Name,
protocol.GetMessageBytes(new InvocationMessage("method1", [1])))]);
await messageBuffer.WriteAsync(invocationMessage, default);

var res = await pipes.Application.Input.ReadAsync();

var buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message));
var parsedMessage = Assert.IsType<InvocationMessage>(message);
Assert.Equal("method1", parsedMessage.Target);
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());

pipes.Application.Input.AdvanceTo(buffer.Start);

DuplexPipe.UpdateConnectionPair(ref pipes, connection);
await messageBuffer.ResendAsync(pipes.Transport.Output);

Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1)));

res = await pipes.Application.Input.ReadAsync();

buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
var seqMessage = Assert.IsType<SequenceMessage>(message);
Assert.Equal(1, seqMessage.SequenceId);

pipes.Application.Input.AdvanceTo(buffer.Start);

res = await pipes.Application.Input.ReadAsync();

buffer = res.Buffer;
Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message));
parsedMessage = Assert.IsType<InvocationMessage>(message);
Assert.Equal("method1", parsedMessage.Target);
Assert.Equal(1, ((JsonElement)Assert.Single(parsedMessage.Arguments)).GetInt32());

pipes.Application.Input.AdvanceTo(buffer.Start);

messageBuffer.ShouldProcessMessage(new SequenceMessage(1));

Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance));
Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null)));
}

[Fact]
public async Task AckedMessageNotResentOnReconnect()
{
Expand Down
173 changes: 167 additions & 6 deletions src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using System.Net.WebSockets;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit;

namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;

Expand Down Expand Up @@ -213,7 +211,105 @@ public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLit
}
}

private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
[ConditionalTheory]
[SkipIfDockerNotPresent]
[InlineData("messagepack")]
[InlineData("json")]
public async Task StatefulReconnectPreservesMessageFromOtherServer(string protocolName)
{
using (StartVerifiableLog())
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);

ClientWebSocket innerWs = null;
WebSocketWrapper ws = null;
TaskCompletionSource reconnectTcs = null;
TaskCompletionSource startedReconnectTcs = null;

var connection = CreateConnection(_serverFixture.FirstServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory,
customizeConnection: builder =>
{
builder.WithStatefulReconnect();
builder.Services.Configure<HttpConnectionOptions>(o =>
{
// Replace the websocket creation for the first connection so we can make the client think there was an ungraceful closure
// Which will trigger the stateful reconnect flow
o.WebSocketFactory = async (context, token) =>
{
if (reconnectTcs is null)
{
reconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
startedReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
else
{
startedReconnectTcs.SetResult();
// We only want to wait on the reconnect, not the initial connection attempt
await reconnectTcs.Task.DefaultTimeout();
}

innerWs = new ClientWebSocket();
ws = new WebSocketWrapper(innerWs);
await innerWs.ConnectAsync(context.Uri, token);

_ = Task.Run(async () =>
{
try
{
while (innerWs.State == WebSocketState.Open)
{
var buffer = new byte[1024];
var res = await innerWs.ReceiveAsync(buffer, default);
ws.SetReceiveResult((res, buffer.AsMemory(0, res.Count)));
}
}
// Log but ignore receive errors, that likely just means the connection closed
catch (Exception ex)
{
Logger.LogInformation(ex, "Error while reading from inner websocket");
}
});

return ws;
};
});
});
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/stateful", HttpTransportType.WebSockets, protocol, LoggerFactory);

var tcs = new TaskCompletionSource<string>();
connection.On<string>("SendToAll", message => tcs.TrySetResult(message));

var tcs2 = new TaskCompletionSource<string>();
secondConnection.On<string>("SendToAll", message => tcs2.TrySetResult(message));

await connection.StartAsync().DefaultTimeout();
await secondConnection.StartAsync().DefaultTimeout();

// Close first connection before the second connection sends a message to all clients
await ws.CloseOutputAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default);
await startedReconnectTcs.Task.DefaultTimeout();

// Send to all clients, since both clients are on different servers this means the backplane will be used
// And we want to test that messages are still preserved for stateful reconnect purposes when a client disconnects
// But is on a different server from the original message sender.
await secondConnection.SendAsync("SendToAll", "test message").DefaultTimeout();

// Check that second connection still receives the message
Assert.Equal("test message", await tcs2.Task.DefaultTimeout());
Assert.False(tcs.Task.IsCompleted);

// allow first connection to reconnect
reconnectTcs.SetResult();

// Check that first connection received the message once it reconnected
Assert.Equal("test message", await tcs.Task.DefaultTimeout());

await connection.DisposeAsync().DefaultTimeout();
}
}

private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null,
Action<IHubConnectionBuilder> customizeConnection = null)
{
var hubConnectionBuilder = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
Expand All @@ -227,6 +323,8 @@ private static HubConnection CreateConnection(string url, HttpTransportType tran

hubConnectionBuilder.Services.AddSingleton(protocol);

customizeConnection?.Invoke(hubConnectionBuilder);

return hubConnectionBuilder.Build();
}

Expand Down Expand Up @@ -255,4 +353,67 @@ public static IEnumerable<object[]> TransportTypesAndProtocolTypes
}
}
}

internal sealed class WebSocketWrapper : WebSocket
{
private readonly WebSocket _inner;
private TaskCompletionSource<(WebSocketReceiveResult, ReadOnlyMemory<byte>)> _receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

public WebSocketWrapper(WebSocket inner)
{
_inner = inner;
}

public override WebSocketCloseStatus? CloseStatus => _inner.CloseStatus;

public override string CloseStatusDescription => _inner.CloseStatusDescription;

public override WebSocketState State => _inner.State;

public override string SubProtocol => _inner.SubProtocol;

public override void Abort()
{
_inner.Abort();
}

public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
return _inner.CloseAsync(closeStatus, statusDescription, cancellationToken);
}

public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
_receiveTcs.TrySetException(new IOException("force reconnect"));
return Task.CompletedTask;
}

public override void Dispose()
{
_inner.Dispose();
}

public void SetReceiveResult((WebSocketReceiveResult, ReadOnlyMemory<byte>) result)
{
_receiveTcs.SetResult(result);
}

public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
var res = await _receiveTcs.Task;
// Handle zero-byte reads
if (buffer.Count == 0)
{
return res.Item1;
}
_receiveTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
res.Item2.CopyTo(buffer);
return res.Item1;
}

public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
return _inner.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
}
}
}
1 change: 1 addition & 0 deletions src/SignalR/server/StackExchangeRedis/test/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public void Configure(IApplicationBuilder app)
app.UseEndpoints(endpoints =>
{
endpoints.MapHub<EchoHub>("/echo");
endpoints.MapHub<StatefulHub>("/stateful", o => o.AllowStatefulReconnects = true);
});
}

Expand Down
12 changes: 12 additions & 0 deletions src/SignalR/server/StackExchangeRedis/test/StatefulHub.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests;

public class StatefulHub : Hub
{
public Task SendToAll(string message)
{
return Clients.All.SendAsync("SendToAll", message);
}
}
Loading