Skip to content

Commit 02fda84

Browse files
committed
JIT: factorize tail calls to reduce binary size
Use a cache to remember tail calls that were already implemented and replace further implementations of the same tail call with a jump to the previous implementation. Coverage shows that all cases are covered in libs/estdlib/src and libs/jit/src: OP_RETURN: 50 misses, 1735 hits (97%) OP_JUMP/OP_CALL_LAST/OP_CALL_ONLY: 656 misses, 389 hits (37%) OP_CALL_LAST: 220 misses, 206 hits (48%) OP_FUNC_INFO: 58 misses, 1619 hits (97%) Signed-off-by: Paul Guyot <[email protected]>
1 parent 4621b7d commit 02fda84

File tree

5 files changed

+214
-64
lines changed

5 files changed

+214
-64
lines changed

libs/jit/src/jit.erl

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@
100100
labels_count :: pos_integer(),
101101
atom_resolver :: fun((integer()) -> atom()),
102102
literal_resolver :: fun((integer()) -> any()),
103-
type_resolver :: fun((integer()) -> any())
103+
type_resolver :: fun((integer()) -> any()),
104+
tail_cache :: [{tuple(), non_neg_integer()}]
104105
}).
105106

106107
-type stream() :: any().
@@ -142,7 +143,8 @@ compile(
142143
labels_count = LabelsCount,
143144
atom_resolver = AtomResolver,
144145
literal_resolver = LiteralResolver,
145-
type_resolver = TypeResolver
146+
type_resolver = TypeResolver,
147+
tail_cache = []
146148
},
147149
{State1, MSt2} = first_pass(Opcodes, MMod, MSt1, State0),
148150
MSt3 = second_pass(MMod, MSt2, State1),
@@ -170,18 +172,30 @@ first_pass(
170172
?ASSERT_ALL_NATIVE_FREE(MSt1),
171173
first_pass(Rest1, MMod, MSt1, State0);
172174
% 2
173-
first_pass(<<?OP_FUNC_INFO, Rest0/binary>>, MMod, MSt0, State0) ->
175+
first_pass(<<?OP_FUNC_INFO, Rest0/binary>>, MMod, MSt0, #state{tail_cache = TC} = State0) ->
174176
?ASSERT_ALL_NATIVE_FREE(MSt0),
175177
{_ModuleAtom, Rest1} = decode_atom(Rest0),
176178
{_FunctionName, Rest2} = decode_atom(Rest1),
177179
{_Arity, Rest3} = decode_literal(Rest2),
178180
?TRACE("OP_FUNC_INFO ~p, ~p, ~p\n", [_ModuleAtom, _FunctionName, _Arity]),
179-
% Implement function clause at the previous label. (TODO: optimize it out to save space)
180-
MSt1 = MMod:call_primitive_last(MSt0, ?PRIM_RAISE_ERROR, [
181-
ctx, jit_state, offset, ?FUNCTION_CLAUSE_ATOM
182-
]),
183-
?ASSERT_ALL_NATIVE_FREE(MSt1),
184-
first_pass(Rest3, MMod, MSt1, State0);
181+
% Implement function clause at the previous label.
182+
Offset = MMod:offset(MSt0),
183+
{MSt1, OffsetReg} = MMod:move_to_native_register(MSt0, Offset),
184+
TailCacheKey = {call_primitive_last, ?PRIM_RAISE_ERROR, [OffsetReg, ?FUNCTION_CLAUSE_ATOM]},
185+
State1 =
186+
case lists:keyfind(TailCacheKey, 1, TC) of
187+
false ->
188+
MSt3 = MMod:call_primitive_last(MSt1, ?PRIM_RAISE_ERROR, [
189+
ctx, jit_state, {free, OffsetReg}, ?FUNCTION_CLAUSE_ATOM
190+
]),
191+
State0#state{tail_cache = [{TailCacheKey, Offset} | TC]};
192+
{TailCacheKey, CacheOffset} ->
193+
MSt2 = MMod:jump_to_offset(MSt1, CacheOffset),
194+
MSt3 = MMod:free_native_registers(MSt2, [OffsetReg]),
195+
State0
196+
end,
197+
?ASSERT_ALL_NATIVE_FREE(MSt3),
198+
first_pass(Rest3, MMod, MSt3, State1);
185199
% 3
186200
first_pass(
187201
<<?OP_INT_CALL_END>>, MMod, MSt0, #state{labels_count = LabelsCount} = State
@@ -203,26 +217,56 @@ first_pass(<<?OP_CALL, Rest0/binary>>, MMod, MSt0, State0) ->
203217
?ASSERT_ALL_NATIVE_FREE(MSt1),
204218
first_pass(Rest2, MMod, MSt1, State0);
205219
% 5
206-
first_pass(<<?OP_CALL_LAST, Rest0/binary>>, MMod, MSt0, State0) ->
220+
first_pass(<<?OP_CALL_LAST, Rest0/binary>>, MMod, MSt0, #state{tail_cache = TC} = State0) ->
207221
?ASSERT_ALL_NATIVE_FREE(MSt0),
208222
{_Arity, Rest1} = decode_literal(Rest0),
209223
{Label, Rest2} = decode_label(Rest1),
210224
{NWords, Rest3} = decode_literal(Rest2),
211225
?TRACE("OP_CALL_LAST ~p, ~p, ~p\n", [_Arity, Label, NWords]),
212-
MSt1 = MMod:move_to_cp(MSt0, {y_reg, NWords}),
213-
MSt2 = MMod:increment_sp(MSt1, NWords + 1),
214-
MSt3 = MMod:call_only_or_schedule_next(MSt2, Label),
226+
TailCacheKey0 = {op_call_last, NWords, Label},
227+
case lists:keyfind(TailCacheKey0, 1, TC) of
228+
false ->
229+
Offset0 = MMod:offset(MSt0),
230+
MSt1 = MMod:move_to_cp(MSt0, {y_reg, NWords}),
231+
MSt2 = MMod:increment_sp(MSt1, NWords + 1),
232+
TailCacheKey1 = {op_call_only, Label},
233+
case lists:keyfind(TailCacheKey1, 1, TC) of
234+
false ->
235+
Offset1 = MMod:offset(MSt2),
236+
MSt3 = MMod:call_only_or_schedule_next(MSt2, Label),
237+
State1 = State0#state{
238+
tail_cache = [{TailCacheKey1, Offset1}, {TailCacheKey0, Offset0} | TC]
239+
};
240+
{TailCacheKey1, Offset1} ->
241+
MSt3 = MMod:jump_to_offset(MSt2, Offset1),
242+
State1 = State0#state{
243+
tail_cache = [{TailCacheKey0, Offset0} | TC]
244+
}
245+
end;
246+
{TailCacheKey0, Offset0} ->
247+
MSt3 = MMod:jump_to_offset(MSt0, Offset0),
248+
State1 = State0
249+
end,
215250
?ASSERT_ALL_NATIVE_FREE(MSt3),
216-
first_pass(Rest3, MMod, MSt3, State0);
251+
first_pass(Rest3, MMod, MSt3, State1);
217252
% 6
218-
first_pass(<<?OP_CALL_ONLY, Rest0/binary>>, MMod, MSt0, State0) ->
253+
first_pass(<<?OP_CALL_ONLY, Rest0/binary>>, MMod, MSt0, #state{tail_cache = TC} = State0) ->
219254
?ASSERT_ALL_NATIVE_FREE(MSt0),
220255
{_Arity, Rest1} = decode_literal(Rest0),
221256
{Label, Rest2} = decode_label(Rest1),
222257
?TRACE("OP_CALL_ONLY ~p, ~p\n", [_Arity, Label]),
223-
MSt1 = MMod:call_only_or_schedule_next(MSt0, Label),
258+
TailCacheKey = {op_call_only, Label},
259+
case lists:keyfind(TailCacheKey, 1, TC) of
260+
false ->
261+
Offset = MMod:offset(MSt0),
262+
MSt1 = MMod:call_only_or_schedule_next(MSt0, Label),
263+
State1 = State0#state{tail_cache = [{TailCacheKey, Offset} | TC]};
264+
{TailCacheKey, Offset} ->
265+
MSt1 = MMod:jump_to_offset(MSt0, Offset),
266+
State1 = State0
267+
end,
224268
?ASSERT_ALL_NATIVE_FREE(MSt1),
225-
first_pass(Rest2, MMod, MSt1, State0);
269+
first_pass(Rest2, MMod, MSt1, State1);
226270
% 7
227271
first_pass(<<?OP_CALL_EXT, Rest0/binary>>, MMod, MSt0, State0) ->
228272
?ASSERT_ALL_NATIVE_FREE(MSt0),
@@ -348,7 +392,7 @@ first_pass(<<?OP_DEALLOCATE, Rest0/binary>>, MMod, MSt0, State0) ->
348392
?ASSERT_ALL_NATIVE_FREE(MSt2),
349393
first_pass(Rest1, MMod, MSt2, State0);
350394
% 19
351-
first_pass(<<?OP_RETURN, Rest/binary>>, MMod, MSt0, State0) ->
395+
first_pass(<<?OP_RETURN, Rest/binary>>, MMod, MSt0, #state{tail_cache = TC} = State0) ->
352396
?ASSERT_ALL_NATIVE_FREE(MSt0),
353397
?TRACE("OP_RETURN\n", []),
354398
% Optimized return: check if returning within same module
@@ -371,9 +415,18 @@ first_pass(<<?OP_RETURN, Rest/binary>>, MMod, MSt0, State0) ->
371415
),
372416
MSt5 = MMod:free_native_registers(MSt4, [CpReg0]),
373417
% Different module: use existing slow path
374-
MSt6 = MMod:call_primitive_last(MSt5, ?PRIM_RETURN, [ctx, jit_state]),
418+
TailCacheKey = {call_primitive_last, ?PRIM_RETURN},
419+
case lists:keyfind(TailCacheKey, 1, TC) of
420+
false ->
421+
Offset = MMod:offset(MSt5),
422+
MSt6 = MMod:call_primitive_last(MSt5, ?PRIM_RETURN, [ctx, jit_state]),
423+
State1 = State0#state{tail_cache = [{TailCacheKey, Offset} | TC]};
424+
{TailCacheKey, Offset} ->
425+
MSt6 = MMod:jump_to_offset(MSt5, Offset),
426+
State1 = State0
427+
end,
375428
?ASSERT_ALL_NATIVE_FREE(MSt6),
376-
first_pass(Rest, MMod, MSt6, State0);
429+
first_pass(Rest, MMod, MSt6, State1);
377430
% 20
378431
first_pass(<<?OP_SEND, Rest/binary>>, MMod, MSt0, State0) ->
379432
?ASSERT_ALL_NATIVE_FREE(MSt0),
@@ -836,13 +889,22 @@ first_pass(<<?OP_SELECT_TUPLE_ARITY, Rest0/binary>>, MMod, MSt0, State0) ->
836889
?ASSERT_ALL_NATIVE_FREE(MSt5),
837890
first_pass(Rest4, MMod, MSt5, State0);
838891
% 61
839-
first_pass(<<?OP_JUMP, Rest0/binary>>, MMod, MSt0, State0) ->
892+
first_pass(<<?OP_JUMP, Rest0/binary>>, MMod, MSt0, #state{tail_cache = TC} = State0) ->
840893
?ASSERT_ALL_NATIVE_FREE(MSt0),
841894
{Label, Rest1} = decode_label(Rest0),
842895
?TRACE("OP_JUMP ~p\n", [Label]),
843-
MSt1 = MMod:call_only_or_schedule_next(MSt0, Label),
844-
?ASSERT_ALL_NATIVE_FREE(MSt1),
845-
first_pass(Rest1, MMod, MSt1, State0);
896+
TailCacheKey = {op_call_only, Label},
897+
case lists:keyfind(TailCacheKey, 1, TC) of
898+
false ->
899+
Offset = MMod:offset(MSt0),
900+
MSt1 = MMod:call_only_or_schedule_next(MSt0, Label),
901+
?ASSERT_ALL_NATIVE_FREE(MSt1),
902+
first_pass(Rest1, MMod, MSt1, State0#state{tail_cache = [{TailCacheKey, Offset} | TC]});
903+
{TailCacheKey, Offset} ->
904+
MSt1 = MMod:jump_to_offset(MSt0, Offset),
905+
?ASSERT_ALL_NATIVE_FREE(MSt1),
906+
first_pass(Rest1, MMod, MSt1, State0)
907+
end;
846908
% 62
847909
% Same implementation as OP_TRY, to confirm.
848910
first_pass(<<?OP_CATCH, Rest0/binary>>, MMod, MSt0, State0) ->

libs/jit/src/jit_aarch64.erl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
return_if_not_equal_to_ctx/2,
3939
jump_to_label/2,
4040
jump_to_continuation/2,
41+
jump_to_offset/2,
4142
if_block/3,
4243
if_else_block/4,
4344
shift_right/3,
@@ -531,6 +532,13 @@ jump_to_label(
531532
State#state{stream = Stream1, branches = [Reloc | AccBranches]}
532533
end.
533534

535+
jump_to_offset(#state{stream_module = StreamModule, stream = Stream0} = State, TargetOffset) ->
536+
Offset = StreamModule:offset(Stream0),
537+
Rel = TargetOffset - Offset,
538+
I1 = jit_aarch64_asm:b(Rel),
539+
Stream1 = StreamModule:append(Stream0, I1),
540+
State#state{stream = Stream1}.
541+
534542
%%-----------------------------------------------------------------------------
535543
%% @doc Jump to a continuation address stored in a register.
536544
%% This is used for optimized intra-module returns.

libs/jit/src/jit_armv6m.erl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
return_if_not_equal_to_ctx/2,
3939
jump_to_label/2,
4040
jump_to_continuation/2,
41+
jump_to_offset/2,
4142
if_block/3,
4243
if_else_block/4,
4344
shift_right/3,
@@ -731,6 +732,12 @@ jump_to_label(
731732
Stream1 = StreamModule:append(Stream0, CodeBlock),
732733
State1#state{stream = Stream1}.
733734

735+
jump_to_offset(#state{stream_module = StreamModule, stream = Stream0} = State, TargetOffset) ->
736+
Offset = StreamModule:offset(Stream0),
737+
CodeBlock = branch_to_offset_code(State, Offset, TargetOffset),
738+
Stream1 = StreamModule:append(Stream0, CodeBlock),
739+
State#state{stream = Stream1}.
740+
734741
%%-----------------------------------------------------------------------------
735742
%% @doc Jump to address in continuation pointer register
736743
%% The continuation points to a function prologue, so we need to compute
@@ -793,15 +800,14 @@ jump_to_continuation(
793800
% Free all registers as this is a terminal instruction
794801
State1#state{stream = Stream2, available_regs = ?AVAILABLE_REGS, used_regs = []}.
795802

796-
branch_to_label_code(State, Offset, Label, {Label, LabelOffset}) when
797-
LabelOffset - Offset =< 2050, LabelOffset - Offset >= -2044
803+
branch_to_offset_code(_State, Offset, TargetOffset) when
804+
TargetOffset - Offset =< 2050, TargetOffset - Offset >= -2044
798805
->
799806
% Near branch: use direct B instruction
800-
Rel = LabelOffset - Offset,
801-
CodeBlock = jit_armv6m_asm:b(Rel),
802-
{State, CodeBlock};
803-
branch_to_label_code(
804-
#state{available_regs = [TempReg | _]} = State0, Offset, Label, {Label, LabelOffset}
807+
Rel = TargetOffset - Offset,
808+
jit_armv6m_asm:b(Rel);
809+
branch_to_offset_code(
810+
#state{available_regs = [TempReg | _]}, Offset, TargetOffset
805811
) ->
806812
% Far branch: use register-based sequence, need temporary register
807813
if
@@ -812,19 +818,22 @@ branch_to_label_code(
812818
I3 = jit_armv6m_asm:bx(TempReg),
813819
% Unaligned : need nop
814820
I4 = jit_armv6m_asm:nop(),
815-
LiteralValue = LabelOffset - Offset - 5,
821+
LiteralValue = TargetOffset - Offset - 5,
816822
I5 = <<LiteralValue:32/little>>,
817-
CodeBlock = <<I1/binary, I2/binary, I3/binary, I4/binary, I5/binary>>;
823+
<<I1/binary, I2/binary, I3/binary, I4/binary, I5/binary>>;
818824
true ->
819825
% Unaligned
820826
I1 = jit_armv6m_asm:ldr(TempReg, {pc, 4}),
821827
I2 = jit_armv6m_asm:add(TempReg, pc),
822828
I3 = jit_armv6m_asm:bx(TempReg),
823-
LiteralValue = LabelOffset - Offset - 5,
829+
LiteralValue = TargetOffset - Offset - 5,
824830
I4 = <<LiteralValue:32/little>>,
825-
CodeBlock = <<I1/binary, I2/binary, I3/binary, I4/binary>>
826-
end,
827-
{State0, CodeBlock};
831+
<<I1/binary, I2/binary, I3/binary, I4/binary>>
832+
end.
833+
834+
branch_to_label_code(State, Offset, Label, {Label, LabelOffset}) ->
835+
CodeBlock = branch_to_offset_code(State, Offset, LabelOffset),
836+
{State, CodeBlock};
828837
branch_to_label_code(
829838
#state{available_regs = [TempReg | _], branches = Branches} = State0, Offset, Label, false
830839
) ->

libs/jit/src/jit_x86_64.erl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
return_if_not_equal_to_ctx/2,
3939
jump_to_label/2,
4040
jump_to_continuation/2,
41+
jump_to_offset/2,
4142
if_block/3,
4243
if_else_block/4,
4344
shift_right/3,
@@ -524,6 +525,13 @@ jump_to_label(
524525
State#state{stream = Stream1, branches = [Reloc | AccBranches]}
525526
end.
526527

528+
jump_to_offset(#state{stream_module = StreamModule, stream = Stream0} = State, TargetOffset) ->
529+
Offset = StreamModule:offset(Stream0),
530+
RelOffset = TargetOffset - Offset,
531+
I1 = jit_x86_64_asm:jmp(RelOffset),
532+
Stream1 = StreamModule:append(Stream0, I1),
533+
State#state{stream = Stream1}.
534+
527535
%%-----------------------------------------------------------------------------
528536
%% @doc Jump to a continuation address stored in a register.
529537
%% This is used for optimized intra-module returns.

0 commit comments

Comments
 (0)