Skip to content

Commit 629a0b2

Browse files
authored
Merge pull request #110 from Shopify/yjit-arg-types
YJIT: pass argument types to callees
2 parents 10f54ed + 2b3d1a1 commit 629a0b2

2 files changed

Lines changed: 48 additions & 12 deletions

File tree

bootstraptest/test_yjit.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ def foo
8989
retval = foo()
9090
}
9191

92+
# Passing argument types to callees
93+
assert_equal '8.5', %q{
94+
def foo(x, y)
95+
x + y
96+
end
97+
98+
def bar
99+
foo(7, 1.5)
100+
end
101+
102+
bar
103+
bar
104+
}
105+
92106
# Recursive Ruby-to-Ruby calls
93107
assert_equal '21', %q{
94108
def fib(n)

yjit_codegen.c

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,13 @@ yjit_entry_prologue(void)
256256
return code_ptr;
257257
}
258258

259-
260259
// Generate code to check for interrupts and take a side-exit
261260
static void
262261
yjit_check_ints(codeblock_t* cb, uint8_t* side_exit)
263262
{
264263
// Check for interrupts
265264
// see RUBY_VM_CHECK_INTS(ec) macro
265+
ADD_COMMENT(cb, "RUBY_VM_CHECK_INTS(ec)");
266266
mov(cb, REG0_32, member_opnd(REG_EC, rb_execution_context_t, interrupt_mask));
267267
not(cb, REG0_32);
268268
test(cb, member_opnd(REG_EC, rb_execution_context_t, interrupt_flag), REG0_32);
@@ -288,7 +288,6 @@ jit_jump_to_next_insn(jitstate_t *jit, const ctx_t *current_context)
288288
);
289289
}
290290

291-
292291
// Compile a sequence of bytecode instructions for a given basic block version
293292
void
294293
yjit_gen_block(ctx_t *ctx, block_t *block, rb_execution_context_t *ec)
@@ -517,22 +516,33 @@ gen_putself(jitstate_t* jit, ctx_t* ctx)
517516
return YJIT_KEEP_COMPILING;
518517
}
519518

519+
// Compute the index of a local variable from its slot index
520+
uint32_t slot_to_local_idx(const rb_iseq_t *iseq, int32_t slot_idx)
521+
{
522+
// Convoluted rules from local_var_name() in iseq.c
523+
int32_t local_table_size = iseq->body->local_table_size;
524+
int32_t op = slot_idx - VM_ENV_DATA_SIZE;
525+
int32_t local_idx = local_idx = local_table_size - op - 1;
526+
RUBY_ASSERT(local_idx >= 0 && local_idx < local_table_size);
527+
return (uint32_t)local_idx;
528+
}
529+
520530
static codegen_status_t
521531
gen_getlocal_wc0(jitstate_t* jit, ctx_t* ctx)
522532
{
533+
// Compute the offset from BP to the local
534+
int32_t slot_idx = (int32_t)jit_get_arg(jit, 0);
535+
const int32_t offs = -(SIZEOF_VALUE * slot_idx);
536+
uint32_t local_idx = slot_to_local_idx(jit->iseq, slot_idx);
537+
523538
// Load environment pointer EP from CFP
524539
mov(cb, REG0, member_opnd(REG_CFP, rb_control_frame_t, ep));
525540

526-
// Compute the offset from BP to the local
527-
int32_t local_idx = (int32_t)jit_get_arg(jit, 0);
528-
const int32_t offs = -(SIZEOF_VALUE * local_idx);
529-
530541
// Load the local from the EP
531542
mov(cb, REG0, mem_opnd(64, REG0, offs));
532543

533544
// Write the local at SP
534545
x86opnd_t stack_top = ctx_stack_push_local(ctx, local_idx);
535-
//x86opnd_t stack_top = ctx_stack_push(ctx, TYPE_UNKNOWN);
536546
mov(cb, stack_top, REG0);
537547

538548
return YJIT_KEEP_COMPILING;
@@ -581,7 +591,8 @@ gen_setlocal_wc0(jitstate_t* jit, ctx_t* ctx)
581591
}
582592
*/
583593

584-
int32_t local_idx = (int32_t)jit_get_arg(jit, 0);
594+
int32_t slot_idx = (int32_t)jit_get_arg(jit, 0);
595+
uint32_t local_idx = slot_to_local_idx(jit->iseq, slot_idx);
585596

