From 4be33a912732d61102271dc1cd5323e5312535f2 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Sat, 23 May 2026 00:55:28 -0700 Subject: [PATCH] feat: Add tools and toolset to use SkillSource in ADK agents This change introduces a set of tools and a toolset to enable ADK agents to interact with and use skills loaded via `SkillSource`. Key changes: - ListSkillsTool: Lists available skills. - LoadSkillTool: Loads skill instructions. - LoadSkillResourceTool: Loads skill resources. - SkillToolset: Groups the above tools. - Integrates SkillSource into LlmAgent and BaseLlmFlow. - Tests for all new tools. PiperOrigin-RevId: 920071248 --- .../adk/flows/llmflows/BaseLlmFlow.java | 61 +++- .../adk/skills/AbstractSkillSource.java | 24 +- .../adk/skills/InMemorySkillSource.java | 15 +- .../google/adk/skills/LocalSkillSource.java | 27 +- .../adk/skills/SkillSourceException.java | 36 +- .../java/com/google/adk/tools/BaseTool.java | 3 +- .../com/google/adk/tools/BaseToolset.java | 2 +- .../google/adk/tools/LlmRequestProcessor.java | 43 +++ .../adk/tools/skills/ListSkillsTool.java | 59 ++++ .../tools/skills/LoadSkillResourceTool.java | 242 +++++++++++++ .../adk/tools/skills/LoadSkillTool.java | 87 +++++ .../google/adk/tools/skills/SkillToolset.java | 131 +++++++ .../com/google/adk/agents/LlmAgentTest.java | 2 +- .../adk/flows/llmflows/BaseLlmFlowTest.java | 99 ++++++ .../adk/skills/LocalSkillSourceTest.java | 124 +++++++ .../java/com/google/adk/testing/TestLlm.java | 2 +- .../adk/tools/skills/ListSkillsToolTest.java | 151 ++++++++ .../skills/LoadSkillResourceToolTest.java | 330 ++++++++++++++++++ .../adk/tools/skills/LoadSkillToolTest.java | 161 +++++++++ .../adk/tools/skills/SkillToolsetTest.java | 125 +++++++ 20 files changed, 1686 insertions(+), 38 deletions(-) create mode 100644 core/src/main/java/com/google/adk/tools/LlmRequestProcessor.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/SkillToolset.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..06bc83026 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -38,6 +38,9 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.telemetry.Tracing; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.LlmRequestProcessor; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -58,6 +61,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -96,20 +100,8 @@ private Flowable preprocess( Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); - RequestProcessor toolsProcessor = - (ctx, req) -> { - LlmRequest.Builder builder = req.toBuilder(); - return agent - .canonicalTools(new ReadonlyContext(ctx)) - .concatMapCompletable( - tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) - .andThen( - Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); - }; - Iterable allProcessors = - Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor)); + Iterables.concat(requestProcessors, ImmutableList.of(getRequestProcessorFromTools(agent))); return Flowable.fromIterable(allProcessors) .concatMap( @@ -121,6 +113,49 @@ private Flowable preprocess( result -> result.events() != null ? result.events() : ImmutableList.of())); } + /** + * Constructs a {@link RequestProcessor} that sequentially applies the {@code processLlmRequest} + * methods of all tools and toolsets associated with this agent to the incoming {@link + * LlmRequest}. + * + * @return A {@link RequestProcessor} that applies tool-specific modifications to LLM requests. + */ + private RequestProcessor getRequestProcessorFromTools(LlmAgent agent) { + return (context, request) -> { + ReadonlyContext readonlyContext = new ReadonlyContext(context); + List> processors = new ArrayList<>(); + + for (Object toolOrToolset : agent.toolsUnion()) { + if (toolOrToolset instanceof BaseTool baseTool) { + processors.add(baseTool::processLlmRequest); + } else if (toolOrToolset instanceof BaseToolset baseToolset) { + // First apply the toolset's own request processor if it implements the + // LlmRequestProcessor interface, then unwrap all tools from the toolset + // and apply each individual tool's request processor sequentially. + processors.add( + (builder, ctx) -> + (baseToolset instanceof LlmRequestProcessor processor + ? processor.processLlmRequest(builder, ctx) + : Completable.complete()) + .andThen(baseToolset.getTools(readonlyContext)) + .concatMapCompletable(b -> b.processLlmRequest(builder, ctx))); + } else { + throw new IllegalArgumentException( + "Object in tools list is not of a supported type: " + + toolOrToolset.getClass().getName()); + } + } + + LlmRequest.Builder builder = request.toBuilder(); + ToolContext toolContext = ToolContext.builder(context).build(); + return Flowable.fromIterable(processors) + .concatMapCompletable(f -> f.apply(builder, toolContext)) + .andThen( + Single.fromCallable( + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + }; + } + /** * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link * ResponseProcessor} instances. Emits events for the model response and any subsequent function diff --git a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java index aca399f92..31f17a18e 100644 --- a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java @@ -16,6 +16,8 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.SKILL_FORMAT_ERROR; +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; import static java.nio.channels.Channels.newReader; import static java.nio.charset.StandardCharsets.UTF_8; @@ -82,12 +84,14 @@ private Frontmatter loadFrontmatter(String skillName, PathT skillMdPath) Frontmatter frontmatter = yamlMapper.readValue(yaml, Frontmatter.class); if (!frontmatter.name().equals(skillName)) { throw new SkillSourceException( - "Skill name '%s' does not match directory name '%s'." - .formatted(frontmatter.name(), skillName)); + "Skill name in the frontmatter '%s' does not match skill name '%s'." + .formatted(frontmatter.name(), skillName), + SKILL_LOAD_ERROR); } return frontmatter; } catch (IOException e) { - throw new SkillSourceException("Cannot load frontmatter for skill '" + skillName + "'", e); + throw new SkillSourceException( + "Cannot load frontmatter for skill '" + skillName + "'", SKILL_LOAD_ERROR, e); } } @@ -100,7 +104,9 @@ public Single loadInstructions(String skillName) { return readInstructions(reader); } catch (IOException e) { throw new SkillSourceException( - "Failed to load instruction for skill '" + skillName + "'", e); + "Failed to load instruction for skill '" + skillName + "'", + SKILL_LOAD_ERROR, + e); } }); } @@ -140,7 +146,8 @@ private String readFrontmatterYaml(BufferedReader reader) throws IOException, SkillSourceException { String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + throw new SkillSourceException( + "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); } StringBuilder sb = new StringBuilder(); @@ -151,14 +158,15 @@ private String readFrontmatterYaml(BufferedReader reader) sb.append(line).append("\n"); } throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES); + "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); } private String readInstructions(BufferedReader reader) throws IOException, SkillSourceException { // Skip the frontmatter block String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + throw new SkillSourceException( + "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); } boolean dashClosed = false; while ((line = reader.readLine()) != null) { @@ -169,7 +177,7 @@ private String readInstructions(BufferedReader reader) throws IOException, Skill } if (!dashClosed) { throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES); + "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); } // Read the instructions till the end of the file StringBuilder sb = new StringBuilder(); diff --git a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java index 42916e36a..d299dfb21 100644 --- a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java +++ b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java @@ -16,6 +16,8 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.charset.StandardCharsets.UTF_8; @@ -56,7 +58,8 @@ public Single> listFrontmatters() { public Single> listResources(String skillName, String resourceDirectory) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } String prefix = resourceDirectory.isEmpty() @@ -67,7 +70,8 @@ public Single> listResources(String skillName, String reso && data.resources().keySet().stream().noneMatch(path -> path.startsWith(prefix))) { return Single.error( new SkillSourceException( - "Resource directory not found: " + resourceDirectory + " for skill: " + skillName)); + "Resource directory not found: " + resourceDirectory + " for skill: " + skillName, + RESOURCE_NOT_FOUND)); } return Single.just( @@ -92,13 +96,16 @@ public Single loadResource(String skillName, String resourcePath) { .map(SkillData::resources) .mapOptional(m -> Optional.ofNullable(m.get(resourcePath))) .switchIfEmpty( - Single.error(new SkillSourceException("Resource not found: " + resourcePath))); + Single.error( + new SkillSourceException( + "Resource not found: " + resourcePath, RESOURCE_NOT_FOUND))); } private Single getSkillData(String skillName) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } return Single.just(data); } diff --git a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java index 939c30b3c..b4b7d4876 100644 --- a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java @@ -16,6 +16,10 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.RESOURCE_LOAD_ERROR; +import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.file.Files.isDirectory; @@ -43,14 +47,16 @@ public LocalSkillSource(Path skillsBasePath) { public Single> listResources(String skillName, String resourceDirectory) { Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } Path resourceDir = skillDir.resolve(resourceDirectory); if (!isDirectory(resourceDir)) { return Single.error( new SkillSourceException( "Resource directory '%s' not found for skill '%s'" - .formatted(resourceDirectory, skillName))); + .formatted(resourceDirectory, skillName), + RESOURCE_NOT_FOUND)); } return Single.fromCallable( @@ -67,7 +73,9 @@ public Single> listResources(String skillName, String reso t -> Single.error( new SkillSourceException( - "Failed to traverse resource directory: " + resourceDirectory, t))); + "Failed to traverse resource directory: " + resourceDirectory, + RESOURCE_LOAD_ERROR, + t))); } @Override @@ -78,7 +86,9 @@ protected Flowable listSkills() { t -> Flowable.error( new SkillSourceException( - "Failed to list skills in directory: " + skillsBasePath, t))) + "Failed to list skills in directory: " + skillsBasePath, + SKILL_LOAD_ERROR, + t))) .filter(Files::isDirectory) .mapOptional(this::findSkillMd) .map(skillMd -> new SkillMdPath(skillMd.getParent().getFileName().toString(), skillMd)); @@ -88,7 +98,8 @@ protected Flowable listSkills() { protected Single findResourcePath(String skillName, String resourcePath) { Path file = skillsBasePath.resolve(skillName).resolve(resourcePath); if (!Files.exists(file)) { - return Single.error(new SkillSourceException("Resource not found: " + file)); + return Single.error( + new SkillSourceException("Resource not found: " + file, RESOURCE_NOT_FOUND)); } return Single.just(file); } @@ -97,11 +108,13 @@ protected Single findResourcePath(String skillName, String resourcePath) { protected Single findSkillMdPath(String skillName) { Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error(new SkillSourceException("Skill directory not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill directory not found: " + skillName, SKILL_NOT_FOUND)); } return Maybe.fromOptional(findSkillMd(skillDir)) .switchIfEmpty( - Single.error(new SkillSourceException("SKILL.md not found in " + skillName))); + Single.error( + new SkillSourceException("SKILL.md not found in " + skillName, SKILL_NOT_FOUND))); } @Override diff --git a/core/src/main/java/com/google/adk/skills/SkillSourceException.java b/core/src/main/java/com/google/adk/skills/SkillSourceException.java index be23291da..273428897 100644 --- a/core/src/main/java/com/google/adk/skills/SkillSourceException.java +++ b/core/src/main/java/com/google/adk/skills/SkillSourceException.java @@ -22,11 +22,43 @@ */ public final class SkillSourceException extends Exception { - public SkillSourceException(String message) { + public static final String SKILL_LOAD_ERROR = "SKILL_LOAD_ERROR"; + public static final String SKILL_NOT_FOUND = "SKILL_NOT_FOUND"; + public static final String SKILL_FORMAT_ERROR = "SKILL_FORMAT_ERROR"; + public static final String RESOURCE_LOAD_ERROR = "RESOURCE_LOAD_ERROR"; + public static final String RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"; + + private final String errorCode; + + /** + * Constructs a new exception with the specified detail message and error code. + * + * @param message The detail message. + * @param errorCode The specific error code categorizing the failure. + */ + public SkillSourceException(String message, String errorCode) { super(message); + this.errorCode = errorCode; } - public SkillSourceException(String message, Throwable cause) { + /** + * Constructs a new exception with the specified detail message, error code, and cause. + * + * @param message The detail message. + * @param errorCode The specific error code categorizing the failure. + * @param cause The cause. + */ + public SkillSourceException(String message, String errorCode, Throwable cause) { super(message, cause); + this.errorCode = errorCode; + } + + /** + * Returns the error code categorizing the failure. + * + * @return The error code string. + */ + public String getErrorCode() { + return errorCode; } } diff --git a/core/src/main/java/com/google/adk/tools/BaseTool.java b/core/src/main/java/com/google/adk/tools/BaseTool.java index e7c15af20..7a896ad1a 100644 --- a/core/src/main/java/com/google/adk/tools/BaseTool.java +++ b/core/src/main/java/com/google/adk/tools/BaseTool.java @@ -45,7 +45,7 @@ import org.slf4j.LoggerFactory; /** The base class for all ADK tools. */ -public abstract class BaseTool { +public abstract class BaseTool implements LlmRequestProcessor { private final String name; private final String description; private final boolean isLongRunning; @@ -183,6 +183,7 @@ private Single runAsync( * internal list of tools. Override this method for processing the outgoing request. */ @CanIgnoreReturnValue + @Override public Completable processLlmRequest( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { if (declaration().isEmpty()) { diff --git a/core/src/main/java/com/google/adk/tools/BaseToolset.java b/core/src/main/java/com/google/adk/tools/BaseToolset.java index 76369e5b9..79b2044b4 100644 --- a/core/src/main/java/com/google/adk/tools/BaseToolset.java +++ b/core/src/main/java/com/google/adk/tools/BaseToolset.java @@ -28,7 +28,7 @@ public interface BaseToolset extends AutoCloseable { * Return all tools in the toolset based on the provided context. * * @param readonlyContext Context used to filter tools available to the agent. - * @return A Single emitting a list of tools available under the specified context. + * @return A Flowable emitting tools available under the specified context. */ Flowable getTools(ReadonlyContext readonlyContext); diff --git a/core/src/main/java/com/google/adk/tools/LlmRequestProcessor.java b/core/src/main/java/com/google/adk/tools/LlmRequestProcessor.java new file mode 100644 index 000000000..db5986444 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/LlmRequestProcessor.java @@ -0,0 +1,43 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.models.LlmRequest; +import io.reactivex.rxjava3.core.Completable; + +/** + * Interface for components that can intercept and modify an outgoing {@link LlmRequest} before it + * is sent to the LLM. + * + *

During flow execution, if a {@link BaseToolset} implements this, the toolset's {@link + * #processLlmRequest} method will be executed, followed by executing all the tools' {@link + * #processLlmRequest} method of tools returned by the toolset. + */ +public interface LlmRequestProcessor { + + /** + * Intercepts and modifies the outgoing {@link LlmRequest} using the provided builder. + * + *

Implementations can mutate the {@code llmRequestBuilder} to alter the request that will be + * sent to the LLM. + * + * @param llmRequestBuilder The builder for the outgoing {@link LlmRequest}. + * @param toolContext The context in which the tool or toolset is being executed. + * @return A {@link Completable} that completes when the request processing is done. + */ + Completable processLlmRequest(LlmRequest.Builder llmRequestBuilder, ToolContext toolContext); +} diff --git a/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java b/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java new file mode 100644 index 000000000..bc669632a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; + +/** Tool to list all available skills. */ +final class ListSkillsTool extends BaseTool { + private final SkillSource skillSource; + + ListSkillsTool(SkillSource skillSource) { + super("list_skills", "Lists all available skills with their names and descriptions."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder().type(Type.Known.OBJECT).properties(ImmutableMap.of()).build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return skillSource + .listFrontmatters() + .map(ImmutableMap::values) + .map(SkillToolset::getSkillsPrompt) + .>map(skills -> ImmutableMap.of("skills_xml", skills)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java new file mode 100644 index 000000000..39f68d414 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java @@ -0,0 +1,242 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.net.URLConnection.guessContentTypeFromName; +import static java.net.URLConnection.guessContentTypeFromStream; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.io.ByteSource; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +/** Tool to load resources (references, assets, or scripts) from a skill. */ +final class LoadSkillResourceTool extends BaseTool { + + private static final ImmutableSet EXTRA_TEXT_MIME_TYPES = + ImmutableSet.of( + // go/keep-sorted start + "application/json", + "application/x-python", + "application/x-sh", + "application/x-shar", + "application/x-shellscript", + "application/xml", + "application/yaml" + // go/keep-sorted end + ); + private static final String BINARY_FILE_DETECTED_MSG = + "Binary file detected. The content has been included in the next part of the function" + + " response for you to analyze."; + private static final String SKILL_NAME = "skill_name"; + private static final String FILE_PATH = "file_path"; + private static final String CONTENT = "content"; + private static final String MIME_TYPE = "mime_type"; + + private final SkillSource skillSource; + + LoadSkillResourceTool(SkillSource skillSource) { + super( + "load_skill_resource", + "Loads a resource file (from references/, assets/, or scripts/) from within a skill."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + SKILL_NAME, + Schema.builder() + .type(Type.Known.STRING) + .description("The name of the skill.") + .build(), + FILE_PATH, + Schema.builder() + .type(Type.Known.STRING) + .description( + "The relative path to the resource (e.g.," + + " 'references/my_doc.md', 'assets/template.txt'," + + " or 'scripts/setup.sh').") + .build())) + .required(ImmutableList.of(SKILL_NAME, FILE_PATH)) + .build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + String skillName = (String) args.get(SKILL_NAME); + String resourcePath = (String) args.get(FILE_PATH); + + if (Strings.isNullOrEmpty(skillName)) { + return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); + } + if (Strings.isNullOrEmpty(resourcePath)) { + return createErrorResponse("Resource path is required.", "MISSING_RESOURCE_PATH"); + } + if (!resourcePath.startsWith("references/") + && !resourcePath.startsWith("assets/") + && !resourcePath.startsWith("scripts/")) { + return createErrorResponse( + "Path must start with 'references/', 'assets/', or 'scripts/'.", "INVALID_RESOURCE_PATH"); + } + + return skillSource + .loadResource(skillName, resourcePath) + .>map( + contentSource -> createResult(skillName, resourcePath, contentSource)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } + + private boolean hasBinaryContentResponse(FunctionResponse functionResponse) { + return functionResponse + .response() + .filter( + resp -> + resp.containsKey(SKILL_NAME) + && resp.containsKey(MIME_TYPE) + && resp.get(CONTENT) instanceof byte[]) + .isPresent(); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return super.processLlmRequest(llmRequestBuilder, toolContext) + .andThen( + Completable.fromRunnable( + () -> { + List contents = new ArrayList<>(llmRequestBuilder.build().contents()); + if (contents.isEmpty()) { + return; + } + + Content lastContent = Iterables.getLast(contents); + List parts = lastContent.parts().orElse(ImmutableList.of()); + + // Extract raw binary content into a dedicated binary Part + ImmutableList updatedParts = + parts.stream().flatMap(this::processPart).collect(toImmutableList()); + + if (!updatedParts.isEmpty()) { + contents.set( + contents.size() - 1, lastContent.toBuilder().parts(updatedParts).build()); + llmRequestBuilder.contents(contents); + } + })); + } + + /** + * Processes a {@link Part} to extract raw binary content from a function response. + * + *

If the part is a function response from this tool containing binary data, it returns a + * stream containing the updated function response part (with a placeholder message) and a new + * part containing the raw binary data. Otherwise, it returns a stream of the original {@link + * Part}. + * + * @param part the {@link Part} to process + * @return a stream containing the processed parts + */ + private Stream processPart(Part part) { + FunctionResponse binaryFunctionResponse = + part.functionResponse() + .filter(funcResp -> funcResp.name().orElse("").equals(name())) + .filter(this::hasBinaryContentResponse) + .orElse(null); + + if (binaryFunctionResponse == null) { + return Stream.of(part); + } + + return binaryFunctionResponse.response().stream() + .flatMap( + response -> { + Map newResponse = new HashMap<>(response); + + String mimeType = newResponse.remove(MIME_TYPE).toString(); + byte[] binaryContent = + (byte[]) newResponse.replace(CONTENT, BINARY_FILE_DETECTED_MSG); + + Part updatedPart = + part.toBuilder() + .functionResponse(binaryFunctionResponse.toBuilder().response(newResponse)) + .build(); + Part binaryPart = Part.fromBytes(binaryContent, mimeType); + + return Stream.of(updatedPart, binaryPart); + }); + } + + private ImmutableMap createResult( + String skillName, String resourcePath, ByteSource contentSource) throws IOException { + byte[] bytes = contentSource.read(); + // Special handling of shell script as the guessContentTypeFromName would return + // application/x-shar + String contentType = + resourcePath.endsWith(".sh") || resourcePath.endsWith(".bash") + ? "application/x-sh" + : guessContentTypeFromName(resourcePath); + if (contentType == null) { + contentType = guessContentTypeFromStream(new ByteArrayInputStream(bytes)); + } + if (contentType == null) { + contentType = "application/octet-stream"; + } + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put(SKILL_NAME, skillName).put(FILE_PATH, resourcePath).put(MIME_TYPE, contentType); + + if (contentType.startsWith("text/") || EXTRA_TEXT_MIME_TYPES.contains(contentType)) { + builder.put(CONTENT, new String(bytes, UTF_8)); + } else { + builder.put(CONTENT, bytes); + } + return builder.buildOrThrow(); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java new file mode 100644 index 000000000..6d47e297a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java @@ -0,0 +1,87 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; + +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; + +/** Tool to load a skill's instructions. */ +final class LoadSkillTool extends BaseTool { + + private static final String SKILL_NAME = "skill_name"; + private final SkillSource skillSource; + + LoadSkillTool(SkillSource skillSource) { + super("load_skill", "Loads the SKILL.md instructions for a given skill."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + SKILL_NAME, + Schema.builder() + .type(Type.Known.STRING) + .description("The name of the skill to load.") + .build())) + .required(ImmutableList.of(SKILL_NAME)) + .build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + String skillName = (String) args.get(SKILL_NAME); + if (Strings.isNullOrEmpty(skillName)) { + return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); + } + + return skillSource + .loadFrontmatter(skillName) + .>zipWith( + skillSource.loadInstructions(skillName), + (frontmatter, instructions) -> + ImmutableMap.of( + "skill_name", + skillName, + "frontmatter", + frontmatter, + "instructions", + instructions)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java b/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java new file mode 100644 index 000000000..979592c9e --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java @@ -0,0 +1,131 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static java.util.Optional.ofNullable; + +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.LlmRequestProcessor; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Collection; +import java.util.Map; +import java.util.StringJoiner; + +/** + * A toolset for managing and interacting with agent skills. Provides tools to list, load, and run + * skills. + */ +public class SkillToolset implements BaseToolset, LlmRequestProcessor { + + private static final String DEFAULT_SKILL_SYSTEM_INSTRUCTION = + """ + You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. + + Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: + - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. + - **references/** (Optional): Additional documentation or examples for skill usage. + - **assets/** (Optional): Templates, scripts or other resources used by the skill. + - **scripts/** (Optional): Executable scripts that can be run via bash. + + This is very important: + + 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `skill_name=""` to read its full instructions before proceeding. + 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. + 3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. + 4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. + """; + + private final SkillSource skillSource; + private final ImmutableList coreTools; + private final String systemInstruction; + + /** Initializes the SkillToolset with a SkillSource and default execution settings. */ + public SkillToolset(SkillSource skillSource) { + this(skillSource, DEFAULT_SKILL_SYSTEM_INSTRUCTION); + } + + /** Initializes the SkillToolset with a SkillSource. */ + public SkillToolset(SkillSource skillSource, String systemInstruction) { + this.skillSource = skillSource; + this.systemInstruction = systemInstruction; + this.coreTools = + ImmutableList.of( + new ListSkillsTool(skillSource), + new LoadSkillTool(skillSource), + new LoadSkillResourceTool(skillSource)); + } + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.fromIterable(coreTools); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return skillSource + .listFrontmatters() + .map(ImmutableMap::values) + .map(SkillToolset::getSkillsPrompt) + .map( + skills -> + llmRequestBuilder.appendInstructions(ImmutableList.of(systemInstruction, skills))) + .ignoreElement(); + } + + @Override + public void close() throws Exception { + // No resources to release for now + } + + static Single> createErrorResponse(String errorMessage, String errorCode) { + return Single.just(ImmutableMap.of("error", errorMessage, "error_code", errorCode)); + } + + static Single> createErrorResponse(Throwable t) { + if (t instanceof SkillSourceException ex) { + return Single.just( + ImmutableMap.of( + "error", + ofNullable(ex.getMessage()).orElse(ex.toString()), + "error_code", + ex.getErrorCode())); + } + return Single.error(t); + } + + static String getSkillsPrompt(Collection frontmatters) { + return frontmatters.stream() + .map(Frontmatter::toXml) + .reduce( + new StringJoiner("\n", "", "").setEmptyValue(""), + StringJoiner::add, + StringJoiner::merge) + .toString(); + } +} diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index e40a83aa0..35cf12f6f 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -440,7 +440,7 @@ public void close_closesToolsets() throws Exception { } @Test - public void close_closesToolsetsOnException() throws Exception { + public void close_closesToolsetsOnException() { ClosableToolset toolset1 = new ClosableToolset() { @Override diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 2a06c1f0a..a4ec3a7c0 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -25,9 +25,12 @@ import static com.google.adk.testing.TestUtils.createTestLlm; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.ReadonlyContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; @@ -35,6 +38,8 @@ import com.google.adk.models.LlmResponse; import com.google.adk.testing.TestLlm; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.LlmRequestProcessor; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -47,6 +52,7 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -762,4 +768,97 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { assertThat(event.author()).isEqualTo(invocationContext.agent().name()); assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); } + + @Test + public void run_withTools_sequentiallyAppliesToolProcessors() { + BaseTool tool1 = + new BaseTool("tool1", "test tool 1") { + @Override + public Completable processLlmRequest( + LlmRequest.Builder builder, ToolContext toolContext) { + return Completable.fromAction( + () -> builder.appendInstructions(ImmutableList.of("instruction1"))); + } + }; + BaseTool tool2 = + new BaseTool("tool2", "test tool 2") { + @Override + public Completable processLlmRequest( + LlmRequest.Builder builder, ToolContext toolContext) { + return Completable.fromAction( + () -> builder.appendInstructions(ImmutableList.of("instruction2"))); + } + }; + + TestLlm testLlm = createTestLlm(LlmResponse.builder().build()); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(tool1, tool2).build(); + + InvocationContext invocationContext = createInvocationContext(agent); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getLastRequest().getSystemInstructions()) + .containsExactly("instruction1\n\ninstruction2"); + } + + @Test + public void run_withToolset_appliesToolsetAndItsToolsProcessors() { + TestToolset toolset = new TestToolset(); + TestLlm testLlm = createTestLlm(LlmResponse.builder().build()); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(toolset).build(); + + InvocationContext invocationContext = createInvocationContext(agent); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getLastRequest().getSystemInstructions()) + .containsExactly("toolset-instruction\n\ntool-instruction"); + } + + private static final class TestToolset implements BaseToolset, LlmRequestProcessor { + private final BaseTool tool1 = + new BaseTool("tool1", "test tool 1") { + @Override + public Completable processLlmRequest( + LlmRequest.Builder builder, ToolContext toolContext) { + return Completable.fromAction( + () -> builder.appendInstructions(ImmutableList.of("tool-instruction"))); + } + }; + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just(tool1); + } + + @Override + public Completable processLlmRequest(LlmRequest.Builder builder, ToolContext toolContext) { + return Completable.fromAction( + () -> builder.appendInstructions(ImmutableList.of("toolset-instruction"))); + } + + @Override + public void close() {} + } + + @Test + public void run_withUnsupportedToolType_throwsIllegalArgumentException() { + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .tools("unsupported-tool-type-string") + .build(); + + InvocationContext invocationContext = createInvocationContext(agent); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + Flowable flow = baseLlmFlow.run(invocationContext); + Single> single = flow.toList(); + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, single::blockingGet); + assertThat(thrown) + .hasMessageThat() + .contains("Object in tools list is not of a supported type: java.lang.String"); + } } diff --git a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java index 256f1d66a..5efdb4ab9 100644 --- a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java +++ b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java @@ -224,6 +224,8 @@ public void testLoadResource_notFound() throws IOException { var single = source.loadResource("my-skill", "non-existent.txt"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + SkillSourceException cause = (SkillSourceException) exception.getCause(); + assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.RESOURCE_NOT_FOUND); } @Test @@ -234,6 +236,8 @@ public void testLoadFrontmatter_skillNotFound() { var single = source.loadFrontmatter("non-existent"); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + SkillSourceException cause = (SkillSourceException) exception.getCause(); + assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.SKILL_NOT_FOUND); } @Test @@ -249,5 +253,125 @@ public void testListSkillMdPaths_skillSourceException() throws IOException { var single = source.listFrontmatters(); RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + SkillSourceException cause = (SkillSourceException) exception.getCause(); + assertThat(cause.getErrorCode()).isEqualTo(SkillSourceException.SKILL_LOAD_ERROR); + } + + @Test + public void testLoadFrontmatter_missingStartDashes() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + name: my-skill + description: This is a test skill + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadInstructions_missingStartDashes() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + name: my-skill + description: Test + --- + Some Markdown Body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadInstructions("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadFrontmatter_nameMismatch() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + --- + name: other-skill + description: This is a test skill + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains( + "Skill name in the frontmatter 'other-skill' does not match skill name 'my-skill'."); + } + + @Test + public void testLoadFrontmatter_emptyFile() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString(skillDir.resolve("SKILL.md"), ""); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadInstructions_emptyFile() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString(skillDir.resolve("SKILL.md"), ""); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadInstructions("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); } } diff --git a/core/src/test/java/com/google/adk/testing/TestLlm.java b/core/src/test/java/com/google/adk/testing/TestLlm.java index aaacf00a0..fc9ce3850 100644 --- a/core/src/test/java/com/google/adk/testing/TestLlm.java +++ b/core/src/test/java/com/google/adk/testing/TestLlm.java @@ -42,7 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; import java.util.function.Supplier; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** * A test implementation of {@link BaseLlm}. diff --git a/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java b/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java new file mode 100644 index 000000000..fe5b202a2 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ListSkillsToolTest { + + @Test + public void call_listSkillsTool_success() { + Frontmatter testFrontmatter = + Frontmatter.builder().name("test-skill").description("test skill").build(); + + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = + InMemorySkillSource.builder() + .skill(testFrontmatter.name()) + .frontmatter(testFrontmatter) + .instructions("Test instructions") + .build(); + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skills_xml", "" + testFrontmatter.toXml() + ""); + } + + @Test + public void call_listSkillsTool_empty() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + ListSkillsTool listSkillsTool = new ListSkillsTool(InMemorySkillSource.builder().build()); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response).containsExactly("skills_xml", ""); + } + + @Test + public void call_listSkillsTool_skillSourceException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + when(skillSource.listFrontmatters()) + .thenReturn( + Single.error(new SkillSourceException("Failed to list skills", SKILL_LOAD_ERROR))); + + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly("error", "Failed to list skills", "error_code", "SKILL_LOAD_ERROR"); + } + + @Test + public void call_listSkillsTool_otherException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + RuntimeException expectedException = new RuntimeException("Unexpected error"); + when(skillSource.listFrontmatters()).thenReturn(Single.error(expectedException)); + + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + var single = + listSkillsTool.runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()); + + RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); + + assertThat(thrown).hasMessageThat().contains("Unexpected error"); + } + + @Test + public void call_listSkillsTool_declaration() { + ListSkillsTool listSkillsTool = new ListSkillsTool(mock(SkillSource.class)); + assertThat(listSkillsTool.declaration().get().name()).hasValue("list_skills"); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java new file mode 100644 index 000000000..3ab9b2a17 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java @@ -0,0 +1,330 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LoadSkillResourceToolTest { + + @Test + public void call_loadSkillResourceTool_reference_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "references/my_doc.md", + "mime_type", "text/markdown", + "content", "doc content"); + } + + @Test + public void call_loadSkillResourceTool_asset_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/template.txt"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "assets/template.txt", + "mime_type", "text/plain", + "content", "asset content"); + } + + @Test + public void call_loadSkillResourceTool_script_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "scripts/setup.sh"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "scripts/setup.sh", + "mime_type", "application/x-sh", + "content", "echo hello"); + } + + @Test + public void call_loadSkillResourceTool_streamDetection_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/data_no_ext"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "assets/data_no_ext", + "mime_type", "application/xml", + "content", ""); + } + + @Test + public void call_loadSkillResourceTool_binaryReference_detected() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + ToolContext toolContext = createToolContext(); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/binary.dat"), + toolContext) + .blockingGet(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(loadSkillResourceTool.name()) + .response(response) + .build()) + .build(); + + // Binary data is added as separate part in the next request to LLM + LlmRequest.Builder builder = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.builder().role("user").parts(partFunctionResponse).build())); + loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); + + List contents = builder.build().contents(); + assertThat(contents).hasSize(1); + List parts = contents.get(0).parts().get(); + assertThat(parts).hasSize(2); + + FunctionResponse updatedFunctionResponse = parts.get(0).functionResponse().get(); + assertThat(updatedFunctionResponse.response().get()) + .containsExactly( + "skill_name", + "test-skill", + "file_path", + "references/binary.dat", + "content", + "Binary file detected. The content has been included in the next part of the function" + + " response for you to analyze."); + + Part binaryPart = parts.get(1); + assertThat(binaryPart.inlineData().get().mimeType()).hasValue("application/octet-stream"); + assertThat(binaryPart.inlineData().get().data().get()).isEqualTo(new byte[] {0, 1, 2, 3}); + } + + @Test + public void call_loadSkillResourceTool_nonBinaryReference_notChanged() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + ToolContext toolContext = createToolContext(); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), + toolContext) + .blockingGet(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(loadSkillResourceTool.name()) + .response(response) + .build()) + .build(); + + LlmRequest.Builder builder = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.builder().role("user").parts(partFunctionResponse).build())); + List expectedContents = builder.build().contents(); + + loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); + + assertThat(builder.build().contents()).isEqualTo(expectedContents); + } + + @Test + public void call_loadSkillResourceTool_missingSkillName() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync(ImmutableMap.of("file_path", "references/my_doc.md"), createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Skill name is required.", + "error_code", "MISSING_SKILL_NAME"); + } + + @Test + public void call_loadSkillResourceTool_missingPath() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync(ImmutableMap.of("skill_name", "test-skill"), createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Resource path is required.", + "error_code", "MISSING_RESOURCE_PATH"); + } + + @Test + public void call_loadSkillResourceTool_skillNotFound() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "other-skill", "file_path", "references/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Skill not found: other-skill", + "error_code", "SKILL_NOT_FOUND"); + } + + @Test + public void call_loadSkillResourceTool_invalidPathPrefix() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "invalid/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Path must start with 'references/', 'assets/', or 'scripts/'.", + "error_code", "INVALID_RESOURCE_PATH"); + } + + @Test + public void call_loadSkillResourceTool_resourceNotFound() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/missing.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Resource not found: references/missing.md", + "error_code", "RESOURCE_NOT_FOUND"); + } + + @Test + public void call_loadSkillResourceTool_declaration() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(mock(SkillSource.class)); + assertThat(loadSkillResourceTool.declaration().get().name()).hasValue("load_skill_resource"); + } + + @Test + public void call_loadSkillResourceTool_processLlmRequest_emptyContents() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(mock(SkillSource.class)); + LlmRequest.Builder builder = LlmRequest.builder().contents(ImmutableList.of()); + loadSkillResourceTool.processLlmRequest(builder, createToolContext()).blockingAwait(); + assertThat(builder.build().contents()).isEmpty(); + } + + private ToolContext createToolContext() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + return ToolContext.builder(invocationContext).build(); + } + + private SkillSource createTestSkillSource() { + return InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .addResource("references/my_doc.md", "doc content".getBytes(UTF_8)) + .addResource("references/binary.dat", new byte[] {0, 1, 2, 3}) + .addResource("assets/template.txt", "asset content".getBytes(UTF_8)) + .addResource("scripts/setup.sh", "echo hello".getBytes(UTF_8)) + .addResource("assets/data_no_ext", "".getBytes(UTF_8)) + .build(); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java new file mode 100644 index 000000000..29b127755 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LoadSkillToolTest { + + @Test + public void call_loadSkillTool_success() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", + "test-skill", + "instructions", + "Test instructions", + "frontmatter", + Frontmatter.builder().name("test-skill").description("test skill").build()); + } + + @Test + public void call_loadSkillTool_missingSkillName() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly("error", "Skill name is required.", "error_code", "MISSING_SKILL_NAME"); + } + + @Test + public void call_loadSkillTool_skillSourceException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + when(skillSource.loadFrontmatter("test-skill")) + .thenReturn(Single.error(new SkillSourceException("Skill not found", SKILL_NOT_FOUND))); + when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); + + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response).containsExactly("error", "Skill not found", "error_code", SKILL_NOT_FOUND); + } + + @Test + public void call_loadSkillTool_otherException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + RuntimeException expectedException = new RuntimeException("Unexpected error"); + when(skillSource.loadFrontmatter("test-skill")).thenReturn(Single.error(expectedException)); + when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); + + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + var single = + loadSkillTool.runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()); + + RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); + + assertThat(thrown).hasMessageThat().contains("Unexpected error"); + } + + @Test + public void call_loadSkillTool_declaration() { + LoadSkillTool loadSkillTool = new LoadSkillTool(mock(SkillSource.class)); + assertThat(loadSkillTool.declaration().get().name()).hasValue("load_skill"); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java b/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java new file mode 100644 index 000000000..d2a25439a --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java @@ -0,0 +1,125 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; + +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.truth.Correspondence; +import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SkillToolsetTest { + + @Test + public void getTools_returnsCoreTools() throws Exception { + SkillSource mockSkillSource = mock(SkillSource.class); + try (SkillToolset toolSet = new SkillToolset(mockSkillSource)) { + Flowable tools = toolSet.getTools(null); + List baseTools = tools.toList().blockingGet(); + + assertThat(baseTools) + .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) + .containsExactly("list_skills", "load_skill", "load_skill_resource"); + } + } + + @Test + public void getTools_withInMemorySkills() throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + try (SkillToolset toolSet = new SkillToolset(skillSource)) { + + Flowable tools = toolSet.getTools(null); + List baseTools = tools.toList().blockingGet(); + + assertThat(baseTools) + .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) + .containsExactly("list_skills", "load_skill", "load_skill_resource"); + } + } + + @Test + public void processLlmRequest_addsInstructions() throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + try (SkillToolset toolSet = new SkillToolset(skillSource)) { + + LlmRequest.Builder requestBuilder = LlmRequest.builder(); + ToolContext mockToolContext = mock(ToolContext.class); + + toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); + + LlmRequest request = requestBuilder.build(); + ImmutableList instructions = request.getSystemInstructions(); + + assertThat(instructions).isNotEmpty(); + String instruction = instructions.get(0); + assertThat(instruction) + .contains("You can use specialized 'skills' to help you with complex tasks"); + assertThat(instruction).contains(""); + assertThat(instruction).contains("test-skill"); + } + } + + @Test + public void processLlmRequest_withCustomSystemInstruction_addsCustomInstructions() + throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + String customInstruction = "Custom system instruction for testing."; + try (SkillToolset toolSet = new SkillToolset(skillSource, customInstruction)) { + + LlmRequest.Builder requestBuilder = LlmRequest.builder(); + ToolContext mockToolContext = mock(ToolContext.class); + + toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); + + LlmRequest request = requestBuilder.build(); + ImmutableList instructions = request.getSystemInstructions(); + + assertThat(instructions).isNotEmpty(); + String instruction = instructions.get(0); + assertThat(instruction).contains(customInstruction); + assertThat(instruction).contains(""); + assertThat(instruction).contains("test-skill"); + } + } +}