Skip to content

Commit 93561a7

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for plugins and credential service in AdkApp
PiperOrigin-RevId: 799667543
1 parent 248a365 commit 93561a7

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,16 @@ def test_set_up(self):
220220
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL),
221221
)
222222
assert app._tmpl_attrs.get("runner") is None
223+
assert app._tmpl_attrs.get("session_service") is None
224+
assert app._tmpl_attrs.get("artifact_service") is None
225+
assert app._tmpl_attrs.get("memory_service") is None
226+
assert app._tmpl_attrs.get("credential_service") is None
223227
app.set_up()
224228
assert app._tmpl_attrs.get("runner") is not None
229+
assert app._tmpl_attrs.get("session_service") is not None
230+
assert app._tmpl_attrs.get("artifact_service") is not None
231+
assert app._tmpl_attrs.get("memory_service") is not None
232+
assert app._tmpl_attrs.get("credential_service") is not None
225233

226234
def test_clone(self):
227235
app = reasoning_engines.AdkApp(

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@
4444
except (ImportError, AttributeError):
4545
BaseAgent = Any
4646

47+
try:
48+
from google.adk.plugins.base_plugin import BasePlugin
49+
50+
BasePlugin = BasePlugin
51+
except (ImportError, AttributeError):
52+
BasePlugin = Any
53+
4754
try:
4855
from google.adk.sessions import BaseSessionService
4956

@@ -72,6 +79,15 @@
7279
except (ImportError, AttributeError):
7380
BaseMemoryService = Any
7481

82+
try:
83+
from google.adk.auth.credential_service.base_credential_service import (
84+
BaseCredentialService,
85+
)
86+
87+
BaseCredentialService = BaseCredentialService
88+
except (ImportError, AttributeError):
89+
BaseCredentialService = Any
90+
7591
try:
7692
from opentelemetry.sdk import trace
7793

@@ -322,10 +338,14 @@ def __init__(
322338
self,
323339
*,
324340
agent: "BaseAgent",
341+
plugins: Optional[List["BasePlugin"]] = None,
325342
enable_tracing: bool = False,
326343
session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None,
327344
artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None,
328345
memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None,
346+
credential_service_builder: Optional[
347+
Callable[..., "BaseCredentialService"]
348+
] = None,
329349
env_vars: Optional[Dict[str, str]] = None,
330350
):
331351
"""An ADK Application."""
@@ -343,10 +363,12 @@ def __init__(
343363
"project": initializer.global_config.project,
344364
"location": initializer.global_config.location,
345365
"agent": agent,
366+
"plugins": plugins,
346367
"enable_tracing": enable_tracing,
347368
"session_service_builder": session_service_builder,
348369
"artifact_service_builder": artifact_service_builder,
349370
"memory_service_builder": memory_service_builder,
371+
"credential_service_builder": credential_service_builder,
350372
"app_name": _DEFAULT_APP_NAME,
351373
"env_vars": env_vars or {},
352374
}
@@ -533,8 +555,19 @@ def set_up(self):
533555
else:
534556
self._tmpl_attrs["memory_service"] = InMemoryMemoryService()
535557

558+
credential_service_builder = self._tmpl_attrs.get("credential_service_builder")
559+
if credential_service_builder:
560+
self._tmpl_attrs["credential_service"] = credential_service_builder()
561+
else:
562+
from google.adk.auth.credential_service.in_memory_credential_service import (
563+
InMemoryCredentialService,
564+
)
565+
566+
self._tmpl_attrs["credential_service"] = InMemoryCredentialService()
567+
536568
self._tmpl_attrs["runner"] = Runner(
537569
agent=self._tmpl_attrs.get("agent"),
570+
plugins=self._tmpl_attrs.get("plugins"),
538571
session_service=self._tmpl_attrs.get("session_service"),
539572
artifact_service=self._tmpl_attrs.get("artifact_service"),
540573
memory_service=self._tmpl_attrs.get("memory_service"),
@@ -548,6 +581,7 @@ def set_up(self):
548581
session_service=self._tmpl_attrs.get("in_memory_session_service"),
549582
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
550583
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
584+
credential_service=self._tmpl_attrs.get("credential_service"),
551585
app_name=self._tmpl_attrs.get("app_name"),
552586
)
553587

0 commit comments

Comments
 (0)