diff --git a/annot/pom.xml b/annot/pom.xml index 35b38f53e5..fd600ce221 100644 --- a/annot/pom.xml +++ b/annot/pom.xml @@ -90,6 +90,11 @@ junit-jupiter-engine test + + org.junit.jupiter + junit-jupiter-params + test + org.hamcrest hamcrest diff --git a/annot/src/main/java/com/predic8/membrane/annot/util/JsonPathUtil.java b/annot/src/main/java/com/predic8/membrane/annot/util/JsonPathUtil.java new file mode 100644 index 0000000000..890d161df2 --- /dev/null +++ b/annot/src/main/java/com/predic8/membrane/annot/util/JsonPathUtil.java @@ -0,0 +1,99 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.annot.util; + +import com.predic8.membrane.annot.yaml.ParsingContext; + +public final class JsonPathUtil { + + + private JsonPathUtil() {} + + /** + * Returns the parent path before the last JSONPath segment. + * + * @param jsonPath JSONPath created by {@link ParsingContext} + * @return parent path without the last segment + */ + public static String getParentPath(String jsonPath) { + return jsonPath.substring(0, findLastSegmentStart(jsonPath)); + } + + /** + * Returns the last part of a JSONPath created by {@link ParsingContext}. + *

+ * Examples: + * {@code $.api.methods} -> {@code methods} + * {@code $.api.methods['rpc.echo']} -> {@code rpc.echo} + * {@code $.api.methods[0]} -> {@code 0} + * + * @param jsonPath JSONPath created by {@link ParsingContext} + * @return unescaped field name or array index of the last segment + */ + public static String getLastSegment(String jsonPath) { + String segment = jsonPath.substring(findLastSegmentStart(jsonPath)); + + if (segment.startsWith(".")) { + return segment.substring(1); + } + + if (isQuotedPropertySegment(segment)) { + return decodeQuotedPropertySegment(segment); + } + + if (isArrayIndexSegment(segment)) { + return segment.substring(1, segment.length() - 1); + } + + throw new IllegalArgumentException("Unsupported JSONPath segment: " + segment); + } + + private static boolean isQuotedPropertySegment(String segment) { + return segment.startsWith("['") && segment.endsWith("']"); + } + + private static boolean isArrayIndexSegment(String segment) { + return segment.startsWith("[") && segment.endsWith("]"); + } + + /** + * Removes JSONPath bracket quoting and unescapes supported characters. + */ + private static String decodeQuotedPropertySegment(String segment) { + return segment.substring(2, segment.length() - 2) + .replace("\\'", "'") + + .replace("\\\\", "\\"); + } + + /** + * Returns the start index of the last JSONPath segment. + */ + private static int findLastSegmentStart(String jsonPath) { + int depth = 0; + for (int i = jsonPath.length() - 1; i >= 0; i--) { + char c = jsonPath.charAt(i); + if (c == ']') { + depth++; + } else if (c == '[') { + depth--; + } + if (depth == 0 && (c == '.' || c == '[')) { + return i; + } + } + throw new IllegalArgumentException("Cannot determine parent path of: " + jsonPath); + } +} diff --git a/annot/src/main/java/com/predic8/membrane/annot/yaml/ParsingContext.java b/annot/src/main/java/com/predic8/membrane/annot/yaml/ParsingContext.java index 791ed2df82..fae8b08769 100644 --- a/annot/src/main/java/com/predic8/membrane/annot/yaml/ParsingContext.java +++ b/annot/src/main/java/com/predic8/membrane/annot/yaml/ParsingContext.java @@ -20,6 +20,9 @@ import org.jetbrains.annotations.*; import org.slf4j.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * Immutable parsing state passed down while traversing YAML. * - context: current element scope used for local type resolution in {@link Grammar}. @@ -29,6 +32,7 @@ public record ParsingContext(String context, T registry, Grammar grammar, JsonNode topLevel, String path, String key) { private static final Logger log = LoggerFactory.getLogger(ParsingContext.class); + public static final Pattern VALID_PROPERTY_NAME = Pattern.compile("[A-Za-z_][A-Za-z0-9_]*"); public ParsingContext updateContext(String context) { return new ParsingContext<>(context, registry, grammar, topLevel, path,key); @@ -96,7 +100,7 @@ public String getPath() { * such as `.` or `'` that JSONPath would otherwise interpret as syntax. */ private static String toJsonPathProperty(String property) { - if (property.matches("[A-Za-z_][A-Za-z0-9_]*")) { + if (VALID_PROPERTY_NAME.matcher(property).matches()) { return "." + property; } return "['" + property diff --git a/annot/src/main/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRenderer.java b/annot/src/main/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRenderer.java index c938a56810..4eed14ec25 100644 --- a/annot/src/main/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRenderer.java +++ b/annot/src/main/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRenderer.java @@ -25,6 +25,8 @@ import java.nio.file.Path; +import static com.predic8.membrane.annot.util.JsonPathUtil.getLastSegment; +import static com.predic8.membrane.annot.util.JsonPathUtil.getParentPath; import static com.predic8.membrane.common.TerminalColorsMini.*; /** @@ -58,6 +60,12 @@ public static String renderErrorReport(ParsingContext pc) throws JsonProcessingE return renderErrorReport(pc, null); } + /** + * Renders a YAML representation of the JSON node with line-based error markers + * and, if available, the source file path. + * + * @return YAML string with ">" prefixes and red color for error lines + */ public static String renderErrorReport(ParsingContext pc, Path sourceFile) throws JsonProcessingException { JsonNode node = pc.getNode(); @@ -270,6 +278,9 @@ private static String markLinesWithPrefix(String yaml, Path sourceFile) { return result.toString().trim(); } + /** + * Counts the leading spaces in a YAML line. + */ private static int getIndentation(String line) { int indent = 0; for (char c : line.toCharArray()) { @@ -281,67 +292,4 @@ private static int getIndentation(String line) { } return indent; } - - static String getParentPath(String jsonPath) { - return jsonPath.substring(0, findLastSegmentStart(jsonPath)); - } - - /** - * Returns the last part of a JSONPath created by {@link ParsingContext}. - * Examples: - * `$.api.methods` -> `methods` - * `$.api.methods['rpc.echo']` -> `rpc.echo` - * `$.api.methods[0]` -> `0` - */ - static String getLastSegment(String jsonPath) { - String segment = jsonPath.substring(findLastSegmentStart(jsonPath)); - - if (isPropertySegment(segment)) { - return segment.substring(1); - } - - if (isQuotedPropertySegment(segment)) { - return decodeQuotedPropertySegment(segment); - } - - if (isArrayIndexSegment(segment)) { - return segment.substring(1, segment.length() - 1); - } - - throw new IllegalArgumentException("Unsupported JSONPath segment: " + segment); - } - - private static boolean isPropertySegment(String segment) { - return segment.startsWith("."); - } - - private static boolean isQuotedPropertySegment(String segment) { - return segment.startsWith("['") && segment.endsWith("']"); - } - - private static String decodeQuotedPropertySegment(String segment) { - return segment.substring(2, segment.length() - 2) - .replace("\\'", "'") - .replace("\\\\", "\\"); - } - - private static boolean isArrayIndexSegment(String segment) { - return segment.startsWith("[") && segment.endsWith("]"); - } - - private static int findLastSegmentStart(String jsonPath) { - int depth = 0; - for (int i = jsonPath.length() - 1; i >= 0; i--) { - char c = jsonPath.charAt(i); - if (c == ']') { - depth++; - } else if (c == '[') { - depth--; - } - if (depth == 0 && (c == '.' || c == '[')) { - return i; - } - } - throw new IllegalArgumentException("Cannot determine parent path of: " + jsonPath); - } } diff --git a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverter.java b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverter.java index bffcf2c1b7..e3004dfaff 100644 --- a/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverter.java +++ b/annot/src/main/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverter.java @@ -42,6 +42,9 @@ public final class ScalarValueConverter { private final SpelEvaluator spelEvaluator = new SpelEvaluator(); private final ReferenceResolver referenceResolver = new ReferenceResolver(); + /** + * Converts a YAML scalar node to the setter target type or resolves a reference. + */ public Object coerceScalarOrReference(ParsingContext ctx, Method setter, JsonNode node, String key, Class wanted) throws WrongEnumConstantException { if (wanted.equals(String.class)) return node.isTextual() ? spelEvaluator.resolve(node.asText(), String.class) : node.asText(); @@ -55,12 +58,18 @@ public Object coerceScalarOrReference(ParsingContext ctx, Method setter, Json return coerceNonTextual(ctx, setter, node, key, wanted); } + /** + * Converts a scalar node directly, evaluating SpEL first for textual values. + */ public static Object convertScalarOrSpel(JsonNode node, Class targetType) { if (node == null || !node.isTextual()) return SCALAR_MAPPER.convertValue(node, targetType); return STATIC_SPEL_EVALUATOR.resolve(node.asText(), targetType); } + /** + * Handles textual YAML values, including SpEL, numbers, references, and maps. + */ private Object coerceTextual(ParsingContext ctx, Method setter, JsonNode node, String key, Class wanted) { String evaluated = evaluateSpelForString(key, node.asText()); if (evaluated == null) { @@ -84,6 +93,9 @@ private Object coerceTextual(ParsingContext ctx, Method setter, JsonNode node throw unsupported(wanted, key, node); } + /** + * Handles already typed JSON values such as booleans and numbers. + */ private Object coerceNonTextual(ParsingContext ctx, Method setter, JsonNode node, String key, Class wanted) { if (isInteger(wanted)) return node.isInt() ? node.intValue() : parseInt(node.asText()); @@ -102,9 +114,9 @@ private Object coerceNonTextual(ParsingContext ctx, Method setter, JsonNode n /** * Converts the value of one entry from an {@code @MCOtherAttributes} map. - * Example: for `methods: { 'rpc.echo': { params: ... } }` the key is `rpc.echo`. + * Example: for {@code methods: { 'rpc.echo': { params: ... } }} the key is {@code rpc.echo}. * The nested object is then bound as the configured Java type for the map value, - * not left as a raw map. Plain scalar values such as `timeout: 5` stay plain values. + * not left as a raw map. Plain scalar values such as {@code timeout: 5} stay plain values. */ private Object convertAnySetterValue(ParsingContext ctx, Method setter, JsonNode node, String key) { Class valueType = getMapValueType(setter); @@ -112,7 +124,14 @@ private Object convertAnySetterValue(ParsingContext ctx, Method setter, JsonN return SCALAR_MAPPER.convertValue(node, Object.class); } if (valueType == String.class) { - return node.isTextual() ? evaluateSpelForString(key, node.asText()) : node.asText(); + if (node.isTextual()) + return evaluateSpelForString(key, node.asText()); + if (node.isObject() || node.isArray()) + throw unsupported(valueType, key, node); + return SCALAR_MAPPER.convertValue(node, valueType); + } + if (isScalarCompatible(valueType)) { + return SCALAR_MAPPER.convertValue(node, valueType); } return bind( ctx.updateContext(getElementName(valueType)).addProperty(key), @@ -121,6 +140,9 @@ private Object convertAnySetterValue(ParsingContext ctx, Method setter, JsonN ); } + /** + * Returns the runtime class used as the value type of a map setter parameter. + */ private static Class getMapValueType(Method setter) { Type genericType = setter.getGenericParameterTypes()[0]; if (!(genericType instanceof ParameterizedType parameterizedType)) { @@ -150,6 +172,9 @@ private String evaluateSpelForString(String key, String value) { } } + /** + * Parses numeric text and reports invalid values as configuration errors. + */ private Object parseNumericOrThrow(ParsingContext ctx, String key, Class wanted, String value, JsonNode node) { try { if (isInteger(wanted)) @@ -167,12 +192,18 @@ private Object parseNumericOrThrow(ParsingContext ctx, String key, Class w throw unsupported(wanted, key, node); } + /** + * Checks whether the target type should be treated as a bean reference. + */ private static boolean isBeanReference(Class wanted) { if (wanted == Integer.TYPE || wanted == Long.TYPE || wanted == Float.TYPE || wanted == Double.TYPE || wanted == Boolean.TYPE || wanted == String.class) return false; return !wanted.isEnum(); } + /** + * Parses enum values case-insensitively. + */ @SuppressWarnings("unchecked") private static > E parseEnum(Class enumClass, JsonNode node) throws WrongEnumConstantException { String value = node.asText().toUpperCase(ROOT); @@ -203,6 +234,31 @@ private static boolean isNumber(Class wanted) { return isInteger(wanted) || isLong(wanted) || isDouble(wanted); } + /** + * Checks whether Jackson can convert the node as a scalar value. + */ + private static boolean isScalarCompatible(Class wanted) { + return wanted.isEnum() + || isPrimitiveScalar(wanted) + || isPrimitiveWrapper(wanted) + || Number.class.isAssignableFrom(wanted); + } + + private static boolean isPrimitiveScalar(Class wanted) { + return wanted.isPrimitive() && wanted != void.class; + } + + private static boolean isPrimitiveWrapper(Class wanted) { + return wanted == Boolean.class + || wanted == Character.class + || wanted == Byte.class + || wanted == Short.class + || wanted == Integer.class + || wanted == Long.class + || wanted == Float.class + || wanted == Double.class; + } + private static ConfigurationParsingException unsupported(Class wanted, String key, JsonNode node) { return new ConfigurationParsingException("Unsupported setter type: %s for key '%s' with node type %s".formatted(wanted.getName(), key, node.getNodeType())); } diff --git a/annot/src/test/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRendererTest.java b/annot/src/test/java/com/predic8/membrane/annot/util/JsonPathUtilTest.java similarity index 87% rename from annot/src/test/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRendererTest.java rename to annot/src/test/java/com/predic8/membrane/annot/util/JsonPathUtilTest.java index 556f7536ec..c49fae1305 100644 --- a/annot/src/test/java/com/predic8/membrane/annot/yaml/error/LineYamlErrorRendererTest.java +++ b/annot/src/test/java/com/predic8/membrane/annot/util/JsonPathUtilTest.java @@ -12,15 +12,15 @@ See the License for the specific language governing permissions and limitations under the License. */ -package com.predic8.membrane.annot.yaml.error; +package com.predic8.membrane.annot.util; import org.junit.jupiter.api.Test; -import static com.predic8.membrane.annot.yaml.error.LineYamlErrorRenderer.getLastSegment; +import static com.predic8.membrane.annot.util.JsonPathUtil.getLastSegment; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -class LineYamlErrorRendererTest { +class JsonPathUtilTest { @Test void splitsJsonPathSegments() { @@ -48,6 +48,6 @@ void rejectsJsonPathWithoutSegment() { private static void assertPath(String jsonPath, String expectedParentPath, String expectedLastSegment) { assertEquals(expectedLastSegment, getLastSegment(jsonPath)); - assertEquals(expectedParentPath, LineYamlErrorRenderer.getParentPath(jsonPath)); + assertEquals(expectedParentPath, JsonPathUtil.getParentPath(jsonPath)); } } diff --git a/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/MethodSetterTest.java b/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/MethodSetterTest.java index c38a2f59a4..9150b8dc87 100644 --- a/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/MethodSetterTest.java +++ b/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/MethodSetterTest.java @@ -14,16 +14,25 @@ package com.predic8.membrane.annot.yaml.parsing; -import com.fasterxml.jackson.databind.*; -import com.predic8.membrane.annot.*; -import com.predic8.membrane.annot.util.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.annot.MCChildElement; +import com.predic8.membrane.annot.MCOtherAttributes; +import com.predic8.membrane.annot.util.GrammarMock; import com.predic8.membrane.annot.yaml.ConfigurationParsingException; import com.predic8.membrane.annot.yaml.ParsingContext; -import org.junit.jupiter.api.*; - -import static com.predic8.membrane.annot.yaml.parsing.MethodSetter.*; +import com.predic8.membrane.annot.yaml.parsing.binding.ScalarValueConverter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static com.predic8.membrane.annot.yaml.parsing.MethodSetter.getMethodSetter; import static org.junit.jupiter.api.Assertions.*; -import static org.junit.jupiter.api.Assertions.assertEquals; class MethodSetterTest { @@ -49,6 +58,40 @@ public static class B {} public static class C {} + private static final String OTHER_ATTRIBUTE_NAME = "timeout"; + + @SuppressWarnings("unused") + static class IntegerAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) {} + } + + @SuppressWarnings("unused") + static class EnumAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) {} + } + + @SuppressWarnings("unused") + static class StringAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) {} + } + + @SuppressWarnings({"rawtypes", "unused"}) + static class RawAttributes { + public void setAttributes(Map attributes) {} + } + + @SuppressWarnings("unused") + static class ListAttributes { + public void setAttributes(Map> attributes) {} + } + + enum Mode { + FAST + } + @Test void dontUseMethodsWithoutChildElementAnnotation() { MethodSetter ms = getMethodSetter(new ParsingContext("foo",null, @@ -82,4 +125,58 @@ void coerceScalar() throws Exception { assertInstanceOf(Long.class, l); assertEquals(1L, l); } -} \ No newline at end of file + + @ParameterizedTest + @MethodSource("directlyConvertedOtherAttributes") + void otherAttributeValueIsConvertedAccordingToMapValueType( + Class attributesClass, + String json, + Object expectedValue + ) throws Exception { + Map attributes = coerceOtherAttribute(attributesClass, json); + + assertEquals(expectedValue, attributes.get(OTHER_ATTRIBUTE_NAME)); + } + + static Stream directlyConvertedOtherAttributes() { + return Stream.of( + Arguments.of(IntegerAttributes.class, "5", 5), + Arguments.of(EnumAttributes.class, "\"FAST\"", Mode.FAST), + Arguments.of(StringAttributes.class, "\"value\"", "value") + ); + } + + @Test + void stringOtherAttributeRejectsNestedObject() { + assertThrows(ConfigurationParsingException.class, + () -> coerceOtherAttribute(StringAttributes.class, "{\"nested\":true}")); + } + + @ParameterizedTest + @MethodSource("mapValueTypes") + void mapValueTypeIsReadFromSetterParameter(Class attributesClass, Class expectedValueType) throws Exception { + assertEquals(expectedValueType, getMapValueType(attributesClass)); + } + + static Stream mapValueTypes() { + return Stream.of( + Arguments.of(IntegerAttributes.class, Integer.class), + Arguments.of(ListAttributes.class, List.class), + Arguments.of(RawAttributes.class, Object.class) + ); + } + + private Map coerceOtherAttribute(Class clazz, String json) throws Exception { + Method setter = clazz.getMethod("setAttributes", Map.class); + Object value = new MethodSetter(setter, null) + .coerceScalarOrReference(null, om.readTree(json), OTHER_ATTRIBUTE_NAME, Map.class); + + return assertInstanceOf(Map.class, value); + } + + private Class getMapValueType(Class clazz) throws Exception { + Method method = ScalarValueConverter.class.getDeclaredMethod("getMapValueType", Method.class); + method.setAccessible(true); + return (Class) method.invoke(null, clazz.getMethod("setAttributes", Map.class)); + } +} diff --git a/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverterTest.java b/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverterTest.java new file mode 100644 index 0000000000..34d8ccef69 --- /dev/null +++ b/annot/src/test/java/com/predic8/membrane/annot/yaml/parsing/binding/ScalarValueConverterTest.java @@ -0,0 +1,113 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.annot.yaml.parsing.binding; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.annot.MCOtherAttributes; +import com.predic8.membrane.annot.yaml.ConfigurationParsingException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.lang.reflect.Method; +import java.util.Map; +import java.util.stream.Stream; + +import static com.predic8.membrane.annot.yaml.parsing.binding.ScalarValueConverterTest.Mode.FAST; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ScalarValueConverterTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String OTHER_ATTRIBUTE_NAME = "timeout"; + + private final ScalarValueConverter converter = new ScalarValueConverter(); + + @SuppressWarnings("unused") + static class IntegerAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) { + } + } + + @SuppressWarnings("unused") + static class EnumAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) { + } + } + + @SuppressWarnings("unused") + static class StringAttributes { + @MCOtherAttributes + public void setAttributes(Map attributes) { + } + } + + enum Mode { + FAST + } + + @Test + void coercesNonTextualScalars() throws Exception { + assertEquals(true, coerce("true", boolean.class)); + assertEquals(true, coerce("true", Boolean.class)); + assertEquals(1, coerce("1", int.class)); + assertEquals(1.0, coerce("1", double.class)); + + Object value = coerce("1", long.class); + + assertInstanceOf(Long.class, value); + assertEquals(1L, value); + } + + @ParameterizedTest + @MethodSource("directlyConvertedOtherAttributes") + void otherAttributeValueIsConvertedAccordingToMapValueType( + Class attributesClass, + String json, + Object expectedValue + ) throws Exception { + assertEquals(expectedValue, coerceOtherAttribute(attributesClass, json).get(OTHER_ATTRIBUTE_NAME)); + } + + static Stream directlyConvertedOtherAttributes() { + return Stream.of( + Arguments.of(IntegerAttributes.class, "5", 5), + Arguments.of(EnumAttributes.class, "\"FAST\"", FAST), + Arguments.of(StringAttributes.class, "\"value\"", "value") + ); + } + + @Test + void stringOtherAttributeRejectsNestedObject() { + assertThrows(ConfigurationParsingException.class, + () -> coerceOtherAttribute(StringAttributes.class, "{\"nested\":true}")); + } + + private Object coerce(String json, Class wanted) throws Exception { + return converter.coerceScalarOrReference(null, null, MAPPER.readTree(json), "value", wanted); + } + + private Map coerceOtherAttribute(Class clazz, String json) throws Exception { + Method setter = clazz.getMethod("setAttributes", Map.class); + Object value = converter.coerceScalarOrReference(null, setter, MAPPER.readTree(json), OTHER_ATTRIBUTE_NAME, Map.class); + + return assertInstanceOf(Map.class, value); + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptor.java index 996aef33e2..d8b14b970f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptor.java @@ -90,6 +90,7 @@ public class JsonRPCProtectionInterceptor extends AbstractInterceptor { private static final Logger log = LoggerFactory.getLogger(JsonRPCProtectionInterceptor.class); private static final ObjectMapper OM = new ObjectMapper(); + private static final String JSON_RPC_ELIGIBLE = JsonRPCProtectionInterceptor.class.getName() + ".jsonRpcEligible"; private static final String RESPONSE_VALIDATION_CONTEXT = JsonRPCProtectionInterceptor.class.getName() + ".responseValidationContext"; private BatchRule batchRule = new BatchRule(); @@ -124,7 +125,9 @@ public Outcome handleRequest(Exchange exc) { )); } - RequestValidationResult validation = getValidator().validateRequest(exc.getRequest().getBodyAsStringDecoded()); + exc.setProperty(JSON_RPC_ELIGIBLE, Boolean.TRUE); + + RequestValidationResult validation = validator.validateRequest(exc.getRequest().getBodyAsStringDecoded()); if (validation.responseValidationContext() != null) { exc.setProperty(RESPONSE_VALIDATION_CONTEXT, validation.responseValidationContext()); } @@ -137,12 +140,16 @@ public Outcome handleResponse(Exchange exc) { return CONTINUE; } + if (!Boolean.TRUE.equals(exc.getProperty(JSON_RPC_ELIGIBLE, Boolean.class))) { + return CONTINUE; + } + ResponseValidationContext context = exc.getProperty(RESPONSE_VALIDATION_CONTEXT, ResponseValidationContext.class); if (context == null && schemaValidation.hasErrorValidation()) { context = new ResponseValidationContext(getPayloadType(exc.getResponse().getBodyAsStringDecoded()), java.util.Map.of()); } - return rejectResponse(exc, getValidator().validateResponse(exc.getResponse().getBodyAsStringDecoded(), context)); + return rejectResponse(exc, validator.validateResponse(exc.getResponse().getBodyAsStringDecoded(), context)); } private Outcome rejectRequest(Exchange exc, ValidationError error) { @@ -202,13 +209,6 @@ public List getMethods() { return methods; } - private JsonRPCValidator getValidator() { - if (validator == null) { - validator = createValidator(); - } - return validator; - } - private JsonRPCValidator createValidator() { schemaValidation.init(router.getResolverMap(), router.getConfiguration().getUriFactory(), getBeanBaseLocation()); return new JsonRPCValidator(batchRule, methods, schemaValidation); diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCSchemaValidation.java b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCSchemaValidation.java index 7056ab631f..526c4c5870 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCSchemaValidation.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCSchemaValidation.java @@ -28,6 +28,8 @@ import com.predic8.membrane.core.util.URIFactory; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.LinkedHashMap; import java.util.Map; @@ -36,6 +38,8 @@ import static com.networknt.schema.SchemaRegistry.withDefaultDialect; import static com.networknt.schema.SpecificationVersion.DRAFT_2020_12; import static com.predic8.membrane.core.resolver.ResolverMap.combine; +import static java.nio.charset.StandardCharsets.*; +import static java.util.Base64.getUrlEncoder; import static java.util.Collections.unmodifiableMap; /** @@ -196,15 +200,15 @@ private Schema loadInlineSchema(SchemaRegistry registry, String methodName, Stri } private String createInlineSchemaLocation(String methodName, String schemaRole, URIFactory uriFactory, String beanBaseLocation) { - String syntheticFile = "__jsonrpc_%s_%s.schema.json".formatted(sanitize(methodName), sanitize(schemaRole)); + String syntheticFile = "__jsonrpc_%s_%s.schema.json".formatted(encodeLocationToken(methodName), encodeLocationToken(schemaRole)); if (beanBaseLocation == null || beanBaseLocation.isBlank()) { return "membrane:%s".formatted(syntheticFile); } return combine(uriFactory, beanBaseLocation, syntheticFile); } - private static String sanitize(String value) { - return value.replaceAll("[^A-Za-z0-9._-]", "_"); + private static String encodeLocationToken(String value) { + return getUrlEncoder().withoutPadding().encodeToString(value.getBytes(UTF_8)); } private Schema loadSchema(String description, SchemaRegistry registry, SchemaLocation schemaLocation, Resolver resolver, String configuredLocation) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCValidator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCValidator.java index b27003c1aa..0bcf213792 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCValidator.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCValidator.java @@ -163,6 +163,9 @@ private ValidationError validateSingleResponse(JsonNode node, ResponseValidation } if (response.isError()) { + if (context.methodFor(response.getId()) == null) { + return invalidResponse(payloadType, response.getId(), "JSON-RPC response id '%s' does not match any request.".formatted(response.getId())); + } return validateErrorResponse(node, payloadType, response.getId()); } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptor.java new file mode 100644 index 0000000000..e0f6bd88a3 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptor.java @@ -0,0 +1,254 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.mcp; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.predic8.membrane.annot.MCChildElement; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.http.Response.ResponseBuilder; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.mcp.MCPProtectionValidator.ValidationError; +import com.predic8.membrane.core.util.config.allowdeny.Rule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; +import static com.predic8.membrane.core.http.Request.METHOD_POST; +import static com.predic8.membrane.core.http.Response.statusCode; +import static com.predic8.membrane.core.jsonrpc.JSONRPCRequest.parse; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.REQUEST; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.RESPONSE; +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.interceptor.Outcome.RETURN; +import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.ERR_INVALID_REQUEST; +import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.error; +import static com.predic8.membrane.core.mcp.MCPToolsList.METHOD; +import static java.util.EnumSet.of; + +/** + * @description + *

Protects an MCP endpoint by validating incoming JSON-RPC requests and + * restricting the MCP methods and tools that clients may use.

+ * + *

Only HTTP POST requests with an + * application/json content type are accepted. JSON-RPC batch + * requests are rejected.

+ * + *

initialize and ping are always allowed. + * tools/list, tools/call, and notifications are + * enabled by default and can be disabled under methods. Every + * other method is rejected.

+ * + *

Tool rules are evaluated in declaration order and the first matching + * rule wins. If no rule matches, the tool is allowed. Consequently, all tools + * are allowed when tools is omitted. Add a final + * deny: ".*" rule to change this into an allowlist.

+ * + *

Denied tools are also removed from tools/list responses so + * clients do not discover tools they are not allowed to call.

+ * + * @yaml + *

+ * - mcpProtection:
+ *     methods:
+ *       toolsList: true
+ *       toolsCall: true
+ *       notifications: true
+ *     tools:
+ *       - allow: "listProxies"
+ *       - allow: "getStatistics"
+ *       - allow: "getExchanges"
+ *       - deny: ".*"           # tool names also support regular expressions
+ * 
+ */ +@MCElement(name = "mcpProtection") +public class MCPProtectionInterceptor extends AbstractInterceptor { + + private static final Logger log = LoggerFactory.getLogger(MCPProtectionInterceptor.class); + + private static final ObjectMapper OM = new ObjectMapper(); + private MCPProtectionMethods methods = new MCPProtectionMethods(); + private static final String TOOLS_LIST_REQUEST = MCPProtectionInterceptor.class.getName() + ".toolsListRequest"; + private List tools = List.of(); + private MCPProtectionValidator validator; + + public MCPProtectionInterceptor() { + name = "MCP protection"; + setAppliedFlow(of(REQUEST, RESPONSE)); + } + + @Override + public void init() { + super.init(); + validator = createValidator(); + } + + @Override + public Outcome handleRequest(Exchange exc) { + if (!exc.getRequest().isPOSTRequest()) { + return reject(exc, new ValidationError( + null, + false, + 405, + ERR_INVALID_REQUEST, + "HTTP method %s is not supported. Expected POST.".formatted(exc.getRequest().getMethod()) + ), true); + } + + if (!exc.getRequest().isJSON()) { + return reject(exc, new ValidationError( + null, + false, + 415, + ERR_INVALID_REQUEST, + "Content-Type %s is not supported. Expected application/json.".formatted( + exc.getRequest().getHeader().getContentType() + ) + ), false); + } + + String body = exc.getRequest().getBodyAsStringDecoded(); + ValidationError error = getValidator().validate(body); + if (error == null && isToolsListRequest(body)) { + exc.setProperty(TOOLS_LIST_REQUEST, Boolean.TRUE); + } + return reject(exc, error, false); + } + + @Override + public Outcome handleResponse(Exchange exc) { + if (exc.getResponse() == null || !Boolean.TRUE.equals(exc.getProperty(TOOLS_LIST_REQUEST))) { + return CONTINUE; + } + + try { + JsonNode response = OM.readTree(exc.getResponse().getBodyAsStringDecoded()); + JsonNode resultNode = response == null ? null : response.get("result"); + if (!(resultNode instanceof ObjectNode result)) { + return CONTINUE; + } + + JsonNode listedTools = result.get("tools"); + if (listedTools == null || !listedTools.isArray()) { + return CONTINUE; + } + + var filteredTools = OM.createArrayNode(); + listedTools.forEach(tool -> { + JsonNode name = tool.get("name"); + if (name == null || !name.isTextual() || getValidator().isToolPermitted(name.textValue())) { + filteredTools.add(tool); + } + }); + if (filteredTools.size() != listedTools.size()) { + result.set("tools", filteredTools); + exc.getResponse().setBodyContent(OM.writeValueAsBytes(response)); + } + } catch (JsonProcessingException e) { + log.debug("Could not filter MCP tools/list response.", e); + } + return CONTINUE; + } + + private boolean isToolsListRequest(String body) { + try { + return METHOD.equals(parse(body).getMethod()); + } catch (IOException | RuntimeException e) { + return false; + } + } + + /** + * @description + *

Configures the optional MCP method groups. initialize + * and ping cannot be disabled by this interceptor.

+ */ + @MCChildElement(order = 1) + public void setMethods(MCPProtectionMethods methods) { + this.methods = methods == null ? new MCPProtectionMethods() : methods; + } + + /** + * @description + *

Configures ordered allow and deny rules for tool names used by + * tools/call and advertised by tools/list. + * Rules support regular expressions. The first matching rule wins; tools + * unmatched by any rule are allowed.

+ * + *

When no rules are configured, all tools are allowed. To allow only + * explicitly listed tools, place their allow rules first and + * finish the list with deny: ".*".

+ */ + @MCChildElement(order = 2) + public void setTools(List tools) { + this.tools = tools == null ? List.of() : tools; + } + + public MCPProtectionMethods getMethods() { + return methods; + } + + public List getTools() { + return tools; + } + + private MCPProtectionValidator getValidator() { + if (validator == null) { + validator = createValidator(); + } + return validator; + } + + private MCPProtectionValidator createValidator() { + return new MCPProtectionValidator(methods, tools); + } + + private Outcome reject(Exchange exc, ValidationError error, boolean addAllowHeader) { + if (error == null) { + return CONTINUE; + } + + log.info("Rejected MCP request: {}", error.message()); + if (error.notification()) { + exc.setResponse(statusCode(error.httpStatus()).bodyEmpty().build()); + return RETURN; + } + exc.setResponse(createErrorResponse(error, addAllowHeader)); + return RETURN; + } + + private Response createErrorResponse(ValidationError error, boolean addAllowHeader) { + try { + ResponseBuilder builder = statusCode(error.httpStatus()) + .contentType(APPLICATION_JSON) + .body(error(error.responseId(), error.code(), error.message()).toJson()); + if (addAllowHeader) { + builder.header("Allow", METHOD_POST); + } + return builder.build(); + } catch (IOException e) { + throw new RuntimeException("Could not create MCP error response", e); + } + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionMethods.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionMethods.java new file mode 100644 index 0000000000..59f6fda004 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionMethods.java @@ -0,0 +1,75 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.mcp; + +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; + +/** + * @description + *

Controls the optional MCP methods accepted by + * mcpProtection.

+ * + *

+ * initialize and ping are always allowed. + * Unsupported MCP methods are always rejected.

+ */ +@MCElement(name = "methods", id = "mcp-methods") +public class MCPProtectionMethods { + + private boolean toolsList = true; + private boolean toolsCall = true; + private boolean notifications = true; + + /** + * @description Enables the MCP tools/list method. + * @default true + */ + @MCAttribute + public void setToolsList(boolean toolsList) { + this.toolsList = toolsList; + } + + public boolean isToolsList() { + return toolsList; + } + + /** + * @description Enables the MCP tools/call method. + * @default true + */ + @MCAttribute + public void setToolsCall(boolean toolsCall) { + this.toolsCall = toolsCall; + } + + public boolean isToolsCall() { + return toolsCall; + } + + /** + * @description Enables MCP methods whose names start with + * notifications/. + * @default true + */ + @MCAttribute + public void setNotifications(boolean notifications) { + this.notifications = notifications; + } + + public boolean isNotifications() { + return notifications; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionValidator.java b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionValidator.java new file mode 100644 index 0000000000..72eca044d6 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionValidator.java @@ -0,0 +1,150 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.mcp; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.core.jsonrpc.JSONRPCRequest; +import com.predic8.membrane.core.mcp.MCPInitialize; +import com.predic8.membrane.core.mcp.MCPPing; +import com.predic8.membrane.core.mcp.MCPToolsCall; +import com.predic8.membrane.core.mcp.MCPToolsList; +import com.predic8.membrane.core.util.config.allowdeny.Rule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +import static com.predic8.membrane.core.jsonrpc.JSONRPCRequest.fromNode; +import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.*; +import static com.predic8.membrane.core.mcp.MCPToolsCall.from; + +final class MCPProtectionValidator { + + private static final Logger log = LoggerFactory.getLogger(MCPProtectionValidator.class); + private static final ObjectMapper OM = new ObjectMapper(); + + private final MCPProtectionMethods methods; + private final List toolRules; + + MCPProtectionValidator(MCPProtectionMethods methods, List toolRules) { + this.methods = methods == null ? new MCPProtectionMethods() : methods; + this.toolRules = toolRules == null ? List.of() : toolRules; + } + + ValidationError validate(String body) { + if (body == null || body.isBlank()) { + return invalidRequest("MCP request body must not be empty."); + } + + JsonNode root; + try { + root = OM.readTree(body); + } catch (JsonProcessingException e) { + log.debug("Invalid MCP JSON payload.", e); + return new ValidationError(null, false, 400, ERR_PARSE_ERROR, "Invalid JSON payload."); + } + + if (root == null || !root.isObject()) { + if (root != null && root.isArray()) { + return invalidRequest("MCP does not support JSON-RPC batch requests."); + } + return invalidRequest("MCP payload must be a JSON-RPC request object."); + } + + JSONRPCRequest request; + try { + request = fromNode(root); + } catch (IOException | RuntimeException e) { + return invalidRequest("Invalid JSON-RPC request: " + e.getMessage()); + } + + if (!isMethodAllowed(request.getMethod())) { + return new ValidationError( + responseId(request), + request.isNotification(), + 403, + ERR_METHOD_NOT_FOUND, + "MCP method '%s' is not allowed.".formatted(request.getMethod()) + ); + } + + if (!MCPToolsCall.METHOD.equals(request.getMethod())) { + return null; + } + + MCPToolsCall toolsCall; + try { + toolsCall = from(request); + } catch (IllegalArgumentException e) { + return new ValidationError( + responseId(request), + request.isNotification(), + 400, + ERR_INVALID_PARAMS, + e.getMessage() + ); + } + + if (isToolPermitted(toolsCall.getName())) { + return null; + } + + return new ValidationError( + responseId(request), + request.isNotification(), + 403, + ERR_INVALID_PARAMS, + "MCP tool '%s' is not allowed.".formatted(toolsCall.getName()) + ); + } + + boolean isToolPermitted(String name) { + for (Rule rule : toolRules) { + if (rule.matches(name)) { + return rule.permits(); + } + } + return true; + } + + private boolean isMethodAllowed(String method) { + return switch (method) { + case MCPInitialize.METHOD, MCPPing.METHOD -> true; + case MCPToolsList.METHOD -> methods.isToolsList(); + case MCPToolsCall.METHOD -> methods.isToolsCall(); + default -> methods.isNotifications() && method.startsWith("notifications/"); + }; + } + + private ValidationError invalidRequest(String message) { + return new ValidationError(null, false, 400, ERR_INVALID_REQUEST, message); + } + + private Object responseId(JSONRPCRequest request) { + return request.isNotification() ? null : request.getId(); + } + + record ValidationError( + Object responseId, + boolean notification, + int httpStatus, + int code, + String message + ) { + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/util/config/allowdeny/Rule.java b/core/src/main/java/com/predic8/membrane/core/util/config/allowdeny/Rule.java index a582d6bd87..46e09beda4 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/config/allowdeny/Rule.java +++ b/core/src/main/java/com/predic8/membrane/core/util/config/allowdeny/Rule.java @@ -33,7 +33,7 @@ public boolean matches(String probe) { if (probe == null) { return false; } - return compiledPattern != null && compiledPattern.matcher(probe).matches(); + return compiledPattern.matcher(probe).matches(); } public abstract boolean permits(); diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptorTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptorTest.java index 1aa5636b18..c9a9c56dbb 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptorTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/json/rpc/JsonRPCProtectionInterceptorTest.java @@ -93,6 +93,28 @@ class JsonRPCProtectionInterceptorTest { ok: type: boolean """; + private static final String COLLIDING_INLINE_RESPONSE_CONFIG = """ + schemaValidation: + methods: + 'rpc/a': + response: + schema: + type: object + required: + - ok + properties: + ok: + type: boolean + 'rpc_a': + response: + schema: + type: object + required: + - message + properties: + message: + type: string + """; private static final String BATCH_RESPONSE_CONFIG = """ schemaValidation: methods: @@ -254,6 +276,20 @@ void batchMaxSizeMustBePositive() { assertTrue(exception.getMessage().contains("batch maxSize must be greater than 0")); } + @Test + void skipsResponseValidationWithoutEligibleRequestMarker() throws Exception { + var interceptor = interceptor(ERROR_INLINE_CONFIG); + var exc = exchange(METHOD_GET, TEXT_PLAIN, null); + String responseBody = """ + {"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"broken"}} + """; + + exc.setResponse(jsonResponse(responseBody)); + + assertEquals(CONTINUE, interceptor.handleResponse(exc)); + assertEquals(responseBody, exc.getResponse().getBodyAsStringDecoded()); + } + private static Stream requestCases() { return Stream.of( requestContinues( @@ -507,6 +543,18 @@ private static Stream responseCases() { "Invalid result for method 'rpc.inline'", 1 ), + responseRejects( + "inline response schemas keep distinct locations for method names with same sanitized form", + COLLIDING_INLINE_RESPONSE_CONFIG, + """ + {"jsonrpc":"2.0","id":1,"method":"rpc_a"} + """, + """ + {"jsonrpc":"2.0","id":1,"result":{"ok":true}} + """, + "Invalid result for method 'rpc_a'", + 1 + ), responseContinues( "response validation uses exact method names", RESPONSE_LOCATION_CONFIG, @@ -539,6 +587,18 @@ private static Stream responseCases() { "JSON-RPC response id '9' does not match any request.", 9 ), + responseRejects( + "unknown error response ids are rejected", + ERROR_LOCATION_CONFIG, + """ + {"jsonrpc":"2.0","id":1,"method":"rpc.echo"} + """, + """ + {"jsonrpc":"2.0","id":9,"error":{"code":-32000,"message":"broken","data":{"reason":"timeout"}}} + """, + "JSON-RPC response id '9' does not match any request.", + 9 + ), responseRejects( "batch response validation resolves methods by request id", BATCH_RESPONSE_CONFIG, @@ -641,15 +701,13 @@ private static Stream responseCases() { "Invalid error response", 1 ), - responseRejects( - "error validation also works without prior request context", + responseContinues( + "error validation is skipped without prior request context", ERROR_INLINE_CONFIG, null, """ {"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"broken"}} - """, - "Invalid error response", - 1 + """ ) ); } diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptorTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptorTest.java new file mode 100644 index 0000000000..606bf8b06c --- /dev/null +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/mcp/MCPProtectionInterceptorTest.java @@ -0,0 +1,526 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.core.interceptor.mcp; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.predic8.membrane.annot.beanregistry.BeanRegistryImplementation; +import com.predic8.membrane.annot.yaml.ParsingContext; +import com.predic8.membrane.annot.yaml.parsing.GenericYamlParser; +import com.predic8.membrane.core.config.spring.GrammarAutoGenerated; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.http.Request; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.router.DefaultRouter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; +import static com.predic8.membrane.core.http.MimeType.TEXT_PLAIN; +import static com.predic8.membrane.core.http.Request.METHOD_GET; +import static com.predic8.membrane.core.http.Request.METHOD_POST; +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.interceptor.Outcome.RETURN; +import static com.predic8.membrane.core.jsonrpc.JSONRPCResponse.*; +import static org.junit.jupiter.api.Assertions.*; + +class MCPProtectionInterceptorTest { + + private static final ObjectMapper OM = new ObjectMapper(); + private static final ObjectMapper YAML = new ObjectMapper(new YAMLFactory()); + + @ParameterizedTest(name = "{0}") + @MethodSource("requestCases") + void validatesRequests(RequestCase testCase) throws Exception { + var exc = exchange(testCase.httpMethod(), testCase.contentType(), testCase.body()); + + assertEquals(testCase.expectedOutcome(), interceptor(testCase.config()).handleRequest(exc)); + + if (testCase.expectedOutcome() == CONTINUE) { + assertNull(exc.getResponse()); + return; + } + + assertError( + exc, + testCase.expectedStatus(), + testCase.expectedCode(), + testCase.expectedMessage(), + testCase.expectedId() + ); + if (testCase.expectedStatus() == 405) { + assertEquals(METHOD_POST, exc.getResponse().getHeader().getFirstValue("Allow")); + } + } + + private static Stream requestCases() { + String toolsEnabled = """ + methods: + toolsList: true + toolsCall: true + notifications: true + tools: + - allow: '^(listProxies|getStatistics|getExchanges)$' + - deny: '.*' + """; + + String toolsCallEnabled = """ + methods: + toolsCall: true + """; + + return Stream.of( + continues( + "initialize is always allowed", + "", + """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{}} + """ + ), + continues( + "ping is always allowed", + "", + """ + {"jsonrpc":"2.0","id":2,"method":"ping","params":{}} + """ + ), + continues( + "enabled tools list is allowed", + toolsEnabled, + """ + {"jsonrpc":"2.0","id":3,"method":"tools/list","params":{}} + """ + ), + rejects( + "unsupported MCP methods are always denied", + toolsEnabled, + """ + {"jsonrpc":"2.0","id":6,"method":"resources/read","params":{}} + """, + 403, + ERR_METHOD_NOT_FOUND, + "MCP method 'resources/read' is not allowed.", + 6 + ), + continues( + "allowed tool call", + toolsEnabled, + """ + {"jsonrpc":"2.0","id":7,"method":"tools/call","params":{"name":"getStatistics","arguments":{}}} + """ + ), + rejects( + "denied tool call", + toolsEnabled, + """ + {"jsonrpc":"2.0","id":8,"method":"tools/call","params":{"name":"deleteEverything","arguments":{}}} + """, + 403, + ERR_INVALID_PARAMS, + "MCP tool 'deleteEverything' is not allowed.", + 8 + ), + continues( + "first matching tool rule wins", + """ + methods: + toolsCall: true + tools: + - allow: '^admin\\.status$' + - deny: '^admin\\..*$' + """, + """ + {"jsonrpc":"2.0","id":9,"method":"tools/call","params":{"name":"admin.status"}} + """ + ), + rejects( + "tools call requires named params", + toolsCallEnabled, + """ + {"jsonrpc":"2.0","id":10,"method":"tools/call","params":[]} + """, + 400, + ERR_INVALID_PARAMS, + "requires a params object", + 10 + ), + rejects( + "tools call requires a tool name", + toolsCallEnabled, + """ + {"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"arguments":{}}} + """, + 400, + ERR_INVALID_PARAMS, + "'name' must be a non-empty string", + 11 + ), + rejects( + "batch requests are rejected", + "", + """ + [{"jsonrpc":"2.0","id":1,"method":"ping"}] + """, + 400, + ERR_INVALID_REQUEST, + "does not support JSON-RPC batch requests", + null + ), + rejects( + "invalid JSON is rejected", + "", + "not-json", + 400, + ERR_PARSE_ERROR, + "Invalid JSON payload.", + null + ), + rejects( + "invalid JSON-RPC structure is rejected", + "", + """ + {"jsonrpc":"2.0","id":12} + """, + 400, + ERR_INVALID_REQUEST, + "'method' must be a string", + null + ), + rejects( + "blank request bodies are rejected", + "", + " ", + 400, + ERR_INVALID_REQUEST, + "must not be empty", + null + ), + rejects( + "non-JSON content types are rejected", + "", + METHOD_POST, + TEXT_PLAIN, + "{}", + 415, + ERR_INVALID_REQUEST, + "Expected application/json", + null + ), + rejects( + "non-POST requests are rejected", + "", + METHOD_GET, + APPLICATION_JSON, + null, + 405, + ERR_INVALID_REQUEST, + "Expected POST", + null + ) + ); + } + + private static RequestCase continues(String name, String config, String body) { + return new RequestCase(name, config, METHOD_POST, APPLICATION_JSON, body, CONTINUE, null, null, null, null); + } + + private static RequestCase rejects(String name, + String config, + String body, + int status, + int code, + String message, + Object id) { + return rejects(name, config, METHOD_POST, APPLICATION_JSON, body, status, code, message, id); + } + + private static RequestCase rejects(String name, + String config, + String httpMethod, + String contentType, + String body, + int status, + int code, + String message, + Object id) { + return new RequestCase(name, config, httpMethod, contentType, body, RETURN, status, code, message, id); + } + + private MCPProtectionInterceptor interceptor(String config) throws Exception { + MCPProtectionInterceptor interceptor = parseInterceptor(config); + interceptor.init(new DefaultRouter()); + return interceptor; + } + + private MCPProtectionInterceptor parseInterceptor(String config) throws Exception { + JsonNode node = YAML.readTree(wrapConfig(config)); + var grammar = new GrammarAutoGenerated(); + var registry = new BeanRegistryImplementation(grammar); + return GenericYamlParser.createAndPopulateNode( + new ParsingContext<>("mcpProtection", registry, grammar, node, "$.mcpProtection", null), + MCPProtectionInterceptor.class, + node.get("mcpProtection") + ); + } + + private String wrapConfig(String config) { + if (config == null || config.isBlank()) { + return "mcpProtection: {}\n"; + } + return "mcpProtection:\n" + config.stripIndent().indent(2); + } + + private Exchange exchange(String method, String contentType, String body) { + Request.Builder builder = new Request.Builder() + .method(method) + .uri("/"); + if (contentType != null) { + builder.contentType(contentType); + } + if (body != null) { + builder.body(body); + } + return builder.buildExchange(); + } + + private void assertError(Exchange exc, int status, int code, String message, Object id) throws Exception { + assertNotNull(exc.getResponse()); + assertEquals(status, exc.getResponse().getStatusCode()); + assertEquals(APPLICATION_JSON, exc.getResponse().getHeader().getContentType()); + + JsonNode response = OM.readTree(exc.getResponse().getBodyAsStringDecoded()); + assertEquals("2.0", response.path("jsonrpc").asText()); + assertEquals(code, response.path("error").path("code").asInt()); + assertTrue(response.path("error").path("message").asText().contains(message)); + + if (id == null) { + assertTrue(response.path("id").isNull()); + } else { + assertEquals(OM.valueToTree(id), response.path("id")); + } + } + + private record RequestCase( + String name, + String config, + String httpMethod, + String contentType, + String body, + Outcome expectedOutcome, + Integer expectedStatus, + Integer expectedCode, + String expectedMessage, + Object expectedId + ) { + @Override + public String toString() { + return name; + } + } + + @Test + void suppressesJsonRpcErrorForRejectedNotification() throws Exception { + var exc = exchange(METHOD_POST, APPLICATION_JSON, """ + {"jsonrpc":"2.0","method":"notifications/cancelled","params":{}} + """); + + assertEquals(RETURN, interceptor("methods:\n notifications: false").handleRequest(exc)); + assertNotNull(exc.getResponse()); + assertEquals(403, exc.getResponse().getStatusCode()); + assertTrue(exc.getResponse().getBodyAsStringDecoded().isEmpty()); + } + + @Test + void sendsJsonRpcErrorForRejectedRequestWithNullId() throws Exception { + var exc = exchange(METHOD_POST, APPLICATION_JSON, """ + {"jsonrpc":"2.0","id":null,"method":"resources/read","params":{}} + """); + + assertEquals(RETURN, interceptor("").handleRequest(exc)); + assertError( + exc, + 403, + ERR_METHOD_NOT_FOUND, + "MCP method 'resources/read' is not allowed.", + null + ); + } + + @Test + void filtersDeniedToolsFromToolsListResponse() throws Exception { + String config = """ + tools: + - allow: '^safe\\..*$' + - deny: '.*' + """; + + var interceptor = interceptor(config); + var exc = exchange( + METHOD_POST, + APPLICATION_JSON, + """ + {"jsonrpc":"2.0","id":21,"method":"tools/list","params":{}} + """ + ); + + assertEquals(CONTINUE, interceptor.handleRequest(exc)); + exc.setResponse(Response.ok() + .contentType(APPLICATION_JSON) + .body(""" + { + "jsonrpc": "2.0", + "id": 21, + "result": { + "tools": [ + {"name": "safe.read"}, + {"name": "hidden.delete"}, + {"name": "public.info"} + ], + "nextCursor": "page-2", + "untouched": true + } + } + """) + .build()); + + assertEquals(CONTINUE, interceptor.handleResponse(exc)); + assertEquals(List.of("safe.read"), toolNames(exc)); + + JsonNode response = OM.readTree(exc.getResponse().getBodyAsStringDecoded()); + assertEquals("page-2", response.path("result").path("nextCursor").asText()); + assertTrue(response.path("result").path("untouched").asBoolean()); + } + + @Test + void hidesDeniedToolsFromToolsListResponseWhenUnmatchedToolsAreAllowedByDefault() throws Exception { + String config = """ + tools: + - deny: '^private\\..*$' + """; + + var interceptor = interceptor(config); + var exc = toolsListExchange(23); + + assertEquals(CONTINUE, interceptor.handleRequest(exc)); + exc.setResponse(Response.ok() + .contentType(APPLICATION_JSON) + .body(""" + { + "jsonrpc": "2.0", + "id": 23, + "result": { + "tools": [ + {"name": "public.info"}, + {"name": "private.delete"}, + {"name": "unmatched.tool"} + ] + } + } + """) + .build()); + + assertEquals(CONTINUE, interceptor.handleResponse(exc)); + assertEquals(List.of("public.info", "unmatched.tool"), toolNames(exc)); + } + + @Test + void hidesToolsDeniedByFirstMatchingRuleFromToolsListResponse() throws Exception { + String config = """ + tools: + - deny: '^admin\\..*$' + - allow: '^admin\\.status$' + - allow: '^public\\..*$' + """; + + var interceptor = interceptor(config); + var exc = toolsListExchange(24); + + assertEquals(CONTINUE, interceptor.handleRequest(exc)); + exc.setResponse(Response.ok() + .contentType(APPLICATION_JSON) + .body(""" + { + "jsonrpc": "2.0", + "id": 24, + "result": { + "tools": [ + {"name": "admin.status"}, + {"name": "admin.delete"}, + {"name": "public.info"}, + {"name": "unmatched.tool"} + ] + } + } + """) + .build()); + + assertEquals(CONTINUE, interceptor.handleResponse(exc)); + assertEquals(List.of("public.info", "unmatched.tool"), toolNames(exc)); + } + + @Test + void leavesNonToolsListResponsesUnchanged() throws Exception { + var interceptor = interceptor("tools:\n - deny: '.*'"); + var exc = exchange( + METHOD_POST, + APPLICATION_JSON, + """ + {"jsonrpc":"2.0","id":22,"method":"ping","params":{}} + """ + ); + + assertEquals(CONTINUE, interceptor.handleRequest(exc)); + String body = """ + { + "jsonrpc": "2.0", + "id": 22, + "result": {"tools": [{"name": "hidden.delete"}]} + } + """; + exc.setResponse(Response.ok() + .contentType(APPLICATION_JSON) + .body(body) + .build()); + + assertEquals(CONTINUE, interceptor.handleResponse(exc)); + assertEquals(body, exc.getResponse().getBodyAsStringDecoded()); + } + + private Exchange toolsListExchange(int id) { + return exchange( + METHOD_POST, + APPLICATION_JSON, + """ + {"jsonrpc":"2.0","id":%s,"method":"tools/list","params":{}} + """.formatted(id) + ); + } + + private List toolNames(Exchange exc) throws Exception { + var names = new ArrayList(); + OM.readTree(exc.getResponse().getBodyAsStringDecoded()) + .path("result") + .path("tools") + .forEach(tool -> names.add(tool.path("name").asText())); + return names; + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/AbstractMcpTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/AbstractMcpTutorialTest.java new file mode 100644 index 0000000000..37422cbb79 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/AbstractMcpTutorialTest.java @@ -0,0 +1,25 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.mcp; + +import com.predic8.membrane.tutorials.AbstractMembraneTutorialTest; + +public abstract class AbstractMcpTutorialTest extends AbstractMembraneTutorialTest { + + @Override + protected String getTutorialDir() { + return "ai/mcp"; + } +} diff --git a/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/McpProtectionTutorialTest.java b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/McpProtectionTutorialTest.java new file mode 100644 index 0000000000..acf4ea2219 --- /dev/null +++ b/distribution/src/test/java/com/predic8/membrane/tutorials/ai/mcp/McpProtectionTutorialTest.java @@ -0,0 +1,145 @@ +/* Copyright 2026 predic8 GmbH, www.predic8.com + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +package com.predic8.membrane.tutorials.ai.mcp; + +import io.restassured.response.Response; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.restassured.RestAssured.given; +import static io.restassured.http.ContentType.JSON; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; + +public class McpProtectionTutorialTest extends AbstractMcpTutorialTest { + + private String sessionId; + + @Override + protected String getTutorialYaml() { + return "20-MCP-Protection.yaml"; + } + + @BeforeEach + void initializeMcpSession() { + // @formatter:off + Response initialize = given() + .contentType(JSON) + .body(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "tutorial-test", + "version": "1.0" + } + } + } + """) + .when() + .post("http://localhost:2000") + .then() + .statusCode(200) + .contentType(JSON) + .header("Mcp-Session-Id", notNullValue()) + .extract() + .response(); + + sessionId = initialize.getHeader("Mcp-Session-Id"); + + given() + .contentType(JSON) + .header("Mcp-Session-Id", sessionId) + .body(""" + {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} + """) + .when() + .post("http://localhost:2000") + .then() + .statusCode(202); + // @formatter:on + } + + @Test + void filtersDeniedToolsFromToolsList() { + // @formatter:off + given() + .contentType(JSON) + .header("Mcp-Session-Id", sessionId) + .body(""" + {"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}} + """) + .when() + .post("http://localhost:2000") + .then() + .statusCode(200) + .contentType(JSON) + .body("jsonrpc", equalTo("2.0")) + .body("id", equalTo(2)) + .body("result.tools.name", contains("listProxies", "getStatistics")) + .body("result.tools.name", not(hasItem("getExchanges"))); + // @formatter:on + } + + @Test + void allowsConfiguredTool() { + // @formatter:off + given() + .contentType(JSON) + .header("Mcp-Session-Id", sessionId) + .body(""" + {"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"listProxies","arguments":{}}} + """) + .when() + .post("http://localhost:2000") + .then() + .statusCode(200) + .contentType(JSON) + .body("jsonrpc", equalTo("2.0")) + .body("id", equalTo(3)) + .body("result.content[0].type", equalTo("text")) + .body("result.content[0].text", containsString("Protected-MCP-Endpoint")); + // @formatter:on + } + + @Test + void rejectsBlockedTool() { + // @formatter:off + given() + .contentType(JSON) + .header("Mcp-Session-Id", sessionId) + .body(""" + {"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"getExchanges","arguments":{}}} + """) + .when() + .post("http://localhost:2000") + .then() + .statusCode(403) + .contentType(JSON) + .body("jsonrpc", equalTo("2.0")) + .body("id", equalTo(4)) + .body("error.code", equalTo(-32602)) + .body("error.message", containsString("getExchanges")); + // @formatter:on + } +} diff --git a/distribution/tutorials/README.md b/distribution/tutorials/README.md index 771f8c14e0..fc66f17ad1 100644 --- a/distribution/tutorials/README.md +++ b/distribution/tutorials/README.md @@ -32,7 +32,7 @@ Run and observe Membrane in production. ## [AI / MCP](ai/mcp) -Expose Membrane as an MCP server for AI clients, inspect recent API traffic, and protect the MCP endpoint with an API key. +Expose Membrane as an MCP server for AI clients, inspect recent API traffic, restrict MCP tools, and protect the endpoint with an API key. ## [Security](security) diff --git a/distribution/tutorials/ai/mcp/20-MCP-Protection.yaml b/distribution/tutorials/ai/mcp/20-MCP-Protection.yaml new file mode 100644 index 0000000000..a724c5093c --- /dev/null +++ b/distribution/tutorials/ai/mcp/20-MCP-Protection.yaml @@ -0,0 +1,87 @@ +# yaml-language-server: $schema=https://www.membrane-api.io/v7.3.0.json +# +# Tutorial: MCP Protection +# +# Small MCP protection example: +# - exposes the protected MCP endpoint on http://localhost:2000 +# - forwards accepted MCP requests to the local Membrane MCP server on port 3000 +# - allows only the tools `listProxies` and `getStatistics` +# - hides all other tools from `tools/list` +# +# Notes: +# - `initialize` and `ping` are always allowed. +# - Tool rules are evaluated top-down. +# - The first matching rule wins. +# +# Try: +# +# 1.) Start Membrane: +# +# Linux/Mac: +# ./membrane.sh -c 20-MCP-Protection.yaml +# Windows: +# membrane.cmd -c 20-MCP-Protection.yaml +# +# 2.) Start the MCP Inspector in a second terminal. It usually opens the +# browser automatically: +# +# npx @modelcontextprotocol/inspector +# +# 3.) In the Inspector, connect to this MCP endpoint: +# Transport Type: Streamable HTTP +# URL: http://localhost:2000 +# +# Then click Connect. +# +# 4.) Open the Tools tab and list tools. +# +# The list should contain only `listProxies` and `getStatistics`. +# +# 5.) Call `listProxies`. +# +# 6.) `getExchanges` is also provided by the Membrane MCP server on port 3000, +# but it is hidden from the Inspector and blocked by `mcpProtection`. +# To verify the block, send a manual JSON-RPC request: +# +# curl -i -d '{"jsonrpc":"2.0","id":99,"method":"tools/call","params":{"name":"getExchanges","arguments":{}}}' -H "Content-Type: application/json" http://localhost:2000 +# +# The response should be 403 with a JSON-RPC error explaining that the tool +# is not allowed. + +api: + port: 2000 + name: Protected-MCP-Endpoint + flow: + - mcpProtection: + tools: + - allow: '^(listProxies|getStatistics)$' + - deny: '.*' + target: + url: http://localhost:3000 + +--- +# Local MCP server protected by the API above. +api: + port: 3000 + name: MCP-Server + flow: + - membraneMCPServer: + maxExchanges: 100 + +--- +# Demo API that makes `listProxies` return something useful. +api: + port: 3001 + name: Fruitshop + path: + uri: /shop/v2/ + target: + url: https://api.predic8.de + +--- +# Second demo API, also visible through the allowed `listProxies` tool. +api: + port: 3002 + name: ApiBin + target: + url: https://apibin.io/ diff --git a/distribution/tutorials/ai/mcp/README.md b/distribution/tutorials/ai/mcp/README.md index bf606aaf2b..b7c4c3aa7c 100644 --- a/distribution/tutorials/ai/mcp/README.md +++ b/distribution/tutorials/ai/mcp/README.md @@ -5,10 +5,12 @@ It covers: - running a local MCP server - inspecting recent API traffic through MCP +- allowing only selected MCP tools with `mcpProtection` To begin, open [10-MCP-Server.yaml](10-MCP-Server.yaml) and follow the instructions in the file. +Then continue with [20-MCP-Protection.yaml](20-MCP-Protection.yaml) to test a protected MCP endpoint with the MCP Inspector UI. -If you want to protect the MCP endpoint, put the `membraneMCPServer` behind an `apiKey` interceptor and expose it only through the network path you actually want clients to use. If the client should keep talking to a simple local URL, you can also place a small local proxy in front of the protected MCP endpoint and let that proxy add the required authentication details. +Use `mcpProtection` to restrict MCP methods and tools. If you also need client authentication, put the `membraneMCPServer` behind an `apiKey` interceptor and expose it only through the network path you actually want clients to use. If the client should keep talking to a simple local URL, you can also place a small local proxy in front of the protected MCP endpoint and let that proxy add the required authentication details. ## Claude Desktop Setup