Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 53 additions & 61 deletions lib/lua.ex
Original file line number Diff line number Diff line change
Expand Up @@ -303,67 +303,11 @@ defmodule Lua do
raise Lua.RuntimeException, "Lua.set!/3 cannot have empty keys"
end

def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 1) do
def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 1) or is_function(func, 2) do
keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1)
{function_name, scope} = List.pop_at(keys, -1)

wrapped =
{:native_func,
fn args, state ->
return = List.wrap(func.(args))

if not Util.list_encoded?(return) do
{function_name, scope} = List.pop_at(keys, -1)

raise Lua.RuntimeException,
function: function_name,
scope: scope,
message: "deflua functions must return encoded data, got #{inspect(return)}"
end

{return, state}
end}

state = do_set_nested(lua.state, keys, wrapped)
%{lua | state: state}
end

def set!(%__MODULE__{} = lua, keys, func) when is_function(func, 2) do
keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1)

wrapped =
{:native_func,
fn args, state ->
{function_name, scope} = List.pop_at(keys, -1)

case func.(args, wrap(state)) do
{:error, reason, %__MODULE__{} = returned_lua} ->
raise RuntimeError, value: reason, state: returned_lua.state

{value, %__MODULE__{} = lua} ->
value = List.wrap(value)

if not Util.list_encoded?(value) do
raise Lua.RuntimeException,
function: function_name,
scope: scope,
message: "deflua functions must return encoded data, got #{inspect(value)}"
end

{value, lua.state}

value ->
value = List.wrap(value)

if not Util.list_encoded?(value) do
raise Lua.RuntimeException,
function: function_name,
scope: scope,
message: "deflua functions must return encoded data, got #{inspect(value)}"
end

{value, state}
end
end}
wrapped = wrap_callback(func, function_name, scope)

state = do_set_nested(lua.state, keys, wrapped)
%{lua | state: state}
Expand All @@ -372,12 +316,13 @@ defmodule Lua do
def set!(%__MODULE__{} = lua, keys, value) do
keys = keys |> List.wrap() |> Enum.map(&to_lua_key/1)
value = Display.unwrap(value)
{function_name, scope} = List.pop_at(keys, -1)

{encoded, state} =
if Util.encoded?(value) do
{value, lua.state}
else
Value.encode(value, lua.state)
Value.encode(value, lua.state, &wrap_callback(&1, function_name, scope))
end

state = do_set_nested(state, keys, encoded)
Expand Down Expand Up @@ -935,7 +880,7 @@ defmodule Lua do
if Util.encoded?(value) do
{value, lua}
else
{encoded, state} = Value.encode(value, state)
{encoded, state} = Value.encode(value, state, &wrap_callback(&1, :anonymous, []))
{encoded, %{lua | state: state}}
end
rescue
Expand Down Expand Up @@ -1215,6 +1160,53 @@ defmodule Lua do

defp wrap(state), do: %__MODULE__{state: state}

# Builds the `{:native_func, closure}` for a user-supplied Elixir callback.
#
# Both arities speak the documented `t:Lua.t/0` convention at the VM
# boundary: the raw `Lua.VM.State` handed in by the executor is wrapped as a
# `%Lua{}` for the callback, the returned `%Lua{}` is unwrapped back to its
# raw `.state`, and returns are validated as encoded. Sharing one builder
# keeps `Lua.set!/3` at a path, functions nested inside a `set!` value, and
# `Lua.encode!/2` on the same convention. `function_name`/`scope` feed only
# the error message.
defp wrap_callback(func, function_name, scope) when is_function(func, 1) do
{:native_func,
fn args, state ->
return = List.wrap(func.(args))
validate_encoded!(return, function_name, scope)
{return, state}
end}
end

