diff --git a/pom.xml b/pom.xml index 1ac460d..25d4605 100644 --- a/pom.xml +++ b/pom.xml @@ -64,6 +64,12 @@ 3.5.4 test + + org.eclipse.jetty + jetty-servlet + 11.0.20 + test + com.fasterxml.jackson.core jackson-databind diff --git a/src/test/java/io/prerender/PrerenderFilterTest.java b/src/test/java/io/prerender/PrerenderFilterTest.java index 3cd41ff..ca010ff 100644 --- a/src/test/java/io/prerender/PrerenderFilterTest.java +++ b/src/test/java/io/prerender/PrerenderFilterTest.java @@ -1,68 +1,104 @@ package io.prerender; -import com.github.tomakehurst.wiremock.client.WireMock; -import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.junit5.WireMockExtension; -import jakarta.servlet.FilterChain; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.FilterHolder; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.RegisterExtension; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import java.io.PrintWriter; -import java.io.StringWriter; +import java.io.IOException; +import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.EnumSet; -import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; -@ExtendWith(MockitoExtension.class) +/** + * Drives the filter inside a real servlet container (embedded Jetty), hit by a real + * HTTP client. The upstream Prerender service is faked with WireMock. Catches + * container-level behaviour (filter chain wiring, request URL/query, status and + * header propagation, static-asset pass-through) that Mockito on servlet objects + * would miss. + */ class PrerenderFilterTest { private static final String BOT_UA = "Mozilla/5.0 (compatible; Googlebot/2.1)"; private static final String BROWSER_UA = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"; private static final String PRERENDERED_HTML = "prerendered"; + private static final String ORIGINAL = "original"; @RegisterExtension static WireMockExtension wireMock = WireMockExtension.newInstance() .options(wireMockConfig().dynamicPort()) .build(); - @Mock private HttpServletRequest request; - @Mock private HttpServletResponse response; - @Mock private FilterChain chain; + private static Server jetty; + private static String baseUrl; + private static final HttpClient httpClient = HttpClient.newHttpClient(); + + @BeforeAll + static void startJetty() throws Exception { + jetty = new Server(0); + ServletContextHandler context = new ServletContextHandler(); + context.setContextPath("/"); + + FilterHolder filter = new FilterHolder(new PrerenderFilter()); + filter.setInitParameter("prerenderToken", "test-token"); + filter.setInitParameter("prerenderServiceUrl", wireMock.baseUrl()); + context.addFilter(filter, "/*", EnumSet.of(DispatcherType.REQUEST)); + context.addServlet(new ServletHolder(new OriginalServlet()), "/*"); + + jetty.setHandler(context); + jetty.start(); + int port = ((ServerConnector) jetty.getConnectors()[0]).getLocalPort(); + baseUrl = "http://127.0.0.1:" + port; + } - private StringWriter responseWriter; - private PrerenderFilter filter; + @AfterAll + static void stopJetty() throws Exception { + if (jetty != null) jetty.stop(); + } @BeforeEach - void setUp() throws Exception { + void resetStubs() { wireMock.resetAll(); - responseWriter = new StringWriter(); - lenient().when(response.getWriter()).thenReturn(new PrintWriter(responseWriter)); - PrerenderConfig config = new PrerenderConfig(null, "http://localhost:" + wireMock.getPort()); - filter = new PrerenderFilter(HttpClient.newHttpClient(), config); + } + + private HttpResponse send(String method, String path, String userAgent, String... extraHeaders) throws Exception { + HttpRequest.Builder b = HttpRequest.newBuilder(URI.create(baseUrl + path)) + .header("User-Agent", userAgent); + for (int i = 0; i + 1 < extraHeaders.length; i += 2) { + b.header(extraHeaders[i], extraHeaders[i + 1]); + } + HttpRequest req = "POST".equals(method) + ? b.POST(HttpRequest.BodyPublishers.noBody()).build() + : b.GET().build(); + return httpClient.send(req, HttpResponse.BodyHandlers.ofString()); } @Test void browserRequest_passesThrough() throws Exception { - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/"); - when(request.getParameter("_escaped_fragment_")).thenReturn(null); - when(request.getHeader("X-Bufferbot")).thenReturn(null); - when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA); - - filter.doFilter(request, response, chain); + HttpResponse res = send("GET", "/", BROWSER_UA); - verify(chain).doFilter(request, response); - verify(response, never()).setStatus(anyInt()); + assertEquals(200, res.statusCode()); + assertEquals(ORIGINAL, res.body()); } @Test @@ -70,30 +106,18 @@ void botRequest_receivesPrerenderedResponse() throws Exception { wireMock.stubFor(get(anyUrl()) .willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML))); - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/"); - when(request.getParameter("_escaped_fragment_")).thenReturn(null); - when(request.getHeader("X-Bufferbot")).thenReturn(null); - when(request.getHeader("User-Agent")).thenReturn(BOT_UA); - when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/")); - when(request.getQueryString()).thenReturn(null); - - filter.doFilter(request, response, chain); + HttpResponse res = send("GET", "/about", BOT_UA); - verify(response).setStatus(200); - verify(chain, never()).doFilter(any(), any()); - assertEquals(PRERENDERED_HTML, responseWriter.toString()); + assertEquals(200, res.statusCode()); + assertEquals(PRERENDERED_HTML, res.body()); } @Test void botRequest_staticAsset_passesThrough() throws Exception { - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/styles.css"); + HttpResponse res = send("GET", "/styles.css", BOT_UA); - filter.doFilter(request, response, chain); - - verify(chain).doFilter(request, response); - verify(response, never()).setStatus(anyInt()); + assertEquals(200, res.statusCode()); + assertEquals(ORIGINAL, res.body()); } @Test @@ -101,17 +125,10 @@ void escapedFragment_triggersPrerender() throws Exception { wireMock.stubFor(get(anyUrl()) .willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML))); - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/"); - when(request.getParameter("_escaped_fragment_")).thenReturn(""); - when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA); - when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/")); - when(request.getQueryString()).thenReturn("_escaped_fragment_="); - - filter.doFilter(request, response, chain); + HttpResponse res = send("GET", "/?_escaped_fragment_=", BROWSER_UA); - verify(response).setStatus(200); - verify(chain, never()).doFilter(any(), any()); + assertEquals(200, res.statusCode()); + assertEquals(PRERENDERED_HTML, res.body()); } @Test @@ -119,45 +136,37 @@ void xBufferbot_triggersPrerender() throws Exception { wireMock.stubFor(get(anyUrl()) .willReturn(aResponse().withStatus(200).withBody(PRERENDERED_HTML))); - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/"); - when(request.getParameter("_escaped_fragment_")).thenReturn(null); - when(request.getHeader("X-Bufferbot")).thenReturn("true"); - when(request.getHeader("User-Agent")).thenReturn(BROWSER_UA); - when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/")); - when(request.getQueryString()).thenReturn(null); - - filter.doFilter(request, response, chain); + HttpResponse res = send("GET", "/", BROWSER_UA, "X-Bufferbot", "true"); - verify(response).setStatus(200); - verify(chain, never()).doFilter(any(), any()); + assertEquals(200, res.statusCode()); + assertEquals(PRERENDERED_HTML, res.body()); } @Test void postRequest_passesThrough() throws Exception { - when(request.getMethod()).thenReturn("POST"); + HttpResponse res = send("POST", "/", BOT_UA); - filter.doFilter(request, response, chain); - - verify(chain).doFilter(request, response); - verify(response, never()).setStatus(anyInt()); + assertEquals(200, res.statusCode()); + assertEquals(ORIGINAL, res.body()); } @Test void networkError_fallsBackToNormalResponse() throws Exception { wireMock.stubFor(get(anyUrl()) - .willReturn(aResponse().withFault(com.github.tomakehurst.wiremock.http.Fault.CONNECTION_RESET_BY_PEER))); + .willReturn(aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER))); - when(request.getMethod()).thenReturn("GET"); - when(request.getRequestURI()).thenReturn("/"); - when(request.getParameter("_escaped_fragment_")).thenReturn(null); - when(request.getHeader("X-Bufferbot")).thenReturn(null); - when(request.getHeader("User-Agent")).thenReturn(BOT_UA); - when(request.getRequestURL()).thenReturn(new StringBuffer("http://example.com/")); - when(request.getQueryString()).thenReturn(null); + HttpResponse res = send("GET", "/", BOT_UA); - filter.doFilter(request, response, chain); + assertEquals(200, res.statusCode()); + assertEquals(ORIGINAL, res.body()); + } - verify(chain).doFilter(request, response); + public static class OriginalServlet extends HttpServlet { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException { + resp.setStatus(200); + resp.setContentType("text/plain"); + resp.getWriter().write(ORIGINAL); + } } }