Skip to content

Commit a250a7e

Browse files
committed
feat: Tackling Race/State condition issue by Changing the Code Design
- adding new type `ApiStreamGroundingChunk` to the stream type - collecting sources in the `Task.ts` instead -> decoupling
1 parent 11c454f commit a250a7e

File tree

3 files changed

+72
-17
lines changed

3 files changed

+72
-17
lines changed

src/api/providers/gemini.ts

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"
1515

1616
import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
1717
import { t } from "i18next"
18-
import type { ApiStream } from "../transform/stream"
18+
import type { ApiStream, GroundingSource } from "../transform/stream"
1919
import { getModelParams } from "../transform/model-params"
2020

2121
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
132132
}
133133

134134
if (pendingGroundingMetadata) {
135-
const citations = this.extractCitationsOnly(pendingGroundingMetadata)
136-
if (citations) {
137-
yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
135+
const sources = this.extractGroundingSources(pendingGroundingMetadata)
136+
if (sources.length > 0) {
137+
yield { type: "grounding", sources }
138138
}
139139
}
140140

@@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
175175
return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
176176
}
177177

178-
private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
178+
private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] {
179179
const chunks = groundingMetadata?.groundingChunks
180180

181181
if (!chunks) {
182-
return null
182+
return []
183183
}
184184

185-
const citationLinks = chunks
186-
.map((chunk, i) => {
185+
return chunks
186+
.map((chunk): GroundingSource | null => {
187187
const uri = chunk.web?.uri
188+
const title = chunk.web?.title || uri || "Unknown Source"
189+
188190
if (uri) {
189-
return `[${i + 1}](${uri})`
191+
return {
192+
title,
193+
url: uri,
194+
}
190195
}
191196
return null
192197
})
193-
.filter((link): link is string => link !== null)
198+
.filter((source): source is GroundingSource => source !== null)
199+
}
200+
201+
private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
202+
const sources = this.extractGroundingSources(groundingMetadata)
194203

195-
if (citationLinks.length > 0) {
196-
return citationLinks.join(", ")
204+
if (sources.length === 0) {
205+
return null
197206
}
198207

199-
return null
208+
const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`)
209+
return citationLinks.join(", ")
200210
}
201211

202212
async completePrompt(prompt: string): Promise<string> {

src/api/transform/stream.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
export type ApiStream = AsyncGenerator<ApiStreamChunk>
22

3-
export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
3+
export type ApiStreamChunk =
4+
| ApiStreamTextChunk
5+
| ApiStreamUsageChunk
6+
| ApiStreamReasoningChunk
7+
| ApiStreamGroundingChunk
8+
| ApiStreamError
49

510
export interface ApiStreamError {
611
type: "error"
@@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk {
2732
reasoningTokens?: number
2833
totalCost?: number
2934
}
35+
36+
export interface ApiStreamGroundingChunk {
37+
type: "grounding"
38+
sources: GroundingSource[]
39+
}
40+
41+
export interface GroundingSource {
42+
title: string
43+
url: string
44+
snippet?: string
45+
}

src/core/task/Task.ts

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import { CloudService, ExtensionBridgeService } from "@roo-code/cloud"
3939

4040
// api
4141
import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
42-
import { ApiStream } from "../../api/transform/stream"
42+
import { ApiStream, GroundingSource } from "../../api/transform/stream"
4343

4444
// shared
4545
import { findLastIndex } from "../../shared/array"
@@ -1746,7 +1746,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
17461746
this.didFinishAbortingStream = true
17471747
}
17481748

1749-
// Reset streaming state.
1749+
// Reset streaming state for each new API request
17501750
this.currentStreamingContentIndex = 0
17511751
this.currentStreamingDidCheckpoint = false
17521752
this.assistantMessageContent = []
@@ -1767,6 +1767,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
17671767
const stream = this.attemptApiRequest()
17681768
let assistantMessage = ""
17691769
let reasoningMessage = ""
1770+
let pendingGroundingSources: GroundingSource[] = []
17701771
this.isStreaming = true
17711772

17721773
try {
@@ -1793,6 +1794,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
17931794
cacheReadTokens += chunk.cacheReadTokens ?? 0
17941795
totalCost = chunk.totalCost
17951796
break
1797+
case "grounding":
1798+
// Handle grounding sources separately from regular content
1799+
// to prevent state persistence issues - store them separately
1800+
if (chunk.sources && chunk.sources.length > 0) {
1801+
pendingGroundingSources.push(...chunk.sources)
1802+
}
1803+
break
17961804
case "text": {
17971805
assistantMessage += chunk.text
17981806

@@ -2086,9 +2094,30 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
20862094
let didEndLoop = false
20872095

20882096
if (assistantMessage.length > 0) {
2097+
// Display grounding sources to the user if they exist
2098+
if (pendingGroundingSources.length > 0) {
2099+
const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`)
2100+
const sourcesText = `Sources: ${citationLinks.join(", ")}`
2101+
2102+
await this.say("text", sourcesText, undefined, false, undefined, undefined, {
2103+
isNonInteractive: true,
2104+
})
2105+
}
2106+
2107+
// Strip grounding sources from assistant message before persisting to API history
2108+
// This prevents state persistence issues while maintaining user experience
2109+
let cleanAssistantMessage = assistantMessage
2110+
if (pendingGroundingSources.length > 0) {
2111+
// Remove any grounding source references that might have been integrated into the message
2112+
cleanAssistantMessage = assistantMessage
2113+
.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
2114+
.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
2115+
.trim()
2116+
}
2117+
20892118
await this.addToApiConversationHistory({
20902119
role: "assistant",
2091-
content: [{ type: "text", text: assistantMessage }],
2120+
content: [{ type: "text", text: cleanAssistantMessage }],
20922121
})
20932122

20942123
TelemetryService.instance.captureConversationMessage(this.taskId, "assistant")

0 commit comments

Comments
 (0)