586597
// Load environment pointer EP from CFP
587598
mov(cb, REG0, member_opnd(REG_CFP, rb_control_frame_t, ep));
@@ -605,7 +616,7 @@ gen_setlocal_wc0(jitstate_t* jit, ctx_t* ctx)
605616
mov(cb, REG1, stack_top);
606617

607618
// Write the value at the environment pointer
608-
const int32_t offs = -8 * local_idx;
619+
const int32_t offs = -8 * slot_idx;
609620
mov(cb, mem_opnd(64, REG0, offs), REG1);
610621

611622
return YJIT_KEEP_COMPILING;
@@ -1700,7 +1711,7 @@ gen_oswb_iseq(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const r
17001711
lea(cb, REG0, ctx_sp_opnd(ctx, sizeof(VALUE) * -(argc + 1)));
17011712
mov(cb, member_opnd(REG_CFP, rb_control_frame_t, sp), REG0);
17021713

1703-
// Store the next PC i the current frame
1714+
// Store the next PC in the current frame
17041715
mov(cb, REG0, const_ptr_opnd(jit->pc + insn_len(BIN(opt_send_without_block))));
17051716
mov(cb, mem_opnd(64, REG_CFP, offsetof(rb_control_frame_t, pc)), REG0);
17061717

@@ -1763,6 +1774,15 @@ gen_oswb_iseq(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const r
17631774
// Stub so we can return to JITted code
17641775
blockid_t return_block = { jit->iseq, jit_next_insn_idx(jit) };
17651776

1777+
// Create a context for the callee
1778+
ctx_t callee_ctx = DEFAULT_CTX;
1779+
1780+
// Set the argument type in the callee's context
1781+
for (int32_t arg_idx = 0; arg_idx < argc; ++arg_idx) {
1782+
val_type_t arg_type = ctx_get_opnd_type(ctx, OPND_STACK(argc - arg_idx - 1));
1783+
ctx_set_local_type(&callee_ctx, arg_idx, arg_type);
1784+
}
1785+
17661786
// Pop arguments and receiver in return context, push the return value
17671787
// After the return, the JIT and interpreter SP will match up
17681788
ctx_t return_ctx = *ctx;
@@ -1784,12 +1804,12 @@ gen_oswb_iseq(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const r
17841804
//print_str(cb, "calling Ruby func:");
17851805
//print_str(cb, rb_id2name(vm_ci_mid(ci)));
17861806

1787-
// Load the updated SP
1807+
// Load the updated SP from the CFP
17881808
mov(cb, REG_SP, member_opnd(REG_CFP, rb_control_frame_t, sp));
17891809

17901810
// Directly jump to the entry point of the callee
17911811
gen_direct_jump(
1792-
&DEFAULT_CTX,
1812+
&callee_ctx,
17931813
(blockid_t){ iseq, 0 }
17941814
);
17951815

@@ -1943,6 +1963,7 @@ gen_leave(jitstate_t* jit, ctx_t* ctx)
19431963
mov(cb, REG0, member_opnd(REG_CFP, rb_control_frame_t, ep));
19441964

19451965
// if (flags & VM_FRAME_FLAG_FINISH) != 0
1966+
ADD_COMMENT(cb, "check for finish frame");
19461967
x86opnd_t flags_opnd = mem_opnd(64, REG0, sizeof(VALUE) * VM_ENV_DATA_INDEX_FLAGS);
19471968
test(cb, flags_opnd, imm_opnd(VM_FRAME_FLAG_FINISH));
19481969
jnz_ptr(cb, COUNTED_EXIT(side_exit, leave_se_finish_frame));
@@ -1968,6 +1989,7 @@ gen_leave(jitstate_t* jit, ctx_t* ctx)
19681989
mov(cb, mem_opnd(64, REG_SP, -SIZEOF_VALUE), REG0);
19691990

19701991
// If the return address is NULL, fall back to the interpreter
1992+
ADD_COMMENT(cb, "check for jit return");
19711993
int FALLBACK_LABEL = cb_new_label(cb, "FALLBACK");
19721994
test(cb, REG1, REG1);
19731995
jz_label(cb, FALLBACK_LABEL);

0 commit comments

Comments
 (0)