From cc5263cab5f13615333c80def2edd497968115a8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 19 May 2026 07:59:40 -0700 Subject: [PATCH] feat: remove special handling for Gemini 3 function response ordering PiperOrigin-RevId: 917837583 --- .../adk/flows/llmflows/BaseLlmFlow.java | 3 +- .../google/adk/flows/llmflows/Contents.java | 65 +- .../java/com/google/adk/models/Gemini.java | 317 +++++---- .../adk/flows/llmflows/BaseLlmFlowTest.java | 27 + .../adk/flows/llmflows/ContentsTest.java | 60 +- .../com/google/adk/models/GeminiTest.java | 647 +++++++++++++++++- .../google/adk/tutorials/CityTimeWeather.java | 2 +- 7 files changed, 915 insertions(+), 206 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..67c29ae77 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -685,7 +685,8 @@ private Flowable buildPostprocessingEvents( Event modelResponseEvent = buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - if (modelResponseEvent.functionCalls().isEmpty()) { + if (modelResponseEvent.functionCalls().isEmpty() + || modelResponseEvent.partial().orElse(false)) { return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 876f3a206..36e93014e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -57,13 +57,6 @@ public Single processRequest( } LlmAgent llmAgent = (LlmAgent) context.agent(); - String modelName; - try { - modelName = llmAgent.resolvedModel().modelName().orElse(""); - } catch (IllegalStateException e) { - modelName = ""; - } - ImmutableList sessionEvents; synchronized (context.session().events()) { sessionEvents = ImmutableList.copyOf(context.session().events()); @@ -75,17 +68,13 @@ public Single processRequest( request.toBuilder() .contents( getCurrentTurnContents( - context.branch().orElse(null), - sessionEvents, - context.agent().name(), - modelName)) + context.branch().orElse(null), sessionEvents, context.agent().name())) .build(), ImmutableList.of())); } ImmutableList contents = - getContents( - context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); + getContents(context.branch().orElse(null), sessionEvents, context.agent().name()); return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -94,19 +83,19 @@ public Single processRequest( /** Gets contents for the current turn only (no conversation history). */ private ImmutableList getCurrentTurnContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { // Find the latest event that starts the current turn and process from there. for (int i = events.size() - 1; i >= 0; i--) { Event event = events.get(i); if (event.author().equals("user") || isOtherAgentReply(agentName, event)) { - return getContents(currentBranch, events.subList(i, events.size()), agentName, modelName); + return getContents(currentBranch, events.subList(i, events.size()), agentName); } } return ImmutableList.of(); } private ImmutableList getContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { List filteredEvents = new ArrayList<>(); boolean hasCompactEvent = false; @@ -148,7 +137,7 @@ private ImmutableList getContents( } List resultEvents = rearrangeEventsForLatestFunctionResponse(filteredEvents); - resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents, modelName); + resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents); return resultEvents.stream() .map(Event::content) @@ -564,8 +553,7 @@ private static List rearrangeEventsForLatestFunctionResponse(List return resultEvents; } - private static List rearrangeEventsForAsyncFunctionResponsesInHistory( - List events, String modelName) { + private static List rearrangeEventsForAsyncFunctionResponsesInHistory(List events) { Map functionCallIdToResponseEventIndex = new HashMap<>(); for (int i = 0; i < events.size(); i++) { final int index = i; @@ -592,11 +580,6 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( List resultEvents = new ArrayList<>(); // Keep track of response events already added to avoid duplicates when merging Set processedResponseIndices = new HashSet<>(); - List responseEventsBuffer = new ArrayList<>(); - - // Gemini 3 requires function calls to be grouped first and only then function responses: - // FC1 FC2 FR1 FR2 - boolean shouldBufferResponseEvents = modelName.contains("gemini-3"); for (int i = 0; i < events.size(); i++) { Event event = events.get(i); @@ -641,47 +624,21 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int index : sortedIndices) { if (processedResponseIndices.add(index)) { // Add index and check if it was newly added - responseEventsBuffer.add(events.get(index)); responseEventsToAdd.add(events.get(index)); } } - if (!shouldBufferResponseEvents) { - if (responseEventsToAdd.size() == 1) { - resultEvents.add(responseEventsToAdd.get(0)); - } else if (responseEventsToAdd.size() > 1) { - resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); - } + if (responseEventsToAdd.size() == 1) { + resultEvents.add(responseEventsToAdd.get(0)); + } else if (responseEventsToAdd.size() > 1) { + resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); } } } else { - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } resultEvents.add(event); } } - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } - return resultEvents; } diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 6f145e1de..06da9a97d 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -19,11 +19,11 @@ import static com.google.common.base.StandardSystemProperty.JAVA_VERSION; import com.google.adk.Version; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.Client; import com.google.genai.ResponseStream; -import com.google.genai.types.Candidate; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentConfig; @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -226,21 +225,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre () -> processRawResponses( Flowable.fromFuture(streamFuture).flatMapIterable(iterable -> iterable))) - .filter( - llmResponse -> - llmResponse - .content() - .flatMap(Content::parts) - .map( - parts -> - !parts.isEmpty() - && parts.stream() - .anyMatch( - p -> - p.functionCall().isPresent() - || p.functionResponse().isPresent() - || p.text().isPresent())) - .orElse(false)); + .filter(Gemini::shouldEmit); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); return Flowable.fromFuture( @@ -253,109 +238,68 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre } static Flowable processRawResponses(Flowable rawResponses) { - final StringBuilder accumulatedText = new StringBuilder(); - final StringBuilder accumulatedThoughtText = new StringBuilder(); - // Array to bypass final local variable reassignment in lambda. - final GenerateContentResponse[] lastRawResponseHolder = {null}; - return rawResponses - .concatMap( - rawResponse -> { - lastRawResponseHolder[0] = rawResponse; - logger.trace("Raw streaming response: {}", rawResponse); - - List responsesToEmit = new ArrayList<>(); - LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); - Optional part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse); - String currentTextChunk = part.flatMap(Part::text).orElse(""); - - if (!currentTextChunk.isBlank()) { - if (part.get().thought().orElse(false)) { - accumulatedThoughtText.append(currentTextChunk); - responsesToEmit.add( - thinkingResponseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } else { - accumulatedText.append(currentTextChunk); - responsesToEmit.add( - responseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } - } else { - if (accumulatedThoughtText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedThoughtResponse = - thinkingResponseFromText(accumulatedThoughtText.toString()); - responsesToEmit.add(aggregatedThoughtResponse); - accumulatedThoughtText.setLength(0); - } - if (accumulatedText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString()); - responsesToEmit.add(aggregatedTextResponse); - accumulatedText.setLength(0); - } - responsesToEmit.add(currentProcessedLlmResponse); - } - logger.debug("Responses to emit: {}", responsesToEmit); - return Flowable.fromIterable(responsesToEmit); - }) - .concatWith( - Flowable.defer( - () -> { - GenerateContentResponse finalRawResp = lastRawResponseHolder[0]; - if (finalRawResp == null) { - return Flowable.empty(); - } - boolean isStop = - finalRawResp - .candidates() - .flatMap(candidates -> candidates.stream().findFirst()) - .flatMap(Candidate::finishReason) - .map(finishReason -> finishReason.knownEnum() == FinishReason.Known.STOP) - .orElse(false); - - if (isStop) { - List finalResponses = new ArrayList<>(); - if (accumulatedThoughtText.length() > 0) { - finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() - .usageMetadata( - accumulatedText.length() > 0 - ? null - : finalRawResp.usageMetadata().orElse(null)) - .build()); - } - if (accumulatedText.length() > 0) { - finalResponses.add( - responseFromText(accumulatedText.toString()).toBuilder() - .usageMetadata(finalRawResp.usageMetadata().orElse(null)) - .build()); - } - - return Flowable.fromIterable(finalResponses); - } - return Flowable.empty(); - })); + return Flowable.defer(() -> new StreamingResponseAggregator().process(rawResponses)); } - private static LlmResponse responseFromText(String accumulatedText) { - return LlmResponse.builder() - .content(Content.builder().role("model").parts(Part.fromText(accumulatedText)).build()) - .build(); + /** + * Returns true if {@code response} should be emitted downstream by the streaming pipeline. + * + *

Drops chunks that carry neither semantic content (i.e. they are an empty-text-only response + * per {@link #isEmptyTextOnlyResponse}) nor any useful metadata (per {@link #hasUsefulMetadata}). + * + *

Package-private for testing. + */ + static boolean shouldEmit(LlmResponse response) { + return !isEmptyTextOnlyResponse(response) || hasUsefulMetadata(response); } - private static LlmResponse thinkingResponseFromText(String accumulatedThoughtText) { - return LlmResponse.builder() - .content( - Content.builder() - .role("model") - .parts(Part.fromText(accumulatedThoughtText).toBuilder().thought(true).build()) - .build()) - .build(); + /** + * Returns true if {@code response} carries any non-content metadata that should be propagated + * downstream (e.g. {@code usageMetadata}, {@code finishReason}, transcriptions, grounding or + * error info). Inspects only top-level {@link LlmResponse} fields; the response's content/parts + * are intentionally not considered here. + */ + private static boolean hasUsefulMetadata(LlmResponse response) { + return response.usageMetadata().isPresent() + || response.finishReason().isPresent() + || response.errorCode().isPresent() + || response.groundingMetadata().isPresent() + || response.inputTranscription().isPresent() + || response.outputTranscription().isPresent(); + } + + /** + * Returns true if {@code response} consists of exactly one {@link Part} whose only meaningful + * payload is an empty text string (i.e. {@code parts:[{text:""}]}). Such a chunk can be safely + * dropped from the streaming aggregator because it carries no semantic content for the agent + * pipeline. A part is considered to carry semantic content if any of its non-text payloads + * ({@code functionCall}, {@code functionResponse}, {@code inlineData}, {@code executableCode}, + * {@code codeExecutionResult}, {@code fileData}, {@code thoughtSignature}, {@code videoMetadata}, + * {@code toolCall}, {@code toolResponse}) is present. + */ + private static boolean isEmptyTextOnlyResponse(LlmResponse response) { + return response + .content() + .flatMap(Content::parts) + .map( + parts -> { + if (parts.size() != 1) { + return false; + } + Part part = parts.get(0); + return part.text().map(String::isEmpty).orElse(false) + && part.functionCall().isEmpty() + && part.functionResponse().isEmpty() + && part.inlineData().isEmpty() + && part.executableCode().isEmpty() + && part.codeExecutionResult().isEmpty() + && part.fileData().isEmpty() + && part.thoughtSignature().isEmpty() + && part.videoMetadata().isEmpty() + && part.toolCall().isEmpty() + && part.toolResponse().isEmpty(); + }) + .orElse(false); } @Override @@ -372,4 +316,145 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } + + private static final class StreamingResponseAggregator { + private final List accumulatedSequence = new ArrayList<>(); + private final StringBuilder currentTextBuffer = new StringBuilder(); + private boolean currentTextIsThought = false; + private byte[] currentThoughtSignature = null; + private GenerateContentResponse lastRawResponse = null; + + /** + * Processes a stream of raw responses, emitting partial and aggregated {@link LlmResponse}s. + */ + private Flowable process(Flowable rawResponses) { + return rawResponses + .concatMap(this::processRawResponse) + .concatWith(Flowable.defer(this::processFinalResponse)); + } + + /** + * Processes a single raw streaming chunk, accumulating parts and emitting intermediate + * responses. + */ + private Flowable processRawResponse(GenerateContentResponse rawResponse) { + lastRawResponse = rawResponse; + logger.trace("Raw streaming response: {}", rawResponse); + + LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); + List parts = + currentProcessedLlmResponse.content().flatMap(Content::parts).orElse(ImmutableList.of()); + + if (accumulateParts(parts)) { + return Flowable.just(currentProcessedLlmResponse.toBuilder().partial(true).build()); + } + + // If the chunk has no text or function calls (e.g. metadata-only or empty), we suppress it + // during streaming so it doesn't emit an empty partial response. + // Exception: If this is a standalone empty chunk in an otherwise completely empty stream + // (and not a STOP chunk), we emit it directly as a non-partial empty response. + if (!isStop(currentProcessedLlmResponse) + && accumulatedSequence.isEmpty() + && currentTextBuffer.isEmpty()) { + return Flowable.just(currentProcessedLlmResponse.toBuilder().partial(null).build()); + } + + return Flowable.empty(); + } + + /** + * Accumulates text and function calls from incoming parts. + * + * @return true if any text or function call was present, false otherwise. + */ + private boolean accumulateParts(List parts) { + boolean hasTextOrFc = false; + for (Part part : parts) { + part.thoughtSignature().ifPresent(sig -> currentThoughtSignature = sig); + String text = part.text().orElse(""); + if (!text.isEmpty()) { + hasTextOrFc = true; + boolean isThought = part.thought().orElse(false); + // Immediately flush the active text buffer to preserve the exact interleaved blocks of + // text/thoughts. + if (!currentTextBuffer.isEmpty() && isThought != currentTextIsThought) { + flushTextBufferToSequence(); + } + if (currentTextBuffer.isEmpty()) { + currentTextIsThought = isThought; + } + currentTextBuffer.append(text); + } + if (part.functionCall().isPresent()) { + hasTextOrFc = true; + flushTextBufferToSequence(); + accumulatedSequence.add(part); + } + } + return hasTextOrFc; + } + + /** Flushes any accumulated text or thought content in the buffer as a new {@link Part}. */ + private void flushTextBufferToSequence() { + if (!currentTextBuffer.isEmpty()) { + Part.Builder partBuilder = + Part.builder().text(currentTextBuffer.toString()).thought(currentTextIsThought); + if (currentThoughtSignature != null) { + partBuilder.thoughtSignature(currentThoughtSignature); + currentThoughtSignature = null; + } + accumulatedSequence.add(partBuilder.build()); + currentTextBuffer.setLength(0); + currentTextIsThought = false; + } + } + + /** + * Emits a single final aggregated, non-partial response carrying the complete chronological + * sequence of accumulated parts (thoughts, text, and function calls) when the stream completes + * with a STOP reason. + */ + private Flowable processFinalResponse() { + if (lastRawResponse == null) { + return Flowable.empty(); + } + LlmResponse currentResponse = LlmResponse.create(lastRawResponse); + if (!isStop(currentResponse)) { + return Flowable.empty(); + } + + flushTextBufferToSequence(); + + if (accumulatedSequence.isEmpty()) { + return Flowable.just(currentResponse.toBuilder().partial(null).build()); + } + + LlmResponse.Builder finalResponseBuilder = currentResponse.toBuilder().partial(null); + + // If the final STOP chunk carries a thoughtSignature (e.g. from a preceding function call or + // thought), attach it to the last accumulated part in the sequence. + GeminiUtil.getPart0FromLlmResponse(currentResponse) + .flatMap(Part::thoughtSignature) + .ifPresent( + signature -> { + int targetIndex = accumulatedSequence.size() - 1; + Part targetPart = accumulatedSequence.get(targetIndex); + accumulatedSequence.set( + targetIndex, targetPart.toBuilder().thoughtSignature(signature).build()); + }); + + return Flowable.just( + finalResponseBuilder + .content(Content.builder().role("model").parts(accumulatedSequence).build()) + .build()); + } + + /** Checks whether the response finish reason indicates the stream has finished with STOP. */ + private static boolean isStop(LlmResponse response) { + return response + .finishReason() + .map(reason -> reason.knownEnum() == FinishReason.Known.STOP) + .orElse(false); + } + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 2a06c1f0a..70706d7fe 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -216,6 +216,33 @@ public void run_withLongRunningFunctionCall_returnsCorrectEventsWithLongRunningT assertThat(events.get(2).content()).hasValue(secondContent); } + @Test + public void run_withPartialFunctionCall_doesNotExecuteTool() { + Content partialContent = + Content.fromParts(Part.fromFunctionCall("my_function", ImmutableMap.of("arg1", "value1"))); + LlmResponse partialResponse = + LlmResponse.builder().content(partialContent).partial(true).build(); + TestLlm testLlm = createTestLlm(partialResponse); + ImmutableMap testResponse = + ImmutableMap.of("response", "response for my_function"); + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(new TestTool("my_function", testResponse))) + .build()); + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + /* requestProcessors= */ ImmutableList.of(), + /* responseProcessors= */ ImmutableList.of(), + /* maxSteps= */ Optional.of(1)); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).partial()).hasValue(true); + assertThat(events.get(0).functionCalls()).hasSize(1); + } + @Test public void run_withRequestProcessor_doesNotModifyRequest() { Content content = Content.fromParts(Part.fromText("LLM response")); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 1e6267dde..4313b14ca 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -561,16 +561,66 @@ public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() { List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); - assertThat(result).hasSize(4); - assertThat(result.get(0)).isEqualTo(u1.content().get()); - assertThat(result.get(1)).isEqualTo(fc1.content().get()); - assertThat(result.get(2)).isEqualTo(fc2.content().get()); - Content mergedContent = result.get(3); + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_sequentialCalls_preservesInterleavedOrder() { + Event u1 = createUserEvent("u1", "Query"); + Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event fc2 = createFunctionCallEvent("fc2", "tool2", "call2"); + Event fr2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + + ImmutableList inputEvents = ImmutableList.of(u1, fc1, fr1, fc2, fr2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEvents_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2); + + List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); + + assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(2); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) + .hasValue("tool1"); + assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) + .hasValue("tool2"); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEventsInHistory_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + Event u2 = createUserEvent("u2", "Second Query"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2, u2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(4); // u1, fc1, merged_fr, u2 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(2); assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) .hasValue("tool1"); assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) .hasValue("tool2"); + assertThat(result.get(3)).isEqualTo(inputEvents.get(4).content().get()); } @Test diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index c230f5f68..b5d7cb156 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -17,6 +17,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Candidate; import com.google.genai.types.Content; @@ -27,6 +28,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.nio.charset.StandardCharsets; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -59,8 +61,57 @@ public void processRawResponses_withTextChunks_emitsPartialResponses() { assertLlmResponses( llmResponses, isPartialTextResponse("Thinking..."), - isFinalTextResponse("Thinking..."), - isFunctionCallResponse()); + isPartialFunctionCallResponse("test_function")); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. + } + + @Test + public void processRawResponses_chunkWithBothTextAndFunctionCall_emitsPartialWithBoth() { + GenerateContentResponse chunkWithBoth = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.fromText("Here is the call:"), + Part.fromFunctionCall("my_tool", ImmutableMap.of())) + .build()) + .build()) + .build(); + + Flowable rawResponses = Flowable.just(chunkWithBoth); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, isPartialTextAndFunctionCallResponse("Here is the call:", "my_tool")); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. + } + + @Test + public void processRawResponses_streamingFunctionCallsAndStop_emitsPartialsThenFinalAggregated() { + Part fc1 = Part.fromFunctionCall("tool1", ImmutableMap.of("arg1", "val1")); + Part fc2 = Part.fromFunctionCall("tool2", ImmutableMap.of("arg2", "val2")); + GenerateContentResponse fc2WithStop = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fc2).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .build(); + Flowable rawResponses = Flowable.just(toResponse(fc1), fc2WithStop); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("tool1"), + isPartialFunctionCallResponse("tool2"), + isFinalAggregatedFunctionCallResponse("tool1", "tool2")); } @Test @@ -108,20 +159,20 @@ public void processRawResponses_finishReasonNotStop_doesNotEmitFinalAccumulatedT assertLlmResponses( llmResponses, isPartialTextResponse("Hello"), isPartialTextResponse(" world")); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. } @Test - public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmpty() { + public void processRawResponses_textThenEmpty_emitsPartialTextThenFullText() { Flowable rawResponses = Flowable.just(toResponseWithText("Thinking..."), GenerateContentResponse.builder().build()); Flowable llmResponses = Gemini.processRawResponses(rawResponses); - assertLlmResponses( - llmResponses, - isPartialTextResponse("Thinking..."), - isFinalTextResponse("Thinking..."), - isEmptyResponse()); + assertLlmResponses(llmResponses, isPartialTextResponse("Thinking...")); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. } @Test @@ -138,6 +189,8 @@ public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetad llmResponses, isPartialTextResponseWithUsageMetadata("Hello", metadata1), isPartialTextResponseWithUsageMetadata(" world", metadata2)); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. } @Test @@ -157,6 +210,27 @@ public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMeta isFinalTextResponseWithUsageMetadata("Hello world", metadata)); } + @Test + public void + processRawResponses_textThenEmptyStopWithUsageMetadata_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder().finishReason(new FinishReason(FinishReason.Known.STOP)).build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Hello"), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isFinalTextResponseWithUsageMetadata("Hello", metadata)); + } + @Test public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); @@ -190,12 +264,366 @@ public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsag llmResponses, isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), isPartialTextResponseWithUsageMetadata("Answer", metadata2), - isFinalThoughtResponseWithNoUsageMetadata("Thinking"), - isFinalTextResponseWithUsageMetadata("Answer", metadata2)); + isFinalThoughtAndTextResponseWithUsageMetadata("Thinking", "Answer", metadata2)); } - // Helper methods for assertions + @Test + public void + processRawResponses_interleavedThoughtAndTextWithStop_separatelyAggregatesThoughtAndText() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata3 = createUsageMetadata(10, 15, 25); + GenerateContentResponseUsageMetadata metadata4 = createUsageMetadata(10, 20, 30); + Flowable rawResponses = + Flowable.just( + toResponseWithThoughtText("Thinking 1", metadata1), + toResponseWithText("Answer 1", metadata2), + toResponseWithThoughtText(" Thinking 2", metadata3), + toResponseWithText(" Answer 2", FinishReason.Known.STOP, metadata4)); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking 1", metadata1), + isPartialTextResponseWithUsageMetadata("Answer 1", metadata2), + isPartialThoughtResponseWithUsageMetadata(" Thinking 2", metadata3), + isPartialTextResponseWithUsageMetadata(" Answer 2", metadata4), + isFinalInterleavedThoughtAndTextResponseWithUsageMetadata( + "Thinking 1", "Answer 1", " Thinking 2", " Answer 2", metadata4)); + } + + @Test + public void + processRawResponses_textAndFunctionCallWithStop_onlyFinalFunctionCallIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Part fcPart = Part.fromFunctionCall("my_tool", ImmutableMap.of()); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fcPart).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Answer", metadata1), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Answer", metadata1), + isPartialFunctionCallResponse("my_tool"), + isFinalTextAndFunctionCallResponseWithUsageMetadata("Answer", metadata2, "my_tool")); + } + + @Test + public void + processRawResponses_thoughtThenEmptyWithSignatureAndStop_flushesThoughtWithSignature() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = toResponseWithThoughtText("Thinking", metadata1); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isFinalThoughtResponseWithUsageMetadataAndSignature("Thinking", metadata2, "sig")); + } + + @Test + public void + processRawResponses_thoughtWithSignatureThenTextAndStop_flushesThoughtWithSignature() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .text("Thinking") + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .build()) + .usageMetadata(metadata1) + .build(); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts(Part.builder().text("Hello").thought(false).build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialTextResponseWithUsageMetadata("Hello", metadata2), + isFinalThoughtAndTextResponseWithUsageMetadataAndSignature( + "Thinking", "Hello", metadata2, "sig")); + } + + @Test + public void + processRawResponses_thoughtThenFunctionCallWithSignatureAndStop_attachesSignatureToFunctionCall() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = toResponseWithThoughtText("Thinking", metadata1); + GenerateContentResponse chunk2 = + toResponse(Part.fromFunctionCall("my_tool", ImmutableMap.of())); + GenerateContentResponse chunk3 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2, chunk3); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isPartialFunctionCallResponse("my_tool"), + isFinalThoughtAndFunctionCallResponseWithUsageMetadataAndSignature( + "Thinking", metadata2, "sig", "my_tool")); + } + + @Test + public void processRawResponses_emptyPartsThenSignature_doesNotThrowException() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + GenerateContentResponse chunk1 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(ImmutableList.of()).build()) + .build()) + .build(); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isEmptyResponse(), + isFinalThoughtResponseWithUsageMetadataAndSignature("", metadata, "sig")); + } + + @Test + public void functionCallThenEmptyTextWithStop_emitsPartialThenFinalAggregatedFunctionCall() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponse("test_function")); + } + + @Test + public void functionCallThenEmptyTextWithUsageMetadata_emitsFinalAggregatedWithUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponseWithUsageMetadata(metadata, "test_function")); + } + @Test + public void functionCallThenEmptyText_doesNotEmitExtraEmptyResponse() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("")); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses(llmResponses, isPartialFunctionCallResponse("test_function")); + // No final response is emitted, no FinishReason.STOP means the stream is treated as + // incomplete/aborted. + } + + @Test + public void textThenFunctionCallThenEmptyTextWithStop_emitsTextThenFunctionCalls() { + Flowable rawResponses = + Flowable.just( + toResponseWithText("Thinking..."), + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Thinking..."), + isPartialFunctionCallResponse("test_function"), + isFinalTextAndFunctionCallResponseWithNoUsageMetadata("Thinking...", "test_function")); + } + + // Test cases for the shouldEmit filter applied by generateContent after processRawResponses. + // shouldEmit drops chunks that are empty-text-only unless they carry final metadata (usage + // metadata or finish reason); everything else is forwarded. + // processRawResponses normally already strips empty-text-only chunks, so shouldEmit + // is defense-in-depth, but it must still behave correctly when fed any LlmResponse directly. + + @Test + public void shouldEmit_emptyTextOnlyResponseWithNoMetadata_returnsFalse() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isFalse(); + } + + @Test + public void shouldEmit_emptyTextOnlyResponseWithFinishReason_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_emptyTextOnlyResponseWithUsageMetadata_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .usageMetadata(createUsageMetadata(5, 10, 15)) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_nonEmptyTextResponse_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("hello")).build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_functionCallResponse_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromFunctionCall("test_function", ImmutableMap.of())) + .build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_contentlessResponse_returnsTrue() { + // A response with no content at all is not an empty-text-only response, so it should pass + // through regardless of metadata. This is the shape emitted by processRawResponses after it + // strips empty-text content while preserving metadata. + LlmResponse response = LlmResponse.builder().build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_multiPartResponseWithEmptyTextPart_returnsTrue() { + // Only single-part empty-text responses are considered "empty-text-only". A multi-part response + // is treated as carrying semantic content and must always pass through. + LlmResponse response = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(""), Part.fromText("hello")) + .build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + // Helper methods for assertions private void assertLlmResponses( Flowable llmResponses, Predicate... predicates) { TestSubscriber testSubscriber = llmResponses.test(); @@ -218,23 +646,65 @@ private static Predicate isPartialTextResponse(String expectedText) private static Predicate isFinalTextResponse(String expectedText) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); return true; }; } - private static Predicate isFunctionCallResponse() { + private static Predicate isPartialFunctionCallResponse(String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(1); + assertThat(response.content().get().parts().get().get(0).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isPartialTextAndFunctionCallResponse( + String expectedText, String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponse( + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponseWithUsageMetadata( + GenerateContentResponseUsageMetadata expectedMetadata, String... expectedToolNames) { return response -> { - assertThat(response.content().get().parts().get().get(0).functionCall()).isNotNull(); + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); return true; }; } private static Predicate isEmptyResponse() { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEmpty(); return true; @@ -268,7 +738,7 @@ private static Predicate isPartialThoughtResponseWithUsageMetadata( private static Predicate isFinalTextResponseWithUsageMetadata( String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(response.usageMetadata()).hasValue(expectedMetadata); @@ -279,7 +749,7 @@ private static Predicate isFinalTextResponseWithUsageMetadata( private static Predicate isFinalThoughtResponseWithUsageMetadata( String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) @@ -289,21 +759,140 @@ private static Predicate isFinalThoughtResponseWithUsageMetadata( }; } - private static Predicate isFinalThoughtResponseWithNoUsageMetadata( - String expectedText) { + private static Predicate isFinalThoughtResponseWithUsageMetadataAndSignature( + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature) { return response -> { - assertThat(response.partial()).isEmpty(); + assertThat(response.partial().orElse(false)).isFalse(); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) .isEqualTo(expectedText); assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) .isTrue(); + assertThat( + GeminiUtil.getPart0FromLlmResponse(response) + .flatMap(Part::thoughtSignature) + .orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(StandardCharsets.UTF_8)); + + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtAndTextResponseWithUsageMetadata( + String expectedThought, + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalThoughtAndTextResponseWithUsageMetadataAndSignature( + String expectedThought, + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat( + response.content().get().parts().get().get(0).thoughtSignature().orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(StandardCharsets.UTF_8)); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalInterleavedThoughtAndTextResponseWithUsageMetadata( + String expectedThought1, + String expectedText1, + String expectedThought2, + String expectedText2, + GenerateContentResponseUsageMetadata expectedMetadata) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(4); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought1); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(1).text()).hasValue(expectedText1); + assertThat(response.content().get().parts().get().get(1).thought()).hasValue(false); + assertThat(response.content().get().parts().get().get(2).text()).hasValue(expectedThought2); + assertThat(response.content().get().parts().get().get(2).thought()).hasValue(true); + assertThat(response.content().get().parts().get().get(3).text()).hasValue(expectedText2); + assertThat(response.content().get().parts().get().get(3).thought()).hasValue(false); + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextAndFunctionCallResponseWithUsageMetadata( + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i + 1).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextAndFunctionCallResponseWithNoUsageMetadata( + String expectedText, String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i + 1).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } assertThat(response.usageMetadata()).isEmpty(); return true; }; } - // Helper methods to create responses for testing + private static Predicate + isFinalThoughtAndFunctionCallResponseWithUsageMetadataAndSignature( + String expectedThought, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature, + String... expectedToolNames) { + return response -> { + assertThat(response.partial().orElse(false)).isFalse(); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length + 1); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedThought); + assertThat(response.content().get().parts().get().get(0).thought()).hasValue(true); + for (int i = 0; i < expectedToolNames.length; i++) { + Part part = response.content().get().parts().get().get(i + 1); + assertThat(part.functionCall().get().name()).hasValue(expectedToolNames[i]); + assertThat(part.thoughtSignature().orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(StandardCharsets.UTF_8)); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { return toResponse(Part.fromText(text)); } @@ -316,14 +905,6 @@ private GenerateContentResponse toResponseWithText(String text, FinishReason.Kno .build()); } - private GenerateContentResponse toResponse(Part part) { - return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); - } - - private GenerateContentResponse toResponse(Candidate candidate) { - return GenerateContentResponse.builder().candidates(candidate).build(); - } - private GenerateContentResponse toResponseWithText( String text, GenerateContentResponseUsageMetadata usageMetadata) { return GenerateContentResponse.builder() @@ -349,6 +930,14 @@ private GenerateContentResponse toResponseWithText( .build(); } + private GenerateContentResponse toResponse(Part part) { + return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); + } + + private GenerateContentResponse toResponse(Candidate candidate) { + return GenerateContentResponse.builder().candidates(candidate).build(); + } + private GenerateContentResponse toResponseWithThoughtText( String text, GenerateContentResponseUsageMetadata usageMetadata) { Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); diff --git a/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java b/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java index 18c8f8786..85347c964 100644 --- a/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java +++ b/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java @@ -31,7 +31,7 @@ public class CityTimeWeather { public static final BaseAgent ROOT_AGENT = LlmAgent.builder() .name("multi_tool_agent") - .model("gemini-2.0-flash-lite") + .model("gemini-3.1-flash-lite") .description("Agent to answer questions about the time and weather in a city.") .instruction( "You are a helpful agent who can answer user questions about the time and weather in"