@@ -290,6 +290,10 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
290
290
private _ParentClass : typeof Agent < Env , State > =
291
291
Object . getPrototypeOf ( this ) . constructor ;
292
292
293
+ // Connection state management for race condition fix
294
+ private _connectionStates = new Map < string , 'connecting' | 'ready' > ( ) ;
295
+ private _messageQueues = new Map < string , Array < { connection : Connection , message : WSMessage } > > ( ) ;
296
+
293
297
mcp : MCPClientManager = new MCPClientManager ( this . _ParentClass . name , "0.0.1" ) ;
294
298
295
299
/**
@@ -354,6 +358,44 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
354
358
*/
355
359
observability ?: Observability = genericObservability ;
356
360
361
+ /**
362
+ * Clean up connection state and message queue for a given connection
363
+ * @param connectionId The connection ID to clean up
364
+ */
365
+ private _cleanupConnection ( connectionId : string ) {
366
+ this . _connectionStates . delete ( connectionId ) ;
367
+ this . _messageQueues . delete ( connectionId ) ;
368
+ }
369
+
370
+ /**
371
+ * Process all queued messages for a connection once it's ready
372
+ * @param connectionId The connection ID to process messages for
373
+ */
374
+ private async _processQueuedMessages ( connectionId : string ) {
375
+ const queue = this . _messageQueues . get ( connectionId ) ;
376
+ if ( ! queue || queue . length === 0 ) {
377
+ return ;
378
+ }
379
+
380
+ // Process each queued message in order
381
+ for ( const { connection, message } of queue ) {
382
+ try {
383
+ await this . _originalOnMessage ( connection , message ) ;
384
+ } catch ( error ) {
385
+ console . error ( `Error processing queued message for connection ${ connectionId } :` , error ) ;
386
+ }
387
+ }
388
+
389
+ // Clear the queue after processing
390
+ this . _messageQueues . set ( connectionId , [ ] ) ;
391
+ }
392
+
393
+ /**
394
+ * The original onMessage handler that processes messages normally
395
+ * This will be set during constructor to preserve the original logic
396
+ */
397
+ private _originalOnMessage ! : ( connection : Connection , message : WSMessage ) => Promise < void > ;
398
+
357
399
/**
358
400
* Execute SQL queries against the Agent's database
359
401
* @template T Type of the returned rows
@@ -464,7 +506,9 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
464
506
} ;
465
507
466
508
const _onMessage = this . onMessage . bind ( this ) ;
467
- this . onMessage = async ( connection : Connection , message : WSMessage ) => {
509
+
510
+ // Store the original message handler for processing queued messages
511
+ this . _originalOnMessage = async ( connection : Connection , message : WSMessage ) => {
468
512
return agentContext . run (
469
513
{ agent : this , connection, request : undefined , email : undefined } ,
470
514
async ( ) => {
@@ -555,14 +599,33 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
555
599
) ;
556
600
} ;
557
601
602
+ // New onMessage handler with race condition protection
603
+ this . onMessage = async ( connection : Connection , message : WSMessage ) => {
604
+ // Check if connection is still in connecting state
605
+ if ( this . _connectionStates . get ( connection . id ) === 'connecting' ) {
606
+ // Queue the message for later processing
607
+ const queue = this . _messageQueues . get ( connection . id ) ;
608
+ if ( queue ) {
609
+ queue . push ( { connection, message } ) ;
610
+ }
611
+ return ;
612
+ }
613
+
614
+ // Process message normally if connection is ready
615
+ return this . _originalOnMessage ( connection , message ) ;
616
+ } ;
617
+
558
618
const _onConnect = this . onConnect . bind ( this ) ;
559
- this . onConnect = ( connection : Connection , ctx : ConnectionContext ) => {
560
- // TODO: This is a hack to ensure the state is sent after the connection is established
561
- // must fix this
619
+ this . onConnect = async ( connection : Connection , ctx : ConnectionContext ) => {
620
+ // Initialize connection state and message queue
621
+ this . _connectionStates . set ( connection . id , 'connecting' ) ;
622
+ this . _messageQueues . set ( connection . id , [ ] ) ;
623
+
562
624
return agentContext . run (
563
625
{ agent : this , connection, request : ctx . request , email : undefined } ,
564
626
async ( ) => {
565
- setTimeout ( ( ) => {
627
+ try {
628
+ // Send initial state immediately (no setTimeout needed)
566
629
if ( this . state ) {
567
630
connection . send (
568
631
JSON . stringify ( {
@@ -572,13 +635,15 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
572
635
) ;
573
636
}
574
637
638
+ // Send MCP servers state
575
639
connection . send (
576
640
JSON . stringify ( {
577
641
mcp : this . getMcpServers ( ) ,
578
642
type : "cf_agent_mcp_servers"
579
643
} )
580
644
) ;
581
645
646
+ // Emit observability event
582
647
this . observability ?. emit (
583
648
{
584
649
displayMessage : "Connection established" ,
@@ -591,8 +656,17 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
591
656
} ,
592
657
this . ctx
593
658
) ;
594
- return this . _tryCatch ( ( ) => _onConnect ( connection , ctx ) ) ;
595
- } , 20 ) ;
659
+
660
+ // Call user's onConnect handler and wait for completion
661
+ await this . _tryCatch ( ( ) => _onConnect ( connection , ctx ) ) ;
662
+
663
+ // Process any queued messages that arrived during setup
664
+ await this . _processQueuedMessages ( connection . id ) ;
665
+
666
+ } finally {
667
+ // Mark connection as ready
668
+ this . _connectionStates . set ( connection . id , 'ready' ) ;
669
+ }
596
670
}
597
671
) ;
598
672
} ;
@@ -890,9 +964,14 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
890
964
if ( connectionOrError && error ) {
891
965
theError = error ;
892
966
// this is a websocket connection error
967
+ const connection = connectionOrError as Connection ;
968
+
969
+ // Clean up connection state on error
970
+ this . _cleanupConnection ( connection . id ) ;
971
+
893
972
console . error (
894
973
"Error on websocket connection:" ,
895
- ( connectionOrError as Connection ) . id ,
974
+ connection . id ,
896
975
theError
897
976
) ;
898
977
console . error (
@@ -1337,6 +1416,10 @@ export class Agent<Env = typeof env, State = unknown> extends Server<Env> {
1337
1416
* Destroy the Agent, removing all state and scheduled tasks
1338
1417
*/
1339
1418
async destroy ( ) {
1419
+ // Clean up all connection states and message queues
1420
+ this . _connectionStates . clear ( ) ;
1421
+ this . _messageQueues . clear ( ) ;
1422
+
1340
1423
// drop all tables
1341
1424
this . sql `DROP TABLE IF EXISTS cf_agents_state` ;
1342
1425
this . sql `DROP TABLE IF EXISTS cf_agents_schedules` ;
0 commit comments