diff --git a/.github/actions/run-inference/action.yml b/.github/actions/run-inference/action.yml
index 314a8be5..fe3d574b 100644
--- a/.github/actions/run-inference/action.yml
+++ b/.github/actions/run-inference/action.yml
@@ -3,7 +3,7 @@ description: Run one llama-tornado inference pass and write the metrics + sideca
inputs:
backend:
- description: 'GPU backend (opencl or ptx)'
+ description: 'GPU backend (opencl, ptx, or cuda)'
required: true
model_file:
description: 'Model filename inside $MODELS_DIR (e.g. Llama-3.2-1B-Instruct-F16.gguf)'
diff --git a/.github/actions/setup-tornadovm/action.yml b/.github/actions/setup-tornadovm/action.yml
index 3b1c5070..01fc41ac 100644
--- a/.github/actions/setup-tornadovm/action.yml
+++ b/.github/actions/setup-tornadovm/action.yml
@@ -3,17 +3,29 @@ description: Build TornadoVM once per backend and reuse across runs via a local
inputs:
backend:
- description: 'TornadoVM backend to build (opencl or ptx)'
+ description: 'TornadoVM backend to build (opencl, ptx, or cuda)'
required: true
runs:
using: composite
steps:
+ - name: Determine TornadoVM branch
+ id: branch
+ shell: bash
+ run: |
+ # The CUDA backend currently lives on the cuda2 branch (TornadoVM PR #861)
+ # until it is merged to master; all other backends build from master.
+ if [ "${{ inputs.backend }}" = "cuda" ]; then
+ echo "ref=cuda2" >> $GITHUB_OUTPUT
+ else
+ echo "ref=master" >> $GITHUB_OUTPUT
+ fi
+
- name: Get TornadoVM HEAD SHA
id: tornado_sha
shell: bash
run: |
- SHA=$(git ls-remote https://github.com/beehive-lab/TornadoVM HEAD | cut -f1)
+ SHA=$(git ls-remote https://github.com/beehive-lab/TornadoVM ${{ steps.branch.outputs.ref }} | cut -f1)
echo "sha=$SHA" >> $GITHUB_OUTPUT
- name: Check local build sentinel
@@ -27,12 +39,12 @@ runs:
echo "up-to-date=false" >> $GITHUB_OUTPUT
fi
- - name: Clone TornadoVM master
+ - name: Clone TornadoVM
if: steps.sentinel.outputs.up-to-date != 'true'
shell: bash
run: |
rm -rf $TORNADO_ROOT
- git clone --depth 1 --branch master \
+ git clone --depth 1 --branch ${{ steps.branch.outputs.ref }} \
https://github.com/beehive-lab/TornadoVM.git \
$TORNADO_ROOT
diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml
index f60083cc..3b33059c 100644
--- a/.github/workflows/build-and-run.yml
+++ b/.github/workflows/build-and-run.yml
@@ -18,7 +18,7 @@ env:
jobs:
code-quality:
if: github.repository == 'beehive-lab/GPULlama3.java'
- runs-on: self-hosted
+ runs-on: [self-hosted]
timeout-minutes: 30
steps:
@@ -31,7 +31,7 @@ jobs:
# ./mvnw -T12C -Pspotless spotless:check
# Build: TornadoVM → GPULlama3 → Quarkus LangChain4j
- # max-parallel: 1 ensures the opencl and ptx variants run sequentially so
+ # max-parallel: 1 ensures the opencl, ptx and cuda variants run sequentially so
# there are no workspace conflicts between matrix jobs.
build:
if: github.repository == 'beehive-lab/GPULlama3.java'
@@ -45,6 +45,7 @@ jobs:
backend:
- name: opencl
- name: ptx
+ - name: cuda
steps:
- name: Checkout GPULlama3
@@ -99,6 +100,7 @@ jobs:
backend:
- name: opencl
- name: ptx
+ - name: cuda
steps:
- name: Checkout GPULlama3
@@ -523,6 +525,7 @@ jobs:
backend:
- name: opencl
- name: ptx
+ - name: cuda
steps:
- name: Checkout GPULlama3
diff --git a/README.md b/README.md
index 2e2db217..9cb3a0a9 100644
--- a/README.md
+++ b/README.md
@@ -66,7 +66,8 @@ GPULlama3ChatModel model = GPULlama3ChatModel.builder()
Ensure you have the following installed and configured:
- **Java 21**: Required for Vector API support & TornadoVM.
-- [TornadoVM](https://github.com/beehive-lab/TornadoVM) with OpenCL or PTX backends.
+- [TornadoVM](https://github.com/beehive-lab/TornadoVM) with OpenCL, PTX, or CUDA backends.
+ - The `--cuda` backend requires a TornadoVM build that includes the CUDA backend from [TornadoVM PR #861](https://github.com/beehive-lab/TornadoVM/pull/861). This project currently builds against TornadoVM `4.0.2-jdk21-dev`.
- GCC/G++ 13 or newer: Required to build and run TornadoVM native components.
### Install, Build, and Run
@@ -305,6 +306,12 @@ Run a model with a text prompt:
./llama-tornado --gpu --verbose-init --opencl --model beehive-llama-3.2-1b-instruct-fp16.gguf --prompt "Explain the benefits of GPU acceleration."
```
+Select a backend explicitly with `--opencl`, `--ptx`, or `--cuda` (NVIDIA), or `--metal` (Apple Silicon). For example, to run on the CUDA backend:
+
+```bash
+./llama-tornado --gpu --cuda --model beehive-llama-3.2-1b-instruct-fp16.gguf --prompt "Explain the benefits of GPU acceleration."
+```
+
#### GPU Execution (FP16 Model)
Enable GPU acceleration with Q8_0 quantization:
```bash
@@ -393,7 +400,7 @@ Supported command-line options include:
```bash
cmd ➜ llama-tornado --help
usage: llama-tornado [-h] --model MODEL_PATH [--prompt PROMPT] [-sp SYSTEM_PROMPT] [--temperature TEMPERATURE] [--top-p TOP_P] [--seed SEED] [-n MAX_TOKENS]
- [--stream STREAM] [--echo ECHO] [-i] [--instruct] [--gpu] [--opencl] [--ptx] [--gpu-memory GPU_MEMORY] [--heap-min HEAP_MIN] [--heap-max HEAP_MAX]
+ [--stream STREAM] [--echo ECHO] [-i] [--instruct] [--gpu] [--opencl] [--ptx] [--cuda] [--metal] [--gpu-memory GPU_MEMORY] [--heap-min HEAP_MIN] [--heap-max HEAP_MAX]
[--debug] [--profiler] [--profiler-dump-dir PROFILER_DUMP_DIR] [--print-bytecodes] [--print-threads] [--print-kernel] [--full-dump]
[--show-command] [--execute-after-show] [--opencl-flags OPENCL_FLAGS] [--max-wait-events MAX_WAIT_EVENTS] [--verbose]
@@ -424,7 +431,9 @@ Mode Selection:
Hardware Configuration:
--gpu Enable GPU acceleration (default: False)
--opencl Use OpenCL backend (default) (default: None)
- --ptx Use PTX/CUDA backend (default: None)
+ --ptx Use PTX backend (default: None)
+ --cuda Use CUDA backend (requires TornadoVM built with the CUDA backend) (default: None)
+ --metal Use Apple Metal backend (macOS only) (default: None)
--gpu-memory GPU_MEMORY
GPU memory allocation (default: 7GB)
--heap-min HEAP_MIN Minimum JVM heap size (default: 20g)
@@ -480,9 +489,9 @@ View TornadoVM's internal behavior:
- **Support for GGUF format models** with full FP16 and partial support for Q8_0 and Q4_0 quantization.
- **Instruction-following and chat modes** for various use cases.
- **Interactive CLI** with `--interactive` and `--instruct` modes.
- - **Flexible backend switching** - choose OpenCL or PTX at runtime (need to build TornadoVM with both enabled).
+ - **Flexible backend switching** - choose OpenCL, PTX, or CUDA at runtime (need to build TornadoVM with the chosen backends enabled).
- **Cross-platform compatibility**:
- - ✅ NVIDIA GPUs (OpenCL & PTX )
+ - ✅ NVIDIA GPUs (OpenCL, PTX & CUDA)
- ✅ Intel GPUs (OpenCL)
- ✅ Apple GPUs (OpenCL)
diff --git a/llama-tornado b/llama-tornado
index 1d6c3d23..78388295 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
llama-tornado: GPU-accelerated Java LLM runner with TornadoVM
-Run LLM models using either OpenCL or PTX backends.
+Run LLM models using OpenCL, PTX, CUDA, or Metal backends.
"""
import argparse
@@ -19,6 +19,7 @@ from enum import Enum
class Backend(Enum):
OPENCL = "opencl"
PTX = "ptx"
+ CUDA = "cuda"
METAL = "metal"
@@ -178,6 +179,14 @@ class LlamaRunner:
"ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx",
]
)
+ elif args.backend == Backend.CUDA:
+ module_config.extend(
+ [
+ f"@{self.tornado_sdk}/etc/exportLists/cuda-exports",
+ "--add-modules",
+ "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.cuda",
+ ]
+ )
elif args.backend == Backend.METAL:
module_config.extend(
[
@@ -426,7 +435,14 @@ def create_parser() -> argparse.ArgumentParser:
dest="backend",
action="store_const",
const=Backend.PTX,
- help="Use PTX/CUDA backend",
+ help="Use PTX backend",
+ )
+ hw_group.add_argument(
+ "--cuda",
+ dest="backend",
+ action="store_const",
+ const=Backend.CUDA,
+ help="Use CUDA backend (requires TornadoVM built with the CUDA backend)",
)
hw_group.add_argument(
"--metal",
diff --git a/pom.xml b/pom.xml
index a83c7ecf..82e875e9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -39,9 +39,10 @@
0.4.0
- 4.0.1
+ 4.0.2
-jdk21
- ${tornadovm.base.version}${jdk.version.suffix}
+
+ ${tornadovm.base.version}${jdk.version.suffix}-dev
25
25
@@ -147,7 +148,8 @@
21
21
-jdk21
- ${tornadovm.base.version}${jdk.version.suffix}
+
+ ${tornadovm.base.version}${jdk.version.suffix}-dev