|
| 1 | +#include <alloca.h> |
1 | 2 | #include <parallel_hashmap/phmap.h> |
2 | 3 |
|
3 | 4 | #include <lsplant.hpp> |
@@ -273,171 +274,251 @@ VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, allocateObject, jclass cls) { |
273 | 274 | } |
274 | 275 |
|
275 | 276 | /** |
276 | | - * @brief A high-performance, low-level implementation of Method.invoke for super.method() calls. |
| 277 | + * Core JNI backend for non-virtual method invocation and special object initialization. |
277 | 278 | * |
278 | | - * This function manually unboxes arguments from a jobject array into a jvalue C-style array, |
279 | | - * calls the appropriate JNI `CallNonvirtual...MethodA` function, |
280 | | - * and then boxes the return value back into a jobject. |
281 | | - * This avoids the overhead of Java reflection. |
282 | | - * |
283 | | - * @warning This is a very sensitive function. |
284 | | - * The `shorty` descriptor must perfectly match the method's actual signature. |
| 279 | + * Implementation details: |
| 280 | + * 1. Dispatches using JNI CallNonvirtual<Type>MethodA. |
| 281 | + * 2. Employs stack allocation (alloca) for JNI argument mapping. |
| 282 | + * 3. Safely mirrors standard Java reflection (NPEs on null primitives/receivers). |
| 283 | + * 4. Prevents JNI Type Confusion and memory leaks by caching primitive wrappers globally, |
| 284 | + * while leveraging java.lang.Number for fast implicit widening/narrowing. |
| 285 | + * 5. Accurately catches and wraps target method exceptions into InvocationTargetException. |
285 | 286 | */ |
286 | 287 | VECTOR_DEF_NATIVE_METHOD(jobject, HookBridge, invokeSpecialMethod, jobject method, |
287 | 288 | jcharArray shorty, jclass cls, jobject thiz, jobjectArray args) { |
288 | | - // --- Cache all necessary MethodIDs for boxing/unboxing primitive wrappers |
289 | | - // --- This is a major performance optimization, done only once. |
290 | | - static auto *const get_int = |
291 | | - env->GetMethodID(env->FindClass("java/lang/Integer"), "intValue", "()I"); |
292 | | - static auto *const get_double = |
293 | | - env->GetMethodID(env->FindClass("java/lang/Double"), "doubleValue", "()D"); |
294 | | - static auto *const get_long = |
295 | | - env->GetMethodID(env->FindClass("java/lang/Long"), "longValue", "()J"); |
296 | | - static auto *const get_float = |
297 | | - env->GetMethodID(env->FindClass("java/lang/Float"), "floatValue", "()F"); |
298 | | - static auto *const get_short = |
299 | | - env->GetMethodID(env->FindClass("java/lang/Short"), "shortValue", "()S"); |
300 | | - static auto *const get_byte = |
301 | | - env->GetMethodID(env->FindClass("java/lang/Byte"), "byteValue", "()B"); |
302 | | - static auto *const get_char = |
303 | | - env->GetMethodID(env->FindClass("java/lang/Character"), "charValue", "()C"); |
304 | | - static auto *const get_boolean = |
305 | | - env->GetMethodID(env->FindClass("java/lang/Boolean"), "booleanValue", "()Z"); |
306 | | - static auto *const set_int = env->GetStaticMethodID(env->FindClass("java/lang/Integer"), |
307 | | - "valueOf", "(I)Ljava/lang/Integer;"); |
308 | | - static auto *const set_double = env->GetStaticMethodID(env->FindClass("java/lang/Double"), |
309 | | - "valueOf", "(D)Ljava/lang/Double;"); |
| 289 | + // --- JNI Global Reference Caching --- |
| 290 | + // Cached once per process lifecycle to maintain extreme performance and prevent JNI aborts. |
| 291 | + static jclass cls_Number = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Number")); |
| 292 | + static jclass cls_Boolean = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Boolean")); |
| 293 | + static jclass cls_Character = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Character")); |
| 294 | + |
| 295 | + // Globally cache primitive wrapper classes for safe return value boxing |
| 296 | + static jclass cls_Integer = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Integer")); |
| 297 | + static jclass cls_Double = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Double")); |
| 298 | + static jclass cls_Long = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Long")); |
| 299 | + static jclass cls_Float = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Float")); |
| 300 | + static jclass cls_Short = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Short")); |
| 301 | + static jclass cls_Byte = (jclass)env->NewGlobalRef(env->FindClass("java/lang/Byte")); |
| 302 | + |
| 303 | + static jclass cls_ITE = |
| 304 | + (jclass)env->NewGlobalRef(env->FindClass("java/lang/reflect/InvocationTargetException")); |
| 305 | + |
| 306 | + static auto *const ctor_ite = env->GetMethodID(cls_ITE, "<init>", "(Ljava/lang/Throwable;)V"); |
| 307 | + |
| 308 | + static auto *const get_int = env->GetMethodID(cls_Number, "intValue", "()I"); |
| 309 | + static auto *const get_double = env->GetMethodID(cls_Number, "doubleValue", "()D"); |
| 310 | + static auto *const get_long = env->GetMethodID(cls_Number, "longValue", "()J"); |
| 311 | + static auto *const get_float = env->GetMethodID(cls_Number, "floatValue", "()F"); |
| 312 | + static auto *const get_short = env->GetMethodID(cls_Number, "shortValue", "()S"); |
| 313 | + static auto *const get_byte = env->GetMethodID(cls_Number, "byteValue", "()B"); |
| 314 | + |
| 315 | + static auto *const get_char = env->GetMethodID(cls_Character, "charValue", "()C"); |
| 316 | + static auto *const get_boolean = env->GetMethodID(cls_Boolean, "booleanValue", "()Z"); |
| 317 | + |
| 318 | + static auto *const set_int = |
| 319 | + env->GetStaticMethodID(cls_Integer, "valueOf", "(I)Ljava/lang/Integer;"); |
| 320 | + static auto *const set_double = |
| 321 | + env->GetStaticMethodID(cls_Double, "valueOf", "(D)Ljava/lang/Double;"); |
310 | 322 | static auto *const set_long = |
311 | | - env->GetStaticMethodID(env->FindClass("java/lang/Long"), "valueOf", "(J)Ljava/lang/Long;"); |
312 | | - static auto *const set_float = env->GetStaticMethodID(env->FindClass("java/lang/Float"), |
313 | | - "valueOf", "(F)Ljava/lang/Float;"); |
314 | | - static auto *const set_short = env->GetStaticMethodID(env->FindClass("java/lang/Short"), |
315 | | - "valueOf", "(S)Ljava/lang/Short;"); |
| 323 | + env->GetStaticMethodID(cls_Long, "valueOf", "(J)Ljava/lang/Long;"); |
| 324 | + static auto *const set_float = |
| 325 | + env->GetStaticMethodID(cls_Float, "valueOf", "(F)Ljava/lang/Float;"); |
| 326 | + static auto *const set_short = |
| 327 | + env->GetStaticMethodID(cls_Short, "valueOf", "(S)Ljava/lang/Short;"); |
316 | 328 | static auto *const set_byte = |
317 | | - env->GetStaticMethodID(env->FindClass("java/lang/Byte"), "valueOf", "(B)Ljava/lang/Byte;"); |
318 | | - static auto *const set_char = env->GetStaticMethodID(env->FindClass("java/lang/Character"), |
319 | | - "valueOf", "(C)Ljava/lang/Character;"); |
320 | | - static auto *const set_boolean = env->GetStaticMethodID(env->FindClass("java/lang/Boolean"), |
321 | | - "valueOf", "(Z)Ljava/lang/Boolean;"); |
| 329 | + env->GetStaticMethodID(cls_Byte, "valueOf", "(B)Ljava/lang/Byte;"); |
| 330 | + static auto *const set_char = |
| 331 | + env->GetStaticMethodID(cls_Character, "valueOf", "(C)Ljava/lang/Character;"); |
| 332 | + static auto *const set_boolean = |
| 333 | + env->GetStaticMethodID(cls_Boolean, "valueOf", "(Z)Ljava/lang/Boolean;"); |
322 | 334 |
|
323 | 335 | auto target = env->FromReflectedMethod(method); |
324 | | - auto param_len = env->GetArrayLength(shorty) - 1; // First char is return type. |
| 336 | + auto param_len = env->GetArrayLength(shorty) - 1; |
325 | 337 |
|
326 | | - // --- Argument Validation --- |
327 | | - if (env->GetArrayLength(args) != param_len) { |
| 338 | + // --- Argument & Receiver Validation --- |
| 339 | + auto args_len = args != nullptr ? env->GetArrayLength(args) : 0; |
| 340 | + if (args_len != param_len) { |
328 | 341 | env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
329 | 342 | "args.length does not match parameter count"); |
330 | 343 | return nullptr; |
331 | 344 | } |
| 345 | + |
332 | 346 | if (thiz == nullptr) { |
333 | | - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
334 | | - "`this` cannot be null for a non-virtual call"); |
| 347 | + env->ThrowNew(env->FindClass("java/lang/NullPointerException"), "null receiver"); |
335 | 348 | return nullptr; |
336 | 349 | } |
337 | 350 |
|
338 | | - // --- Unbox Arguments --- |
339 | | - std::vector<jvalue> a(param_len); |
| 351 | + // Allocate jvalue array on the stack |
| 352 | + jvalue *a = param_len > 0 ? static_cast<jvalue *>(alloca(param_len * sizeof(jvalue))) : nullptr; |
| 353 | + |
340 | 354 | auto *const shorty_char = env->GetCharArrayElements(shorty, nullptr); |
| 355 | + if (shorty_char == nullptr) { |
| 356 | + return nullptr; // JVM already threw OutOfMemoryError |
| 357 | + } |
| 358 | + |
| 359 | + // RAII/Helper for clean JNI array exits |
| 360 | + auto abort_and_return = [&]() { |
| 361 | + env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); |
| 362 | + return nullptr; |
| 363 | + }; |
| 364 | + |
| 365 | + // --- Safe Unboxing --- |
341 | 366 | for (jint i = 0; i != param_len; ++i) { |
342 | 367 | jobject element = env->GetObjectArrayElement(args, i); |
343 | | - if (env->ExceptionCheck()) { |
344 | | - env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); |
345 | | - return nullptr; |
346 | | - } |
| 368 | + if (env->ExceptionCheck()) return abort_and_return(); |
347 | 369 |
|
348 | | - // The shorty string at index i+1 describes the type of the i-th parameter. |
349 | | - switch (shorty_char[i + 1]) { |
350 | | - case 'I': |
351 | | - a[i].i = env->CallIntMethod(element, get_int); |
352 | | - break; |
353 | | - case 'D': |
354 | | - a[i].d = env->CallDoubleMethod(element, get_double); |
355 | | - break; |
356 | | - case 'J': |
357 | | - a[i].j = env->CallLongMethod(element, get_long); |
358 | | - break; |
359 | | - case 'F': |
360 | | - a[i].f = env->CallFloatMethod(element, get_float); |
361 | | - break; |
362 | | - case 'S': |
363 | | - a[i].s = env->CallShortMethod(element, get_short); |
364 | | - break; |
365 | | - case 'B': |
366 | | - a[i].b = env->CallByteMethod(element, get_byte); |
367 | | - break; |
368 | | - case 'C': |
369 | | - a[i].c = env->CallCharMethod(element, get_char); |
370 | | - break; |
371 | | - case 'Z': |
372 | | - a[i].z = env->CallBooleanMethod(element, get_boolean); |
373 | | - break; |
374 | | - default: // Assumes 'L' or '[' for object types |
375 | | - a[i].l = element; |
376 | | - // Set element to null so we don't delete the local ref twice. |
377 | | - // The reference is stored in the jvalue array and is still valid. |
378 | | - element = nullptr; |
379 | | - break; |
| 370 | + char type = shorty_char[i + 1]; |
| 371 | + |
| 372 | + if (element == nullptr) { |
| 373 | + if (type != 'L' && type != '[') { |
| 374 | + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
| 375 | + "null primitive argument"); |
| 376 | + return abort_and_return(); |
| 377 | + } |
| 378 | + a[i].l = nullptr; |
| 379 | + } else { |
| 380 | + if (type == 'Z') { |
| 381 | + if (!env->IsInstanceOf(element, cls_Boolean)) { |
| 382 | + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
| 383 | + "Expected Boolean"); |
| 384 | + return abort_and_return(); |
| 385 | + } |
| 386 | + a[i].z = env->CallBooleanMethod(element, get_boolean); |
| 387 | + } else if (type == 'C') { |
| 388 | + if (!env->IsInstanceOf(element, cls_Character)) { |
| 389 | + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
| 390 | + "Expected Character"); |
| 391 | + return abort_and_return(); |
| 392 | + } |
| 393 | + a[i].c = env->CallCharMethod(element, get_char); |
| 394 | + } else if (type != 'L' && type != '[') { |
| 395 | + bool is_number = env->IsInstanceOf(element, cls_Number) == JNI_TRUE; |
| 396 | + bool is_character = |
| 397 | + !is_number && (env->IsInstanceOf(element, cls_Character) == JNI_TRUE); |
| 398 | + |
| 399 | + if (!is_number && !is_character) { |
| 400 | + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), |
| 401 | + "Expected Number or Character"); |
| 402 | + return abort_and_return(); |
| 403 | + } |
| 404 | + |
| 405 | + // If a Character is passed to a numeric parameter, extract its value for widening |
| 406 | + jchar c_val = 0; |
| 407 | + if (is_character) { |
| 408 | + c_val = env->CallCharMethod(element, get_char); |
| 409 | + if (env->ExceptionCheck()) return abort_and_return(); |
| 410 | + } |
| 411 | + |
| 412 | + switch (type) { |
| 413 | + case 'I': |
| 414 | + a[i].i = env->CallIntMethod(element, get_int); |
| 415 | + break; |
| 416 | + case 'D': |
| 417 | + a[i].d = env->CallDoubleMethod(element, get_double); |
| 418 | + break; |
| 419 | + case 'J': |
| 420 | + a[i].j = env->CallLongMethod(element, get_long); |
| 421 | + break; |
| 422 | + case 'F': |
| 423 | + a[i].f = env->CallFloatMethod(element, get_float); |
| 424 | + break; |
| 425 | + case 'S': |
| 426 | + a[i].s = env->CallShortMethod(element, get_short); |
| 427 | + break; |
| 428 | + case 'B': |
| 429 | + a[i].b = env->CallByteMethod(element, get_byte); |
| 430 | + break; |
| 431 | + } |
| 432 | + } else { |
| 433 | + a[i].l = element; |
| 434 | + element = |
| 435 | + nullptr; // Transferred ownership to jvalue array; will be freed on return |
| 436 | + } |
380 | 437 | } |
381 | 438 |
|
382 | | - // Clean up the local reference for the wrapper object if it was created. |
383 | 439 | if (element) env->DeleteLocalRef(element); |
| 440 | + if (env->ExceptionCheck()) return abort_and_return(); |
| 441 | + } |
384 | 442 |
|
385 | | - // Check for exceptions during the unboxing call (e.g., |
386 | | - // NullPointerException). |
387 | | - if (env->ExceptionCheck()) { |
388 | | - env->ReleaseCharArrayElements(shorty, shorty_char, JNI_ABORT); |
389 | | - return nullptr; |
| 443 | + // --- Non-virtual Invocation --- |
| 444 | + jvalue ret_val; |
| 445 | + switch (shorty_char[0]) { |
| 446 | + case 'I': |
| 447 | + ret_val.i = env->CallNonvirtualIntMethodA(thiz, cls, target, a); |
| 448 | + break; |
| 449 | + case 'D': |
| 450 | + ret_val.d = env->CallNonvirtualDoubleMethodA(thiz, cls, target, a); |
| 451 | + break; |
| 452 | + case 'J': |
| 453 | + ret_val.j = env->CallNonvirtualLongMethodA(thiz, cls, target, a); |
| 454 | + break; |
| 455 | + case 'F': |
| 456 | + ret_val.f = env->CallNonvirtualFloatMethodA(thiz, cls, target, a); |
| 457 | + break; |
| 458 | + case 'S': |
| 459 | + ret_val.s = env->CallNonvirtualShortMethodA(thiz, cls, target, a); |
| 460 | + break; |
| 461 | + case 'B': |
| 462 | + ret_val.b = env->CallNonvirtualByteMethodA(thiz, cls, target, a); |
| 463 | + break; |
| 464 | + case 'C': |
| 465 | + ret_val.c = env->CallNonvirtualCharMethodA(thiz, cls, target, a); |
| 466 | + break; |
| 467 | + case 'Z': |
| 468 | + ret_val.z = env->CallNonvirtualBooleanMethodA(thiz, cls, target, a); |
| 469 | + break; |
| 470 | + case 'L': |
| 471 | + ret_val.l = env->CallNonvirtualObjectMethodA(thiz, cls, target, a); |
| 472 | + break; |
| 473 | + default: |
| 474 | + env->CallNonvirtualVoidMethodA(thiz, cls, target, a); |
| 475 | + break; |
| 476 | + } |
| 477 | + |
| 478 | + // --- Exception Wrapping --- |
| 479 | + jthrowable target_exception = env->ExceptionOccurred(); |
| 480 | + if (target_exception) { |
| 481 | + env->ExceptionClear(); |
| 482 | + jobject ite = env->NewObject(cls_ITE, ctor_ite, target_exception); |
| 483 | + // Ensure NewObject didn't fail due to OOM before throwing |
| 484 | + if (ite) { |
| 485 | + env->Throw(static_cast<jthrowable>(ite)); |
390 | 486 | } |
| 487 | + return abort_and_return(); |
391 | 488 | } |
392 | 489 |
|
393 | | - // --- Call Non-virtual Method and Box Return Value --- |
| 490 | + // --- Box Return Value --- |
394 | 491 | jobject value = nullptr; |
395 | | - // The shorty string at index 0 describes the return type. |
396 | 492 | switch (shorty_char[0]) { |
397 | 493 | case 'I': |
398 | | - value = |
399 | | - env->CallStaticObjectMethod(jclass{nullptr}, |
400 | | - set_int, // Use Integer.valueOf() to box |
401 | | - env->CallNonvirtualIntMethodA(thiz, cls, target, a.data())); |
| 494 | + value = env->CallStaticObjectMethod(cls_Integer, set_int, ret_val.i); |
402 | 495 | break; |
403 | 496 | case 'D': |
404 | | - value = env->CallStaticObjectMethod( |
405 | | - jclass{nullptr}, set_double, |
406 | | - env->CallNonvirtualDoubleMethodA(thiz, cls, target, a.data())); |
| 497 | + value = env->CallStaticObjectMethod(cls_Double, set_double, ret_val.d); |
407 | 498 | break; |
408 | 499 | case 'J': |
409 | | - value = env->CallStaticObjectMethod( |
410 | | - jclass{nullptr}, set_long, env->CallNonvirtualLongMethodA(thiz, cls, target, a.data())); |
| 500 | + value = env->CallStaticObjectMethod(cls_Long, set_long, ret_val.j); |
411 | 501 | break; |
412 | 502 | case 'F': |
413 | | - value = env->CallStaticObjectMethod( |
414 | | - jclass{nullptr}, set_float, |
415 | | - env->CallNonvirtualFloatMethodA(thiz, cls, target, a.data())); |
| 503 | + value = env->CallStaticObjectMethod(cls_Float, set_float, ret_val.f); |
416 | 504 | break; |
417 | 505 | case 'S': |
418 | | - value = env->CallStaticObjectMethod( |
419 | | - jclass{nullptr}, set_short, |
420 | | - env->CallNonvirtualShortMethodA(thiz, cls, target, a.data())); |
| 506 | + value = env->CallStaticObjectMethod(cls_Short, set_short, ret_val.s); |
421 | 507 | break; |
422 | 508 | case 'B': |
423 | | - value = env->CallStaticObjectMethod( |
424 | | - jclass{nullptr}, set_byte, env->CallNonvirtualByteMethodA(thiz, cls, target, a.data())); |
| 509 | + value = env->CallStaticObjectMethod(cls_Byte, set_byte, ret_val.b); |
425 | 510 | break; |
426 | 511 | case 'C': |
427 | | - value = env->CallStaticObjectMethod( |
428 | | - jclass{nullptr}, set_char, env->CallNonvirtualCharMethodA(thiz, cls, target, a.data())); |
| 512 | + value = env->CallStaticObjectMethod(cls_Character, set_char, ret_val.c); |
429 | 513 | break; |
430 | 514 | case 'Z': |
431 | | - value = env->CallStaticObjectMethod( |
432 | | - jclass{nullptr}, set_boolean, |
433 | | - env->CallNonvirtualBooleanMethodA(thiz, cls, target, a.data())); |
| 515 | + value = env->CallStaticObjectMethod(cls_Boolean, set_boolean, ret_val.z); |
434 | 516 | break; |
435 | | - case 'L': // Return type is an object, no boxing needed. |
436 | | - value = env->CallNonvirtualObjectMethodA(thiz, cls, target, a.data()); |
| 517 | + case 'L': |
| 518 | + value = ret_val.l; |
437 | 519 | break; |
438 | | - default: // Assumes 'V' for void return type. |
439 | 520 | case 'V': |
440 | | - env->CallNonvirtualVoidMethodA(thiz, cls, target, a.data()); |
| 521 | + value = nullptr; |
441 | 522 | break; |
442 | 523 | } |
443 | 524 |
|
|
0 commit comments