diff --git a/pom.xml b/pom.xml
index 1ac460d..25d4605 100644
--- a/pom.xml
+++ b/pom.xml
@@ -64,6 +64,12 @@
3.5.4
test
+
+ org.eclipse.jetty
+ jetty-servlet
+ 11.0.20
+ test
+
com.fasterxml.jackson.core
jackson-databind
diff --git a/src/test/java/io/prerender/PrerenderFilterTest.java b/src/test/java/io/prerender/PrerenderFilterTest.java
index 3cd41ff..ca010ff 100644
--- a/src/test/java/io/prerender/PrerenderFilterTest.java
+++ b/src/test/java/io/prerender/PrerenderFilterTest.java
@@ -1,68 +1,104 @@
package io.prerender;
-import com.github.tomakehurst.wiremock.client.WireMock;
-import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
+import com.github.tomakehurst.wiremock.http.Fault;
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
-import jakarta.servlet.FilterChain;
+import jakarta.servlet.DispatcherType;
+import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
+import org.eclipse.jetty.server.Server;
+import org.eclipse.jetty.server.ServerConnector;
+import org.eclipse.jetty.servlet.FilterHolder;
+import org.eclipse.jetty.servlet.ServletContextHandler;
+import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
-import org.mockito.Mock;
-import org.mockito.junit.jupiter.MockitoExtension;
-import java.io.PrintWriter;
-import java.io.StringWriter;
+import java.io.IOException;
+import java.net.URI;
import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.util.EnumSet;
-import static com.github.tomakehurst.wiremock.client.WireMock.*;
+import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
+import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
+import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.*;
-@ExtendWith(MockitoExtension.class)
+/**
+ * Drives the filter inside a real servlet container (embedded Jetty), hit by a real
+ * HTTP client. The upstream Prerender service is faked with WireMock. Catches
+ * container-level behaviour (filter chain wiring, request URL/query, status and
+ * header propagation, static-asset pass-through) that Mockito on servlet objects
+ * would miss.
+ */
class PrerenderFilterTest {
private static final String BOT_UA = "Mozilla/5.0 (compatible; Googlebot/2.1)";
private static final String BROWSER_UA = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36";
private static final String PRERENDERED_HTML = "prerendered";
+ private static final String ORIGINAL = "original";
@RegisterExtension
static WireMockExtension wireMock = WireMockExtension.newInstance()
.options(wireMockConfig().dynamicPort())
.build();
- @Mock private HttpServletRequest request;
- @Mock private HttpServletResponse response;
- @Mock private FilterChain chain;
+ private static Server jetty;
+ private static String baseUrl;
+ private static final HttpClient httpClient = HttpClient.newHttpClient();
+
+ @BeforeAll
+ static void startJetty() throws Exception {
+ jetty = new Server(0);
+ ServletContextHandler context = new ServletContextHandler();
+ context.setContextPath("/");
+
+ FilterHolder filter = new FilterHolder(new PrerenderFilter());
+ filter.setInitParameter("prerenderToken", "test-token");
+ filter.setInitParameter("prerenderServiceUrl", wireMock.baseUrl());
+ context.addFilter(filter, "/*", EnumSet.of(DispatcherType.REQUEST));
+ context.addServlet(new ServletHolder(new OriginalServlet()), "/*");
+
+ jetty.setHandler(context);
+ jetty.start();
+ int port = ((ServerConnector) jetty.getConnectors()[0]).getLocalPort();
+ baseUrl = "http://127.0.0.1:" + port;
+ }
- private StringWriter responseWriter;
- private PrerenderFilter filter;
+ @AfterAll
+ static void stopJetty() throws Exception {
+ if (jetty != null) jetty.stop();
+ }
@BeforeEach
- void setUp() throws Exception {
+ void resetStubs() {
wireMock.resetAll();
- responseWriter = new StringWriter();
- lenient().when(response.getWriter()).thenReturn(new PrintWriter(responseWriter));
- PrerenderConfig config = new PrerenderConfig(null, "http://localhost:" + wireMock.getPort());
- filter = new PrerenderFilter(HttpClient.newHttpClient(), config);
+ }
+
+ private HttpResponse send(String method, String path, String userAgent, String... extraHeaders) throws Exception {
+ HttpRequest.Builder b = HttpRequest.newBuilder(URI.create(baseUrl + path))
+ .header("User-Agent", userAgent);
+ for (int i = 0; i + 1 < extraHeaders.length; i += 2) {
+ b.header(extraHeaders[i], extraHeaders[i + 1]);
+ }
+ HttpRequest req = "POST".equals(method)
+ ? b.POST(HttpRequest.BodyPublishers.noBody()).build()
+ : b.GET().build();
+ return httpClient.send(req, HttpResponse.BodyHandlers.ofString());
}
@Test
void browserRequest_passesThrough() throws Exception {
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/");
- when(request.getParameter("_escaped_fragment_")).thenReturn(null);
- when(request.getHeader("X-Bufferbot")).thenReturn(null);
- when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA);
-
- filter.doFilter(request, response, chain);
+ HttpResponse res = send("GET", "/", BROWSER_UA);
- verify(chain).doFilter(request, response);
- verify(response, never()).setStatus(anyInt());
+ assertEquals(200, res.statusCode());
+ assertEquals(ORIGINAL, res.body());
}
@Test
@@ -70,30 +106,18 @@ void botRequest_receivesPrerenderedResponse() throws Exception {
wireMock.stubFor(get(anyUrl())
.willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML)));
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/");
- when(request.getParameter("_escaped_fragment_")).thenReturn(null);
- when(request.getHeader("X-Bufferbot")).thenReturn(null);
- when(request.getHeader("User-Agent")).thenReturn(BOT_UA);
- when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/"));
- when(request.getQueryString()).thenReturn(null);
-
- filter.doFilter(request, response, chain);
+ HttpResponse res = send("GET", "/about", BOT_UA);
- verify(response).setStatus(200);
- verify(chain, never()).doFilter(any(), any());
- assertEquals(PRERENDERED_HTML, responseWriter.toString());
+ assertEquals(200, res.statusCode());
+ assertEquals(PRERENDERED_HTML, res.body());
}
@Test
void botRequest_staticAsset_passesThrough() throws Exception {
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/styles.css");
+ HttpResponse res = send("GET", "/styles.css", BOT_UA);
- filter.doFilter(request, response, chain);
-
- verify(chain).doFilter(request, response);
- verify(response, never()).setStatus(anyInt());
+ assertEquals(200, res.statusCode());
+ assertEquals(ORIGINAL, res.body());
}
@Test
@@ -101,17 +125,10 @@ void escapedFragment_triggersPrerender() throws Exception {
wireMock.stubFor(get(anyUrl())
.willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML)));
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/");
- when(request.getParameter("_escaped_fragment_")).thenReturn("");
- when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA);
- when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/"));
- when(request.getQueryString()).thenReturn("_escaped_fragment_=");
-
- filter.doFilter(request, response, chain);
+ HttpResponse res = send("GET", "/?_escaped_fragment_=", BROWSER_UA);
- verify(response).setStatus(200);
- verify(chain, never()).doFilter(any(), any());
+ assertEquals(200, res.statusCode());
+ assertEquals(PRERENDERED_HTML, res.body());
}
@Test
@@ -119,45 +136,37 @@ void xBufferbot_triggersPrerender() throws Exception {
wireMock.stubFor(get(anyUrl())
.willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML)));
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/");
- when(request.getParameter("_escaped_fragment_")).thenReturn(null);
- when(request.getHeader("X-Bufferbot")).thenReturn("true");
- when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA);
- when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/"));
- when(request.getQueryString()).thenReturn(null);
-
- filter.doFilter(request, response, chain);
+ HttpResponse res = send("GET", "/", BROWSER_UA, "X-Bufferbot", "true");
- verify(response).setStatus(200);
- verify(chain, never()).doFilter(any(), any());
+ assertEquals(200, res.statusCode());
+ assertEquals(PRERENDERED_HTML, res.body());
}
@Test
void postRequest_passesThrough() throws Exception {
- when(request.getMethod()).thenReturn("POST");
+ HttpResponse res = send("POST", "/", BOT_UA);
- filter.doFilter(request, response, chain);
-
- verify(chain).doFilter(request, response);
- verify(response, never()).setStatus(anyInt());
+ assertEquals(200, res.statusCode());
+ assertEquals(ORIGINAL, res.body());
}
@Test
void networkError_fallsBackToNormalResponse() throws Exception {
wireMock.stubFor(get(anyUrl())
- .willReturn(aResponse().withFault(com.github.tomakehurst.wiremock.http.Fault.CONNECTION_RESET_BY_PEER)));
+ .willReturn(aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER)));
- when(request.getMethod()).thenReturn("GET");
- when(request.getRequestURI()).thenReturn("/");
- when(request.getParameter("_escaped_fragment_")).thenReturn(null);
- when(request.getHeader("X-Bufferbot")).thenReturn(null);
- when(request.getHeader("User-Agent")).thenReturn(BOT_UA);
- when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/"));
- when(request.getQueryString()).thenReturn(null);
+ HttpResponse res = send("GET", "/", BOT_UA);
- filter.doFilter(request, response, chain);
+ assertEquals(200, res.statusCode());
+ assertEquals(ORIGINAL, res.body());
+ }
- verify(chain).doFilter(request, response);
+ public static class OriginalServlet extends HttpServlet {
+ @Override
+ protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException {
+ resp.setStatus(200);
+ resp.setContentType("text/plain");
+ resp.getWriter().write(ORIGINAL);
+ }
}
}