Skip to content

Commit 3be951e

Browse files
committed
Wrap classes before validating result
1 parent 8f4510a commit 3be951e

8 files changed

Lines changed: 439 additions & 203 deletions

File tree

src/main/java/com/hubspot/jinjava/el/ext/MethodValidator.java

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.hubspot.jinjava.el.ext;
22

33
import com.google.common.collect.ImmutableSet;
4+
import com.hubspot.jinjava.interpret.JinjavaInterpreter;
45
import java.lang.reflect.Method;
56

67
public final class MethodValidator {
@@ -39,32 +40,39 @@ public Method validateMethod(Method m) {
3940
if (m == null) {
4041
return null;
4142
}
42-
String canonicalDeclaringClassName = m.getDeclaringClass().getCanonicalName();
43-
return (
44-
allowedMethods.contains(m) ||
45-
allowedDeclaredMethodsFromCanonicalClassNames.contains(
46-
canonicalDeclaringClassName
47-
) ||
48-
allowedDeclaredMethodsFromCanonicalClassPrefixes
49-
.stream()
50-
.anyMatch(canonicalDeclaringClassName::startsWith)
51-
)
52-
? m
53-
: null;
43+
Class<?> clazz = m.getDeclaringClass();
44+
String canonicalClassName = clazz.getCanonicalName();
45+
if (
46+
allowedMethods.contains(m) ||
47+
allowedDeclaredMethodsFromCanonicalClassNames.contains(canonicalClassName) ||
48+
allowedDeclaredMethodsFromCanonicalClassPrefixes
49+
.stream()
50+
.anyMatch(canonicalClassName::startsWith)
51+
) {
52+
return m;
53+
}
54+
return null;
5455
}
5556

5657
public Object validateResult(Object o) {
57-
if (o == null) {
58+
Object wrapped = JinjavaInterpreter
59+
.getCurrentMaybe()
60+
.map(jinjavaInterpreter -> jinjavaInterpreter.wrap(o))
61+
.orElse(o);
62+
63+
if (wrapped == null) {
5864
return null;
5965
}
60-
String canonicalClassName = o.getClass().getCanonicalName();
61-
return (
62-
allowedResultCanonicalClassNames.contains(canonicalClassName) ||
63-
allowedResultCanonicalClassPrefixes
64-
.stream()
65-
.anyMatch(canonicalClassName::startsWith)
66-
)
67-
? o
68-
: null;
66+
Class<?> clazz = wrapped.getClass();
67+
String canonicalClassName = clazz.getCanonicalName();
68+
if (
69+
allowedResultCanonicalClassNames.contains(canonicalClassName) ||
70+
allowedResultCanonicalClassPrefixes
71+
.stream()
72+
.anyMatch(canonicalClassName::startsWith)
73+
) {
74+
return wrapped;
75+
}
76+
return null;
6977
}
7078
}

src/main/java/com/hubspot/jinjava/el/ext/MethodValidatorConfig.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
package com.hubspot.jinjava.el.ext;
22

33
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import com.google.common.collect.ForwardingCollection;
5+
import com.google.common.collect.ForwardingList;
6+
import com.google.common.collect.ForwardingMap;
47
import com.google.common.collect.ImmutableSet;
58
import com.hubspot.jinjava.JinjavaImmutableStyle;
69
import com.hubspot.jinjava.lib.exptest.ExpTest;
710
import com.hubspot.jinjava.lib.filter.Filter;
11+
import com.hubspot.jinjava.objects.DummyObject;
12+
import com.hubspot.jinjava.objects.Namespace;
13+
import com.hubspot.jinjava.objects.SafeString;
814
import com.hubspot.jinjava.objects.collections.PyList;
915
import com.hubspot.jinjava.objects.collections.PyMap;
1016
import com.hubspot.jinjava.objects.collections.SizeLimitingPyList;
1117
import com.hubspot.jinjava.objects.collections.SizeLimitingPyMap;
18+
import com.hubspot.jinjava.objects.collections.SnakeCaseAccessibleMap;
19+
import com.hubspot.jinjava.objects.date.FormattedDate;
20+
import com.hubspot.jinjava.objects.date.PyishDate;
1221
import com.hubspot.jinjava.util.ForLoop;
1322
import java.lang.reflect.Method;
23+
import java.math.BigDecimal;
24+
import java.util.AbstractCollection;
25+
import java.util.ArrayList;
1426
import java.util.Arrays;
1527
import java.util.function.Function;
1628
import java.util.stream.Stream;
@@ -145,6 +157,7 @@ public enum AllowlistGroup {
145157
float.class,
146158
boolean.class,
147159
short.class,
160+
BigDecimal.class,
148161
};
149162

150163
@Override
@@ -160,12 +173,36 @@ String[] allowedDeclaredMethodsFromCanonicalClassPrefixes() {
160173
return ARRAY;
161174
}
162175
},
176+
JinjavaObjects {
177+
private static final Class<?>[] ARRAY = {
178+
PyList.class,
179+
PyMap.class,
180+
SizeLimitingPyMap.class,
181+
SizeLimitingPyList.class,
182+
SnakeCaseAccessibleMap.class,
183+
FormattedDate.class,
184+
PyishDate.class,
185+
DummyObject.class,
186+
Namespace.class,
187+
SafeString.class,
188+
};
189+
190+
@Override
191+
Class<?>[] allowedDeclaredMethodsFromClasses() {
192+
return ARRAY;
193+
}
194+
},
163195
Collections {
164196
private static final Class<?>[] ARRAY = {
165197
PyList.class,
166198
PyMap.class,
167199
SizeLimitingPyMap.class,
168200
SizeLimitingPyList.class,
201+
ArrayList.class,
202+
ForwardingList.class,
203+
ForwardingMap.class,
204+
ForwardingCollection.class,
205+
AbstractCollection.class,
169206
};
170207

171208
@Override

src/test/java/com/hubspot/jinjava/el/ExpressionResolverTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Optional;
3030
import java.util.Set;
3131
import java.util.function.Supplier;
32+
import org.junit.After;
3233
import org.junit.Before;
3334
import org.junit.Test;
3435

@@ -44,6 +45,12 @@ public void setup() {
4445
jinjava = new Jinjava();
4546
interpreter = jinjava.newInterpreter();
4647
context = interpreter.getContext();
48+
JinjavaInterpreter.pushCurrent(interpreter);
49+
}
50+
51+
@After
52+
public void teardown() {
53+
JinjavaInterpreter.popCurrent();
4754
}
4855

4956
@Test
@@ -467,6 +474,7 @@ public void itBlocksDisabledFunctions() {
467474
);
468475

469476
String template = "hi {% for i in range(1, 3) %}{{i}} {% endfor %}";
477+
JinjavaInterpreter.popCurrent();
470478

471479
String rendered = jinjava.render(template, context);
472480
assertEquals("hi 1 2 ", rendered);

src/test/java/com/hubspot/jinjava/el/ext/AstDictTest.java

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,105 @@
88
import com.hubspot.jinjava.interpret.TemplateError.ErrorType;
99
import java.util.Map;
1010
import java.util.Set;
11-
import org.junit.Before;
1211
import org.junit.Test;
1312

1413
public class AstDictTest {
1514

16-
private JinjavaInterpreter interpreter;
17-
18-
@Before
19-
public void setup() {
20-
interpreter = new Jinjava().newInterpreter();
21-
}
22-
2315
@Test
2416
public void itGetsDictValues() {
25-
interpreter.getContext().put("foo", ImmutableMap.of("bar", "test"));
26-
assertThat(interpreter.resolveELExpression("foo.bar", -1)).isEqualTo("test");
17+
try (
18+
var a = JinjavaInterpreter
19+
.closeablePushCurrent(new Jinjava().newInterpreter())
20+
.get()
21+
) {
22+
JinjavaInterpreter interpreter = a.value();
23+
interpreter.getContext().put("foo", ImmutableMap.of("bar", "test"));
24+
assertThat(interpreter.resolveELExpression("foo.bar", -1)).isEqualTo("test");
25+
}
2726
}
2827

2928
@Test
3029
public void itGetsDictValuesWithEnumKeys() {
31-
interpreter.getContext().put("foo", ImmutableMap.of(ErrorType.FATAL, "test"));
32-
assertThat(interpreter.resolveELExpression("foo.FATAL", -1)).isEqualTo("test");
30+
try (
31+
var a = JinjavaInterpreter
32+
.closeablePushCurrent(new Jinjava().newInterpreter())
33+
.get()
34+
) {
35+
JinjavaInterpreter interpreter = a.value();
36+
interpreter.getContext().put("foo", ImmutableMap.of(ErrorType.FATAL, "test"));
37+
assertThat(interpreter.resolveELExpression("foo.FATAL", -1)).isEqualTo("test");
38+
}
3339
}
3440

3541
@Test
3642
public void itGetsDictValuesWithEnumKeysUsingToString() {
37-
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
38-
assertThat(interpreter.resolveELExpression("foo.barName", -1)).isEqualTo("test");
43+
try (
44+
var a = JinjavaInterpreter
45+
.closeablePushCurrent(new Jinjava().newInterpreter())
46+
.get()
47+
) {
48+
JinjavaInterpreter interpreter = a.value();
49+
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
50+
assertThat(interpreter.resolveELExpression("foo.barName", -1)).isEqualTo("test");
51+
}
3952
}
4053

4154
@Test
4255
public void itDoesItemsMethodCall() {
43-
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
44-
assertThat(interpreter.resolveELExpression("foo.items()", -1))
45-
.isInstanceOf(Set.class);
56+
try (
57+
var a = JinjavaInterpreter
58+
.closeablePushCurrent(new Jinjava().newInterpreter())
59+
.get()
60+
) {
61+
JinjavaInterpreter interpreter = a.value();
62+
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
63+
assertThat(interpreter.resolveELExpression("foo.items()", -1))
64+
.isInstanceOf(Set.class);
65+
}
4666
}
4767

4868
@Test
4969
public void itDoesKeysMethodCall() {
50-
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
51-
assertThat(interpreter.resolveELExpression("foo.keys()", -1)).isInstanceOf(Set.class);
70+
try (
71+
var a = JinjavaInterpreter
72+
.closeablePushCurrent(new Jinjava().newInterpreter())
73+
.get()
74+
) {
75+
JinjavaInterpreter interpreter = a.value();
76+
interpreter.getContext().put("foo", ImmutableMap.of(TestEnum.BAR, "test"));
77+
assertThat(interpreter.resolveELExpression("foo.keys()", -1))
78+
.isInstanceOf(Set.class);
79+
}
5280
}
5381

5482
@Test
5583
public void itHandlesEmptyMaps() {
56-
interpreter.getContext().put("foo", ImmutableMap.of());
57-
assertThat(interpreter.resolveELExpression("foo.FATAL", -1)).isNull();
58-
assertThat(interpreter.getErrors()).isEmpty();
84+
try (
85+
var a = JinjavaInterpreter
86+
.closeablePushCurrent(new Jinjava().newInterpreter())
87+
.get()
88+
) {
89+
JinjavaInterpreter interpreter = a.value();
90+
interpreter.getContext().put("foo", ImmutableMap.of());
91+
assertThat(interpreter.resolveELExpression("foo.FATAL", -1)).isNull();
92+
assertThat(interpreter.getErrors()).isEmpty();
93+
}
5994
}
6095

6196
@Test
6297
public void itGetsDictValuesWithEnumKeysInObjects() {
63-
interpreter
64-
.getContext()
65-
.put("test", new TestClass(ImmutableMap.of(ErrorType.FATAL, "test")));
66-
assertThat(interpreter.resolveELExpression("test.my_map.FATAL", -1))
67-
.isEqualTo("test");
98+
try (
99+
var a = JinjavaInterpreter
100+
.closeablePushCurrent(new Jinjava().newInterpreter())
101+
.get()
102+
) {
103+
JinjavaInterpreter interpreter = a.value();
104+
interpreter
105+
.getContext()
106+
.put("test", new TestClass(ImmutableMap.of(ErrorType.FATAL, "test")));
107+
assertThat(interpreter.resolveELExpression("test.my_map.FATAL", -1))
108+
.isEqualTo("test");
109+
}
68110
}
69111

70112
public class TestClass {

src/test/java/com/hubspot/jinjava/el/ext/eager/EagerAstDotTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import com.hubspot.jinjava.BaseInterpretingTest;
77
import com.hubspot.jinjava.JinjavaConfig;
88
import com.hubspot.jinjava.LegacyOverrides;
9+
import com.hubspot.jinjava.el.ext.MethodValidator;
10+
import com.hubspot.jinjava.el.ext.MethodValidatorConfig;
911
import com.hubspot.jinjava.interpret.Context;
1012
import com.hubspot.jinjava.interpret.DeferredValue;
1113
import com.hubspot.jinjava.interpret.DeferredValueException;
@@ -28,6 +30,17 @@ public void setup() {
2830
.withLegacyOverrides(
2931
LegacyOverrides.newBuilder().withUsePyishObjectMapper(true).build()
3032
)
33+
.withMethodValidator(
34+
MethodValidator.create(
35+
MethodValidatorConfig
36+
.builder()
37+
.addDefaultAllowlistGroups()
38+
.addAllowedDeclaredMethodsFromCanonicalClassPrefixes(
39+
EagerAstDotTest.class.getCanonicalName()
40+
)
41+
.build()
42+
)
43+
)
3144
.withMaxMacroRecursionDepth(5)
3245
.withEnableRecursiveMacroCalls(true)
3346
.build();

src/test/java/com/hubspot/jinjava/el/ext/eager/EagerAstMethodTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import com.hubspot.jinjava.JinjavaConfig;
88
import com.hubspot.jinjava.LegacyOverrides;
99
import com.hubspot.jinjava.el.ext.DeferredParsingException;
10+
import com.hubspot.jinjava.el.ext.MethodValidator;
11+
import com.hubspot.jinjava.el.ext.MethodValidatorConfig;
1012
import com.hubspot.jinjava.el.ext.eager.EagerAstDotTest.Foo;
1113
import com.hubspot.jinjava.interpret.Context;
1214
import com.hubspot.jinjava.interpret.DeferredValue;
@@ -35,6 +37,17 @@ public void setup() {
3537
.withLegacyOverrides(
3638
LegacyOverrides.newBuilder().withUsePyishObjectMapper(true).build()
3739
)
40+
.withMethodValidator(
41+
MethodValidator.create(
42+
MethodValidatorConfig
43+
.builder()
44+
.addDefaultAllowlistGroups()
45+
.addAllowedDeclaredMethodsFromCanonicalClassPrefixes(
46+
EagerAstDotTest.class.getCanonicalName()
47+
)
48+
.build()
49+
)
50+
)
3851
.withMaxMacroRecursionDepth(5)
3952
.withEnableRecursiveMacroCalls(true)
4053
.build();

0 commit comments

Comments
 (0)