defp wrap_callback(func, function_name, scope) when is_function(func, 2) do
{:native_func,
fn args, state ->
case func.(args, wrap(state)) do
{:error, reason, %__MODULE__{} = returned_lua} ->
raise RuntimeError, value: reason, state: returned_lua.state

{value, %__MODULE__{} = lua} ->
value = List.wrap(value)
validate_encoded!(value, function_name, scope)
{value, lua.state}

value ->
value = List.wrap(value)
validate_encoded!(value, function_name, scope)
{value, state}
end
end}
end

defp validate_encoded!(value, function_name, scope) do
if not Util.list_encoded?(value) do
raise Lua.RuntimeException,
function: function_name,
scope: scope,
message: "deflua functions must return encoded data, got #{inspect(value)}"
end
end

defp to_lua_key(key) when is_atom(key), do: Atom.to_string(key)
defp to_lua_key(key) when is_binary(key), do: key
defp to_lua_key(key), do: key
Expand Down
43 changes: 24 additions & 19 deletions lib/lua/vm/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -227,41 +227,37 @@ defmodule Lua.VM.Value do

Returns `{encoded_value, state}` since encoding maps and lists allocates tables.
"""
@spec encode(term(), State.t()) :: {term(), State.t()}
def encode(nil, state), do: {nil, state}
def encode(value, state) when is_boolean(value), do: {value, state}
def encode(value, state) when is_number(value), do: {value, state}
def encode(value, state) when is_binary(value), do: {value, state}
def encode(value, state) when is_atom(value), do: {Atom.to_string(value), state}

def encode(fun, state) when is_function(fun, 2), do: {{:native_func, fun}, state}

def encode(fun, state) when is_function(fun, 1) do
wrapper = fn args, st -> {List.wrap(fun.(args)), st} end
{{:native_func, wrapper}, state}
end
@spec encode(term(), State.t(), (fun() -> term())) :: {term(), State.t()}
def encode(value, state, fun_wrapper \\ &default_fun_wrapper/1)
def encode(nil, state, _fun_wrapper), do: {nil, state}
def encode(value, state, _fun_wrapper) when is_boolean(value), do: {value, state}
def encode(value, state, _fun_wrapper) when is_number(value), do: {value, state}
def encode(value, state, _fun_wrapper) when is_binary(value), do: {value, state}
def encode(value, state, _fun_wrapper) when is_atom(value), do: {Atom.to_string(value), state}

def encode(fun, state, fun_wrapper) when is_function(fun, 1) or is_function(fun, 2), do: {fun_wrapper.(fun), state}

def encode({:userdata, value}, state) do
def encode({:userdata, value}, state, _fun_wrapper) do
State.alloc_userdata(state, value)
end

def encode(map, state) when is_map(map) do
def encode(map, state, fun_wrapper) when is_map(map) do
{data, state} =
Enum.reduce(map, {%{}, state}, fn {k, v}, {data, state} ->
key = if is_atom(k), do: Atom.to_string(k), else: k
{encoded_v, state} = encode(v, state)
{encoded_v, state} = encode(v, state, fun_wrapper)
{Map.put(data, key, encoded_v), state}
end)

State.alloc_table(state, data)
end

def encode(list, state) when is_list(list) do
def encode(list, state, fun_wrapper) when is_list(list) do
if keyword_list?(list) do
{data, state} =
Enum.reduce(list, {%{}, state}, fn {k, v}, {data, state} ->
key = Atom.to_string(k)
{encoded_v, state} = encode(v, state)
{encoded_v, state} = encode(v, state, fun_wrapper)
{Map.put(data, key, encoded_v), state}
end)

Expand All @@ -271,14 +267,23 @@ defmodule Lua.VM.Value do
list
|> Enum.with_index(1)
|> Enum.reduce({%{}, state}, fn {v, idx}, {data, state} ->
{encoded_v, state} = encode(v, state)
{encoded_v, state} = encode(v, state, fun_wrapper)
{Map.put(data, idx, encoded_v), state}
end)

State.alloc_table(state, data)
end
end

# Default function wrapper, preserving the raw-state calling convention for
# low-level callers. `Lua.set!/3` and `Lua.encode!/2` inject their own
# wrapper so callbacks receive the public `t:Lua.t/0` instead.
defp default_fun_wrapper(fun) when is_function(fun, 2), do: {:native_func, fun}

defp default_fun_wrapper(fun) when is_function(fun, 1) do
{:native_func, fn args, st -> {List.wrap(fun.(args)), st} end}
end

@doc """
Encodes a list of Elixir values, threading state through each encoding.

