Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 92 additions & 27 deletions lib/lua/compiler/codegen.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,7 @@ defmodule Lua.Compiler.Codegen do

# Hint for "attempt to index a nil value (...)" if `obj` is nil/non-table.
obj_hint = name_hint(obj_expr, ctx)
call_hint = {:method, method, obj_hint}

# Compile the object expression
{object_instructions, obj_reg, ctx} = gen_expr(obj_expr, ctx)
Expand All @@ -1314,38 +1315,102 @@ defmodule Lua.Compiler.Codegen do
ctx = record_peak(ctx)
ctx = %{ctx | next_reg: base_reg + 2}

# Compile arguments into temp registers above the arg window
arg_count = length(args)
ctx = %{ctx | next_reg: base_reg + 2 + arg_count}
# Classify the last argument — same calling conventions as `Expr.Call`,
# offset by one register for `self`.
last_arg_type =
case args do
[] ->
:normal

{arg_instructions, arg_regs, ctx} =
Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} ->
{arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx)
{instructions ++ arg_instructions, regs ++ [arg_reg], ctx}
end)
_ ->
case List.last(args) do
%Expr.Vararg{} -> :vararg
%Expr.Call{} -> :multi_call
%Expr.MethodCall{} -> :multi_call
_ -> :normal
end
end

# Move each arg result to its expected position (base+2+i)
move_instructions =
arg_regs
|> Enum.with_index()
|> Enum.flat_map(fn {arg_reg, i} ->
expected_reg = base_reg + 2 + i
case last_arg_type do
:vararg ->
# obj:m(a, b, ...) — load a, b then all varargs
init_args = Enum.slice(args, 0..-2//1)
arg_count = length(init_args)
ctx = %{ctx | next_reg: base_reg + 2 + arg_count}

if arg_reg == expected_reg do
[]
else
[Instruction.move(expected_reg, arg_reg)]
end
end)
{arg_instructions, arg_regs, ctx} =
Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} ->
{arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx)
{instructions ++ arg_instructions, regs ++ [arg_reg], ctx}
end)

# Call with arg_count + 1 for self
call_instruction = Instruction.call(base_reg, arg_count + 1, 1, {:method, method, obj_hint})
move_instructions = gen_move_args(arg_regs, base_reg + 2)

{object_instructions ++
[self_instruction] ++
arg_instructions ++
move_instructions ++
[call_instruction], base_reg, ctx}
vararg_base = base_reg + 2 + arg_count
vararg_instruction = Instruction.vararg(vararg_base, 0)
# Fixed slots above R[base]: self (1) + arg_count.
call_instruction = Instruction.call(base_reg, -(arg_count + 2), 1, call_hint)

{object_instructions ++
[self_instruction] ++
arg_instructions ++
move_instructions ++
[vararg_instruction, call_instruction], base_reg, ctx}

:multi_call ->
# obj:m(a, b, g()) — load a, b then expand all results of g()
init_args = Enum.slice(args, 0..-2//1)
last_call = List.last(args)
fixed_count = length(init_args)

ctx = %{ctx | next_reg: base_reg + 2 + fixed_count}

{arg_instructions, arg_regs, ctx} =
Enum.reduce(init_args, {[], [], ctx}, fn arg, {instructions, regs, ctx} ->
{arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx)
{instructions ++ arg_instructions, regs ++ [arg_reg], ctx}
end)

move_instructions = gen_move_args(arg_regs, base_reg + 2)

# Position next_reg for the inner call so its base lands at base+2+fixed_count.
ctx = %{ctx | next_reg: base_reg + 2 + fixed_count}

{inner_call_instructions, _inner_base, ctx} = gen_expr(last_call, ctx)

# Patch the inner call's result_count to -2 (expand all results into place).
inner_call_instructions = patch_call_result_count(inner_call_instructions, -2)

# Fixed slots above R[base]: self (1) + fixed_count.
call_instruction = Instruction.call(base_reg, {:multi, 1 + fixed_count}, 1, call_hint)

{object_instructions ++
[self_instruction] ++
arg_instructions ++
move_instructions ++
inner_call_instructions ++
[call_instruction], base_reg, ctx}

:normal ->
arg_count = length(args)
ctx = %{ctx | next_reg: base_reg + 2 + arg_count}

{arg_instructions, arg_regs, ctx} =
Enum.reduce(args, {[], [], ctx}, fn arg, {instructions, regs, ctx} ->
{arg_instructions, arg_reg, ctx} = gen_expr(arg, ctx)
{instructions ++ arg_instructions, regs ++ [arg_reg], ctx}
end)

move_instructions = gen_move_args(arg_regs, base_reg + 2)

call_instruction = Instruction.call(base_reg, arg_count + 1, 1, call_hint)

{object_instructions ++
[self_instruction] ++
arg_instructions ++
move_instructions ++
[call_instruction], base_reg, ctx}
end
end

defp gen_expr(%Expr.Vararg{}, ctx) do
Expand Down
52 changes: 52 additions & 0 deletions test/language/function_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,56 @@ defmodule Lua.Language.FunctionTest do

assert {["error msg"], _} = Lua.eval!(lua, code)
end

test "method call expands table.unpack in tail position", %{lua: lua} do
code = ~S"""
local t = {}
function t:m(...) return select('#', ...) end
local vals = {"a", "b", "c"}
return t:m(table.unpack(vals))
"""

assert {[3], _} = Lua.eval!(lua, code)
end

test "method call expands vararg in tail position", %{lua: lua} do
code = ~S"""
local t = {}
function t:m(...) return select('#', ...) end
local function wrap(...) return t:m(...) end
return wrap("a", "b", "c")
"""

assert {[3], _} = Lua.eval!(lua, code)
end

test "method call expands inner call in tail position", %{lua: lua} do
code = ~S"""
local t = {}
function t:m(...) return select('#', ...) end
local function three() return 1, 2, 3 end
return t:m(three())
"""

assert {[3], _} = Lua.eval!(lua, code)
end

test "method call expands table.unpack with leading fixed args", %{lua: lua} do
code = ~S"""
local t = {}
function t:m(x, ...) return x, select('#', ...) end
return t:m("first", table.unpack({"a","b","c"}))
"""

assert {["first", 3], _} = Lua.eval!(lua, code)
end

test "string:format with table.unpack expands all values", %{lua: lua} do
code = ~S"""
local args = {"a", "b", "c"}
return ("[%s,%s,%s]"):format(table.unpack(args))
"""

assert {["[a,b,c]"], _} = Lua.eval!(lua, code)
end
end
Loading