Skip to content

Allow Tensor/View ops outside XLADispatchMode by delegating to env.dispatch()#98

Open
lk-chen wants to merge 1 commit into
google:mainfrom
lk-chen:lkchen/dispatch-outside-mode
Open

Allow Tensor/View ops outside XLADispatchMode by delegating to env.dispatch()#98
lk-chen wants to merge 1 commit into
google:mainfrom
lk-chen:lkchen/dispatch-outside-mode

Conversation

@lk-chen

@lk-chen lk-chen commented Jun 14, 2026

Copy link
Copy Markdown

Tensor.__torch_dispatch__ and View.__torch_dispatch__ previously raised AssertionError unconditionally for any op other than wait_tensor/prim.device. This caused TypeError when torchax tensors were used in arithmetic outside an active XLADispatchMode context — for example during vLLM AOT lower, where XLADispatchMode is not installed but torchax Tensors and Views appear as arguments to aten ops (e.g. Tensor + View from deepstack vision features).

Fix: before raising, scan args/kwargs for the first Tensor or View and delegate to its _env.dispatch(). env.dispatch() calls v2t_iso() which materializes any View arguments to Tensors, then t2j_iso() converts to JAX arrays and runs the op. This matches the behaviour already provided by XLADispatchMode when that mode is active.

The View.__torch_dispatch__ change handles the View-as-left-operand case. The Tensor.__torch_dispatch__ change is the primary fix: for Tensor + View, Python dispatches to Tensor first (left operand), so View.__torch_dispatch__ is never reached unless Tensor's dispatch succeeds or returns NotImplemented.

…spatch()

Tensor.__torch_dispatch__ and View.__torch_dispatch__ previously raised
AssertionError unconditionally for any op other than wait_tensor/prim.device.
This caused TypeError when torchax tensors were used in arithmetic outside an
active XLADispatchMode context — for example during vLLM AOT lower, where
XLADispatchMode is not installed but torchax Tensors and Views appear as
arguments to aten ops (e.g. Tensor + View from deepstack vision features).

Fix: before raising, scan args/kwargs for the first Tensor or View and
delegate to its _env.dispatch(). env.dispatch() calls v2t_iso() which
materialises any View arguments to Tensors, then t2j_iso() converts to JAX
arrays and runs the op. This matches the behaviour already provided by
XLADispatchMode when that mode is active.

The View.__torch_dispatch__ change handles the View-as-left-operand case.
The Tensor.__torch_dispatch__ change is the primary fix: for Tensor + View,
Python dispatches to Tensor first (left operand), so View.__torch_dispatch__
is never reached unless Tensor's dispatch succeeds or returns NotImplemented.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant