Skip to content

Commit 8db1921

Browse files
authored
Refactor invokeSpecialMethod for reflection parity (#568)
We overhaul the `invokeSpecialMethod` backend to make it robust against JNI aborts and closer to standard Java reflection behavior, while maintaining reasonable performance. Key improvements: - Fix illegal `jclass{nullptr}` usage during return value boxing by globally caching primitive wrapper classes (`Integer`, `Double`, etc.). - Eliminate heap allocations for argument mapping by replacing `std::vector` with stack-allocated memory via `alloca`. - Prevent JNI type confusion crashes by introducing strict type validation (`IsInstanceOf`) before unboxing arguments. - Safely catch native exceptions from the target method and wrap them in `InvocationTargetException`, mirroring `java.lang.reflect.Method.invoke`. - Gracefully handle `null` receivers and zero-length argument arrays.
1 parent fe8e069 commit 8db1921

1 file changed

Lines changed: 198 additions & 117 deletions

File tree

native/src/jni/hook_bridge.cpp

Lines changed: 198 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <alloca.h>
12
#include <parallel_hashmap/phmap.h>
23

34
#include <lsplant.hpp>
@@ -273,171 +274,251 @@ VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, allocateObject, jclass cls) {
273274
}
274275

275276
/**
276-
* @brief A high-performance, low-level implementation of Method.invoke for super.method() calls.
277+
* Core JNI backend for non-virtual method invocation and special object initialization.
277278
*
278-
* This function manually unboxes arguments from a jobject array into a jvalue C-style array,
279-
* calls the appropriate JNI `CallNonvirtual...MethodA` function,
280-
* and then boxes the return value back into a jobject.
281-
* This avoids the overhead of Java reflection.
282-
*
283-
* @warning This is a very sensitive function.
284-
* The `shorty` descriptor must perfectly match the method's actual signature.
279+
* Implementation details:
280+
* 1. Dispatches using JNI CallNonvirtual<Type>MethodA.
281+
* 2. Employs stack allocation (alloca) for JNI argument mapping.
282+
* 3. Safely mirrors standard Java reflection (NPEs on null primitives/receivers).
283+
* 4. Prevents JNI Type Confusion and memory leaks by caching primitive wrappers globally,
284+
* while leveraging java.lang.Number for fast implicit widening/narrowing.
285+
* 5. Accurately catches and wraps target method exceptions into InvocationTargetException.
285286
*/
286287
VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, invokeSpecialMethod, jobject method,
287288
jcharArray shorty, jclass cls, jobject thiz, jobjectArray args) {
288-
// --- Cache all necessary MethodIDs for boxing/unboxing primitive wrappers
289-
// --- This is a major performance optimization, done only once.
290-
static auto *const get_int =
291-
env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I");
292-
static auto *const get_double =
293-
env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D");
294-
static auto *const get_long =
295-
env->GetMethodID(env->FindClass("java/lang/Long"), "longValue", "()J");
296-
static auto *const get_float =
297-
env->GetMethodID(env->FindClass("java/lang/Float"), "floatValue", "()F");
298-
static auto *const get_short =
299-
env->GetMethodID(env->FindClass("java/lang/Short"), "shortValue", "()S");
300-
static auto *const get_byte =
301-
env->GetMethodID(env->FindClass("java/lang/Byte"), "byteValue", "()B");
302-
static auto *const get_char =
303-
env->GetMethodID(env->FindClass("java/lang/Character"), "charValue", "()C");
304-
static auto *const get_boolean =
305-
env->GetMethodID(env->FindClass("java/lang/Boolean"), "booleanValue", "()Z");
306-
static auto *const set_int = env->GetStaticMethodID(env->FindClass("java/lang/Integer"),
307-
"valueOf", "(I)Ljava/lang/Integer;");
308-
static auto *const set_double = env->GetStaticMethodID(env->FindClass("java/lang/Double"),
309-
"valueOf", "(D)Ljava/lang/Double;");
289+
// --- JNI Global Reference Caching ---
290+
// Cached once per process lifecycle to maintain extreme performance and prevent JNI aborts.
291+
static jclass cls_Number = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Number"));
292+
static jclass cls_Boolean = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Boolean"));
293+
static jclass cls_Character = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Character"));
294+
295+
// Globally cache primitive wrapper classes for safe return value boxing
296+
static jclass cls_Integer = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Integer"));
297+
static jclass cls_Double = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Double"));
298+
static jclass cls_Long = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Long"));
299+
static jclass cls_Float = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Float"));
300+
static jclass cls_Short = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Short"));
301+
static jclass cls_Byte = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Byte"));
302+
303+
static jclass cls_ITE =
304+
(jclass)env->NewGlobalRef(env->FindClass("java/lang/reflect/InvocationTargetException"));
305+
306+
static auto *const ctor_ite = env->GetMethodID(cls_ITE, "<init>", "(Ljava/lang/Throwable;)V");
307+
308+
static auto *const get_int = env->GetMethodID(cls_Number, "intValue", "()I");
309+
static auto *const get_double = env->GetMethodID(cls_Number, "doubleValue", "()D");
310+
static auto *const get_long = env->GetMethodID(cls_Number, "longValue", "()J");
311+
static auto *const get_float = env->GetMethodID(cls_Number, "floatValue", "()F");
312+
static auto *const get_short = env->GetMethodID(cls_Number, "shortValue", "()S");
313+
static auto *const get_byte = env->GetMethodID(cls_Number, "byteValue", "()B");
314+
315+
static auto *const get_char = env->GetMethodID(cls_Character, "charValue", "()C");
316+
static auto *const get_boolean = env->GetMethodID(cls_Boolean, "booleanValue", "()Z");
317+
318+
static auto *const set_int =
319+
env->GetStaticMethodID(cls_Integer, "valueOf", "(I)Ljava/lang/Integer;");
320+
static auto *const set_double =
321+
env->GetStaticMethodID(cls_Double, "valueOf", "(D)Ljava/lang/Double;");
310322
static auto *const set_long =
311-
env->GetStaticMethodID(env->FindClass("java/lang/Long"), "valueOf", "(J)Ljava/lang/Long;");
312-
static auto *const set_float = env->GetStaticMethodID(env->FindClass("java/lang/Float"),
313-
"valueOf", "(F)Ljava/lang/Float;");
314-
static auto *const set_short = env->GetStaticMethodID(env->FindClass("java/lang/Short"),
315-
"valueOf", "(S)Ljava/lang/Short;");
323+
env->GetStaticMethodID(cls_Long, "valueOf", "(J)Ljava/lang/Long;");
324+
static auto *const set_float =
325+
env->GetStaticMethodID(cls_Float, "valueOf", "(F)Ljava/lang/Float;");
326+
static auto *const set_short =
327+
env->GetStaticMethodID(cls_Short, "valueOf", "(S)Ljava/lang/Short;");
316328
static auto *const set_byte =
317-
env->GetStaticMethodID(env->FindClass("java/lang/Byte"), "valueOf", "(B)Ljava/lang/Byte;");
318-
static auto *const set_char = env->GetStaticMethodID(env->FindClass("java/lang/Character"),
319-
"valueOf", "(C)Ljava/lang/Character;");
320-
static auto *const set_boolean = env->GetStaticMethodID(env->FindClass("java/lang/Boolean"),
321-
"valueOf", "(Z)Ljava/lang/Boolean;");
329+
env->GetStaticMethodID(cls_Byte, "valueOf", "(B)Ljava/lang/Byte;");
330+
static auto *const set_char =
331+
env->GetStaticMethodID(cls_Character, "valueOf", "(C)Ljava/lang/Character;");
332+
static auto *const set_boolean =
333+
env->GetStaticMethodID(cls_Boolean, "valueOf", "(Z)Ljava/lang/Boolean;");
322334

323335
auto target = env->FromReflectedMethod(method);
324-
auto param_len = env->GetArrayLength(shorty) - 1; // First char is return type.
336+
auto param_len = env->GetArrayLength(shorty) - 1;
325337

326-
// --- Argument Validation ---
327-
if (env->GetArrayLength(args) != param_len) {
338+
// --- Argument & Receiver Validation ---
339+
auto args_len = args != nullptr ? env->GetArrayLength(args) : 0;
340+
if (args_len != param_len) {
328341
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
329342
"args.length does not match parameter count");
330343
return nullptr;
331344
}
345+
332346
if (thiz == nullptr) {
333-
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
334-
"`this` cannot be null for a non-virtual call");
347+
env->ThrowNew(env->FindClass("java/lang/NullPointerException"), "null receiver");
335348
return nullptr;
336349
}
337350

338-
// --- Unbox Arguments ---
339-
std::vector<jvalue> a(param_len);
351+
// Allocate jvalue array on the stack
352+
jvalue *a = param_len > 0 ? static_cast<jvalue *>(alloca(param_len * sizeof(jvalue))) : nullptr;
353+
340354
auto *const shorty_char = env->GetCharArrayElements(shorty, nullptr);
355+
if (shorty_char == nullptr) {
356+
return nullptr; // JVM already threw OutOfMemoryError
357+
}
358+
359+
// RAII/Helper for clean JNI array exits
360+
auto abort_and_return = [&]() {
361+
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
362+
return nullptr;
363+
};
364+
365+
// --- Safe Unboxing ---
341366
for (jint i = 0; i != param_len; ++i) {
342367
jobject element = env->GetObjectArrayElement(args, i);
343-
if (env->ExceptionCheck()) {
344-
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
345-
return nullptr;
346-
}
368+
if (env->ExceptionCheck()) return abort_and_return();
347369

348-
// The shorty string at index i+1 describes the type of the i-th parameter.
349-
switch (shorty_char[i + 1]) {
350-
case 'I':
351-
a[i].i = env->CallIntMethod(element, get_int);
352-
break;
353-
case 'D':
354-
a[i].d = env->CallDoubleMethod(element, get_double);
355-
break;
356-
case 'J':
357-
a[i].j = env->CallLongMethod(element, get_long);
358-
break;
359-
case 'F':
360-
a[i].f = env->CallFloatMethod(element, get_float);
361-
break;
362-
case 'S':
363-
a[i].s = env->CallShortMethod(element, get_short);
364-
break;
365-
case 'B':
366-
a[i].b = env->CallByteMethod(element, get_byte);
367-
break;
368-
case 'C':
369-
a[i].c = env->CallCharMethod(element, get_char);
370-
break;
371-
case 'Z':
372-
a[i].z = env->CallBooleanMethod(element, get_boolean);
373-
break;
374-
default: // Assumes 'L' or '[' for object types
375-
a[i].l = element;
376-
// Set element to null so we don't delete the local ref twice.
377-
// The reference is stored in the jvalue array and is still valid.
378-
element = nullptr;
379-
break;
370+
char type = shorty_char[i + 1];
371+
372+
if (element == nullptr) {
373+
if (type != 'L' && type != '[') {
374+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
375+
"null primitive argument");
376+
return abort_and_return();
377+
}
378+
a[i].l = nullptr;
379+
} else {
380+
if (type == 'Z') {
381+
if (!env->IsInstanceOf(element, cls_Boolean)) {
382+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
383+
"Expected Boolean");
384+
return abort_and_return();
385+
}
386+
a[i].z = env->CallBooleanMethod(element, get_boolean);
387+
} else if (type == 'C') {
388+
if (!env->IsInstanceOf(element, cls_Character)) {
389+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
390+
"Expected Character");
391+
return abort_and_return();
392+
}
393+
a[i].c = env->CallCharMethod(element, get_char);
394+
} else if (type != 'L' && type != '[') {
395+
bool is_number = env->IsInstanceOf(element, cls_Number) == JNI_TRUE;
396+
bool is_character =
397+
!is_number && (env->IsInstanceOf(element, cls_Character) == JNI_TRUE);
398+
399+
if (!is_number && !is_character) {
400+
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
401+
"Expected Number or Character");
402+
return abort_and_return();
403+
}
404+
405+
// If a Character is passed to a numeric parameter, extract its value for widening
406+
jchar c_val = 0;
407+
if (is_character) {
408+
c_val = env->CallCharMethod(element, get_char);
409+
if (env->ExceptionCheck()) return abort_and_return();
410+
}
411+
412+
switch (type) {
413+
case 'I':
414+
a[i].i = env->CallIntMethod(element, get_int);
415+
break;
416+
case 'D':
417+
a[i].d = env->CallDoubleMethod(element, get_double);
418+
break;
419+
case 'J':
420+
a[i].j = env->CallLongMethod(element, get_long);
421+
break;
422+
case 'F':
423+
a[i].f = env->CallFloatMethod(element, get_float);
424+
break;
425+
case 'S':
426+
a[i].s = env->CallShortMethod(element, get_short);
427+
break;
428+
case 'B':
429+
a[i].b = env->CallByteMethod(element, get_byte);
430+
break;
431+
}
432+
} else {
433+
a[i].l = element;
434+
element =
435+
nullptr; // Transferred ownership to jvalue array; will be freed on return
436+
}
380437
}
381438

382-
// Clean up the local reference for the wrapper object if it was created.
383439
if (element) env->DeleteLocalRef(element);
440+
if (env->ExceptionCheck()) return abort_and_return();
441+
}
384442

385-
// Check for exceptions during the unboxing call (e.g.,
386-
// NullPointerException).
387-
if (env->ExceptionCheck()) {
388-
env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT);
389-
return nullptr;
443+
// --- Non-virtual Invocation ---
444+
jvalue ret_val;
445+
switch (shorty_char[0]) {
446+
case 'I':
447+
ret_val.i = env->CallNonvirtualIntMethodA(thiz, cls, target, a);
448+
break;
449+
case 'D':
450+
ret_val.d = env->CallNonvirtualDoubleMethodA(thiz, cls, target, a);
451+
break;
452+
case 'J':
453+
ret_val.j = env->CallNonvirtualLongMethodA(thiz, cls, target, a);
454+
break;
455+
case 'F':
456+
ret_val.f = env->CallNonvirtualFloatMethodA(thiz, cls, target, a);
457+
break;
458+
case 'S':
459+
ret_val.s = env->CallNonvirtualShortMethodA(thiz, cls, target, a);
460+
break;
461+
case 'B':
462+
ret_val.b = env->CallNonvirtualByteMethodA(thiz, cls, target, a);
463+
break;
464+
case 'C':
465+
ret_val.c = env->CallNonvirtualCharMethodA(thiz, cls, target, a);
466+
break;
467+
case 'Z':
468+
ret_val.z = env->CallNonvirtualBooleanMethodA(thiz, cls, target, a);
469+
break;
470+
case 'L':
471+
ret_val.l = env->CallNonvirtualObjectMethodA(thiz, cls, target, a);
472+
break;
473+
default:
474+
env->CallNonvirtualVoidMethodA(thiz, cls, target, a);
475+
break;
476+
}
477+
478+
// --- Exception Wrapping ---
479+
jthrowable target_exception = env->ExceptionOccurred();
480+
if (target_exception) {
481+
env->ExceptionClear();
482+
jobject ite = env->NewObject(cls_ITE, ctor_ite, target_exception);
483+
// Ensure NewObject didn't fail due to OOM before throwing
484+
if (ite) {
485+
env->Throw(static_cast<jthrowable>(ite));
390486
}
487+
return abort_and_return();
391488
}
392489

393-
// --- Call Non-virtual Method and Box Return Value ---
490+
// --- Box Return Value ---
394491
jobject value = nullptr;
395-
// The shorty string at index 0 describes the return type.
396492
switch (shorty_char[0]) {
397493
case 'I':
398-
value =
399-
env->CallStaticObjectMethod(jclass{nullptr},
400-
set_int, // Use Integer.valueOf() to box
401-
env->CallNonvirtualIntMethodA(thiz, cls, target, a.data()));
494+
value = env->CallStaticObjectMethod(cls_Integer, set_int, ret_val.i);
402495
break;
403496
case 'D':
404-
value = env->CallStaticObjectMethod(
405-
jclass{nullptr}, set_double,
406-
env->CallNonvirtualDoubleMethodA(thiz, cls, target, a.data()));
497+
value = env->CallStaticObjectMethod(cls_Double, set_double, ret_val.d);
407498
break;
408499
case 'J':
409-
value = env->CallStaticObjectMethod(
410-
jclass{nullptr}, set_long, env->CallNonvirtualLongMethodA(thiz, cls, target, a.data()));
500+
value = env->CallStaticObjectMethod(cls_Long, set_long, ret_val.j);
411501
break;
412502
case 'F':
413-
value = env->CallStaticObjectMethod(
414-
jclass{nullptr}, set_float,
415-
env->CallNonvirtualFloatMethodA(thiz, cls, target, a.data()));
503+
value = env->CallStaticObjectMethod(cls_Float, set_float, ret_val.f);
416504
break;
417505
case 'S':
418-
value = env->CallStaticObjectMethod(
419-
jclass{nullptr}, set_short,
420-
env->CallNonvirtualShortMethodA(thiz, cls, target, a.data()));
506+
value = env->CallStaticObjectMethod(cls_Short, set_short, ret_val.s);
421507
break;
422508
case 'B':
423-
value = env->CallStaticObjectMethod(
424-
jclass{nullptr}, set_byte, env->CallNonvirtualByteMethodA(thiz, cls, target, a.data()));
509+
value = env->CallStaticObjectMethod(cls_Byte, set_byte, ret_val.b);
425510
break;
426511
case 'C':
427-
value = env->CallStaticObjectMethod(
428-
jclass{nullptr}, set_char, env->CallNonvirtualCharMethodA(thiz, cls, target, a.data()));
512+
value = env->CallStaticObjectMethod(cls_Character, set_char, ret_val.c);
429513
break;
430514
case 'Z':
431-
value = env->CallStaticObjectMethod(
432-
jclass{nullptr}, set_boolean,
433-
env->CallNonvirtualBooleanMethodA(thiz, cls, target, a.data()));
515+
value = env->CallStaticObjectMethod(cls_Boolean, set_boolean, ret_val.z);
434516
break;
435-
case 'L': // Return type is an object, no boxing needed.
436-
value = env->CallNonvirtualObjectMethodA(thiz, cls, target, a.data());
517+
case 'L':
518+
value = ret_val.l;
437519
break;
438-
default: // Assumes 'V' for void return type.
439520
case 'V':
440-
env->CallNonvirtualVoidMethodA(thiz, cls, target, a.data());
521+
value = nullptr;
441522
break;
442523
}
443524

0 commit comments

Comments
 (0)