diff --git a/lib/lua.ex b/lib/lua.ex index 97ae3a3..f5429e3 100644 --- a/lib/lua.ex +++ b/lib/lua.ex @@ -303,67 +303,11 @@ defmodule Lua do raise Lua.RuntimeException, "Lua.set!/3 cannot have empty keys" end - def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 1) do + def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 1) or is_function(func, 2) do keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1) + {function_name, scope} = List.pop_at(keys, -1) - wrapped = - {:native_func, - fn args, state -> - return = List.wrap(func.(args)) - - if not Util.list_encoded?(return) do - {function_name, scope} = List.pop_at(keys, -1) - - raise Lua.RuntimeException, - function: function_name, - scope: scope, - message: "deflua functions must return encoded data, got #{inspect(return)}" - end - - {return, state} - end} - - state = do_set_nested(lua.state, keys, wrapped) - %{lua | state: state} - end - - def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 2) do - keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1) - - wrapped = - {:native_func, - fn args, state -> - {function_name, scope} = List.pop_at(keys, -1) - - case func.(args, wrap(state)) do - {:error, reason, %__MODULE__{} = returned_lua} -> - raise RuntimeError, value: reason, state: returned_lua.state - - {value, %__MODULE__{} = lua} -> - value = List.wrap(value) - - if not Util.list_encoded?(value) do - raise Lua.RuntimeException, - function: function_name, - scope: scope, - message: "deflua functions must return encoded data, got #{inspect(value)}" - end - - {value, lua.state} - - value -> - value = List.wrap(value) - - if not Util.list_encoded?(value) do - raise Lua.RuntimeException, - function: function_name, - scope: scope, - message: "deflua functions must return encoded data, got #{inspect(value)}" - end - - {value, state} - end - end} + wrapped = wrap_callback(func, function_name, scope) state = do_set_nested(lua.state, keys, wrapped) %{lua | state: state} @@ -372,12 +316,13 @@ defmodule Lua do def set!(%__MODULE__{} = lua, keys, value) do keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1) value = Display.unwrap(value) + {function_name, scope} = List.pop_at(keys, -1) {encoded, state} = if Util.encoded?(value) do {value, lua.state} else - Value.encode(value, lua.state) + Value.encode(value, lua.state, &wrap_callback(&1, function_name, scope)) end state = do_set_nested(state, keys, encoded) @@ -935,7 +880,7 @@ defmodule Lua do if Util.encoded?(value) do {value, lua} else - {encoded, state} = Value.encode(value, state) + {encoded, state} = Value.encode(value, state, &wrap_callback(&1, :anonymous, [])) {encoded, %{lua | state: state}} end rescue @@ -1215,6 +1160,53 @@ defmodule Lua do defp wrap(state), do: %__MODULE__{state: state} + # Builds the `{:native_func, closure}` for a user-supplied Elixir callback. + # + # Both arities speak the documented `t:Lua.t/0` convention at the VM + # boundary: the raw `Lua.VM.State` handed in by the executor is wrapped as a + # `%Lua{}` for the callback, the returned `%Lua{}` is unwrapped back to its + # raw `.state`, and returns are validated as encoded. Sharing one builder + # keeps `Lua.set!/3` at a path, functions nested inside a `set!` value, and + # `Lua.encode!/2` on the same convention. `function_name`/`scope` feed only + # the error message. + defp wrap_callback(func, function_name, scope) when is_function(func, 1) do + {:native_func, + fn args, state -> + return = List.wrap(func.(args)) + validate_encoded!(return, function_name, scope) + {return, state} + end} + end + + defp wrap_callback(func, function_name, scope) when is_function(func, 2) do + {:native_func, + fn args, state -> + case func.(args, wrap(state)) do + {:error, reason, %__MODULE__{} = returned_lua} -> + raise RuntimeError, value: reason, state: returned_lua.state + + {value, %__MODULE__{} = lua} -> + value = List.wrap(value) + validate_encoded!(value, function_name, scope) + {value, lua.state} + + value -> + value = List.wrap(value) + validate_encoded!(value, function_name, scope) + {value, state} + end + end} + end + + defp validate_encoded!(value, function_name, scope) do + if not Util.list_encoded?(value) do + raise Lua.RuntimeException, + function: function_name, + scope: scope, + message: "deflua functions must return encoded data, got #{inspect(value)}" + end + end + defp to_lua_key(key) when is_atom(key), do: Atom.to_string(key) defp to_lua_key(key) when is_binary(key), do: key defp to_lua_key(key), do: key diff --git a/lib/lua/vm/value.ex b/lib/lua/vm/value.ex index 91dde89..d765d5a 100644 --- a/lib/lua/vm/value.ex +++ b/lib/lua/vm/value.ex @@ -227,41 +227,37 @@ defmodule Lua.VM.Value do Returns `{encoded_value, state}` since encoding maps and lists allocates tables. """ - @spec encode(term(), State.t()) :: {term(), State.t()} - def encode(nil, state), do: {nil, state} - def encode(value, state) when is_boolean(value), do: {value, state} - def encode(value, state) when is_number(value), do: {value, state} - def encode(value, state) when is_binary(value), do: {value, state} - def encode(value, state) when is_atom(value), do: {Atom.to_string(value), state} - - def encode(fun, state) when is_function(fun, 2), do: {{:native_func, fun}, state} - - def encode(fun, state) when is_function(fun, 1) do - wrapper = fn args, st -> {List.wrap(fun.(args)), st} end - {{:native_func, wrapper}, state} - end + @spec encode(term(), State.t(), (fun() -> term())) :: {term(), State.t()} + def encode(value, state, fun_wrapper \\ &default_fun_wrapper/1) + def encode(nil, state, _fun_wrapper), do: {nil, state} + def encode(value, state, _fun_wrapper) when is_boolean(value), do: {value, state} + def encode(value, state, _fun_wrapper) when is_number(value), do: {value, state} + def encode(value, state, _fun_wrapper) when is_binary(value), do: {value, state} + def encode(value, state, _fun_wrapper) when is_atom(value), do: {Atom.to_string(value), state} + + def encode(fun, state, fun_wrapper) when is_function(fun, 1) or is_function(fun, 2), do: {fun_wrapper.(fun), state} - def encode({:userdata, value}, state) do + def encode({:userdata, value}, state, _fun_wrapper) do State.alloc_userdata(state, value) end - def encode(map, state) when is_map(map) do + def encode(map, state, fun_wrapper) when is_map(map) do {data, state} = Enum.reduce(map, {%{}, state}, fn {k, v}, {data, state} -> key = if is_atom(k), do: Atom.to_string(k), else: k - {encoded_v, state} = encode(v, state) + {encoded_v, state} = encode(v, state, fun_wrapper) {Map.put(data, key, encoded_v), state} end) State.alloc_table(state, data) end - def encode(list, state) when is_list(list) do + def encode(list, state, fun_wrapper) when is_list(list) do if keyword_list?(list) do {data, state} = Enum.reduce(list, {%{}, state}, fn {k, v}, {data, state} -> key = Atom.to_string(k) - {encoded_v, state} = encode(v, state) + {encoded_v, state} = encode(v, state, fun_wrapper) {Map.put(data, key, encoded_v), state} end) @@ -271,7 +267,7 @@ defmodule Lua.VM.Value do list |> Enum.with_index(1) |> Enum.reduce({%{}, state}, fn {v, idx}, {data, state} -> - {encoded_v, state} = encode(v, state) + {encoded_v, state} = encode(v, state, fun_wrapper) {Map.put(data, idx, encoded_v), state} end) @@ -279,6 +275,15 @@ defmodule Lua.VM.Value do end end + # Default function wrapper, preserving the raw-state calling convention for + # low-level callers. `Lua.set!/3` and `Lua.encode!/2` inject their own + # wrapper so callbacks receive the public `t:Lua.t/0` instead. + defp default_fun_wrapper(fun) when is_function(fun, 2), do: {:native_func, fun} + + defp default_fun_wrapper(fun) when is_function(fun, 1) do + {:native_func, fn args, st -> {List.wrap(fun.(args)), st} end} + end + @doc """ Encodes a list of Elixir values, threading state through each encoding. diff --git a/test/lua/callback_state_asymmetry_test.exs b/test/lua/callback_state_asymmetry_test.exs new file mode 100644 index 0000000..8064d29 --- /dev/null +++ b/test/lua/callback_state_asymmetry_test.exs @@ -0,0 +1,89 @@ +defmodule Lua.CallbackStateAsymmetryTest do + @moduledoc """ + A two-arity Elixir callback (`fn args, state -> {results, state} end`) + must receive the same `state` argument — and accept the same return + shape — regardless of how it entered the VM: + + | How the callback enters the VM | `state` it receives | + | --------------------------------------------------------------- | ------------------- | + | `Lua.set!(lua, [:f], fun)` — function directly at the path | `t:Lua.t/0` | + | `deflua` + `Lua.load_api/3` | `t:Lua.t/0` | + | `Lua.set!(lua, [:t], %{"f" => fun})` — function inside a value | `t:Lua.t/0` | + | `Lua.encode!(lua, fun)` — closure handed to Lua at runtime | `t:Lua.t/0` | + + Every entry point hands the callback the public `t:Lua.t/0`, so the public + API (`Lua.decode!/2`, `Lua.encode!/2`, `Lua.get_private!/2`, …) works and a + single closure written against the documented `Lua.set!/3` convention works + in every position. + """ + use ExUnit.Case, async: true + + defmodule DefluaProbe do + @moduledoc false + use Lua.API, scope: "probe" + + deflua whoami(), state do + {[inspect(state.__struct__)], state} + end + end + + # Reports the struct name of the state the VM handed the callback. + defp probe_callback do + fn _args, state -> {[inspect(state.__struct__)], state} end + end + + # A callback written exactly the way the `Lua.set!/3` docs teach: treat + # `state` as a `t:Lua.t/0`, use the public API on it, return it as received. + defp documented_convention_callback do + fn _args, state -> {[Lua.get_private!(state, :secret)], state} end + end + + describe "callbacks that receive the public Lua.t (documented convention)" do + test "a function set! directly at a path receives Lua.t" do + lua = Lua.set!(Lua.new(), [:probe], probe_callback()) + + assert {["Lua"], _} = Lua.eval!(lua, "return probe()") + end + + test "a deflua function loaded via load_api receives Lua.t" do + lua = Lua.load_api(Lua.new(), DefluaProbe) + + assert {["Lua"], _} = Lua.eval!(lua, "return probe.whoami()") + end + + test "a set!-at-path callback can use the public Lua API on its state" do + lua = + Lua.new() + |> Lua.put_private(:secret, "from-private") + |> Lua.set!([:fetch], documented_convention_callback()) + + assert {["from-private"], _} = Lua.eval!(lua, "return fetch()") + end + end + + describe "callbacks that enter the VM as encoded values" do + test "a closure embedded via encode! receives the same Lua.t" do + {fun, lua} = Lua.encode!(Lua.new(), probe_callback()) + lua = Lua.set!(lua, [:probe], fun) + + assert {["Lua"], _} = Lua.eval!(lua, "return probe()") + end + + test "a function nested inside a table passed to set! receives the same Lua.t" do + lua = Lua.set!(Lua.new(), [:api], %{"probe" => probe_callback()}) + + assert {["Lua"], _} = Lua.eval!(lua, "return api.probe()") + end + + test "the same callback works identically via set! and via encode!" do + lua = Lua.put_private(Lua.new(), :secret, "from-private") + + lua = Lua.set!(lua, [:registered], documented_convention_callback()) + {encoded, lua} = Lua.encode!(lua, documented_convention_callback()) + lua = Lua.set!(lua, [:embedded], encoded) + + assert {["from-private", "from-private"], _} = + Lua.eval!(lua, "return registered(), embedded()") + end + end +end