Expand Down
89 changes: 89 additions & 0 deletions test/lua/callback_state_asymmetry_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
defmodule Lua.CallbackStateAsymmetryTest do
@moduledoc """
A two-arity Elixir callback (`fn args, state -> {results, state} end`)
must receive the same `state` argument — and accept the same return
shape — regardless of how it entered the VM:

| How the callback enters the VM | `state` it receives |
| --------------------------------------------------------------- | ------------------- |
| `Lua.set!(lua, [:f], fun)` — function directly at the path | `t:Lua.t/0` |
| `deflua` + `Lua.load_api/3` | `t:Lua.t/0` |
| `Lua.set!(lua, [:t], %{"f" => fun})` — function inside a value | `t:Lua.t/0` |
| `Lua.encode!(lua, fun)` — closure handed to Lua at runtime | `t:Lua.t/0` |

Every entry point hands the callback the public `t:Lua.t/0`, so the public
API (`Lua.decode!/2`, `Lua.encode!/2`, `Lua.get_private!/2`, …) works and a
single closure written against the documented `Lua.set!/3` convention works
in every position.
"""
use ExUnit.Case, async: true

defmodule DefluaProbe do
@moduledoc false
use Lua.API, scope: "probe"

deflua whoami(), state do
{[inspect(state.__struct__)], state}
end
end

# Reports the struct name of the state the VM handed the callback.
defp probe_callback do
fn _args, state -> {[inspect(state.__struct__)], state} end
end

# A callback written exactly the way the `Lua.set!/3` docs teach: treat
# `state` as a `t:Lua.t/0`, use the public API on it, return it as received.
defp documented_convention_callback do
fn _args, state -> {[Lua.get_private!(state, :secret)], state} end
end

describe "callbacks that receive the public Lua.t (documented convention)" do
test "a function set! directly at a path receives Lua.t" do
lua = Lua.set!(Lua.new(), [:probe], probe_callback())

assert {["Lua"], _} = Lua.eval!(lua, "return probe()")
end

test "a deflua function loaded via load_api receives Lua.t" do
lua = Lua.load_api(Lua.new(), DefluaProbe)

assert {["Lua"], _} = Lua.eval!(lua, "return probe.whoami()")
end

test "a set!-at-path callback can use the public Lua API on its state" do
lua =
Lua.new()
|> Lua.put_private(:secret, "from-private")
|> Lua.set!([:fetch], documented_convention_callback())

assert {["from-private"], _} = Lua.eval!(lua, "return fetch()")
end
end

describe "callbacks that enter the VM as encoded values" do
test "a closure embedded via encode! receives the same Lua.t" do
{fun, lua} = Lua.encode!(Lua.new(), probe_callback())
lua = Lua.set!(lua, [:probe], fun)

assert {["Lua"], _} = Lua.eval!(lua, "return probe()")
end

test "a function nested inside a table passed to set! receives the same Lua.t" do
lua = Lua.set!(Lua.new(), [:api], %{"probe" => probe_callback()})

assert {["Lua"], _} = Lua.eval!(lua, "return api.probe()")
end

test "the same callback works identically via set! and via encode!" do
lua = Lua.put_private(Lua.new(), :secret, "from-private")

lua = Lua.set!(lua, [:registered], documented_convention_callback())
{encoded, lua} = Lua.encode!(lua, documented_convention_callback())
lua = Lua.set!(lua, [:embedded], encoded)

assert {["from-private", "from-private"], _} =
Lua.eval!(lua, "return registered(), embedded()")
end
end
end
Loading