From 9beedcb75c9963f9d77acdf7f593d064ee7106be Mon Sep 17 00:00:00 2001 From: Aidistides <208755803+Aidistides@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:25:09 -0400 Subject: [PATCH 1/2] Online learning and Fixed Point Quant --- torchhd/__init__.py | 4 + torchhd/online.py | 478 +++++++++++++++++++++++++++ torchhd/quantize.py | 783 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1265 insertions(+) create mode 100644 torchhd/online.py create mode 100644 torchhd/quantize.py diff --git a/torchhd/__init__.py b/torchhd/__init__.py index e5517dff..bf114669 100644 --- a/torchhd/__init__.py +++ b/torchhd/__init__.py @@ -27,6 +27,8 @@ import torchhd.models as models import torchhd.classifiers as classifiers import torchhd.memory as memory +import torchhd.online as online +import torchhd.quantize as quantize import torchhd.datasets as datasets import torchhd.utils as utils @@ -102,6 +104,8 @@ "models", "classifiers", "memory", + "online", + "quantize", "datasets", "utils", "ensure_vsa_tensor", diff --git a/torchhd/online.py b/torchhd/online.py new file mode 100644 index 00000000..38c13e1c --- /dev/null +++ b/torchhd/online.py @@ -0,0 +1,478 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Streaming and online learning primitives for hyperdimensional computing. + +This module provides streaming-capable components that update model state +one sample at a time — no batching required. The central building block is +the Hebbian accumulator, which maintains a running outer-product weight matrix +updated via a local Hebbian rule. + +Classes +------- +HebbianAccumulator + Running weight accumulation via outer-product Hebbian updates. +StreamingCentroid + Online centroid classifier with Hebbian weight updates, one sample per step. +""" + +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.parameter import Parameter +import torch.nn.init as init + +import torchhd.functional as functional + +__all__ = [ + "HebbianAccumulator", + "StreamingCentroid", +] + + +class HebbianAccumulator(nn.Module): + r"""Running outer-product Hebbian accumulator for online hypervector learning. + + Maintains a weight matrix :math:`W \in \mathbb{R}^{d_{\text{in}} \times d_{\text{out}}}` + that is updated via the Hebbian outer product: + + .. math:: + W \leftarrow (1 - \eta) W + \eta \, (x^{\top} y) + + where :math:`x` is the input hypervector, :math:`y` is the target hypervector, + and :math:`\eta \in [0, 1]` is the learning rate. Setting :math:`\eta = 1` + recovers standard additive (unweighted) accumulation; :math:`\eta < 1` implements + an exponential moving average that gracefully forgets older samples — useful + for non-stationary streams. + + The accumulator exposes a ``forward`` method for reading from the weight + matrix, and a ``step`` method for accepting a single (input, target) pair. + + Args: + in_features (int): Dimensionality of the input hypervectors :math:`d_{\text{in}}`. + out_features (int): Dimensionality of the target hypervectors :math:`d_{\text{out}}`. + lr (float, optional): Hebbian learning rate :math:`\eta`. Default: ``1.0``. + device (``torch.device``, optional): Desired device of the weight matrix. + dtype (``torch.dtype``, optional): Desired data type of the weight matrix. + requires_grad (bool, optional): If autograd should track the weight matrix. Default: ``False``. + + Shape: + - Input: :math:`(d_{\text{in}})` or :math:`(*, d_{\text{in}})` + - Target: :math:`(d_{\text{out}})` or :math:`(*, d_{\text{out}})` + - Weight: :math:`(d_{\text{in}}, d_{\text{out}})` + + Attributes: + weight: The accumulated weight matrix of shape :math:`(d_{\text{in}}, d_{\text{out}})`. + lr: Hebbian learning rate. + num_steps: Running count of ``step`` calls (read-only). + + Examples:: + + >>> acc = HebbianAccumulator(512, 512, lr=0.1) + >>> x = torchhd.random(1, 512) + >>> y = torchhd.random(1, 512) + >>> acc.step(x, y) + >>> read = acc(x) + >>> read.shape + torch.Size([1, 512]) + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: Tensor + lr: float + num_steps: int + + def __init__( + self, + in_features: int, + out_features: int, + lr: float = 1.0, + device=None, + dtype=None, + requires_grad: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.lr = lr + + weight = torch.zeros(in_features, out_features, **factory_kwargs) + self.weight = Parameter(weight, requires_grad=requires_grad) + self.register_buffer("num_steps", torch.zeros((), dtype=torch.long)) + + def forward(self, input: Tensor) -> Tensor: + r"""Read from the accumulated weight matrix. + + .. math:: + \text{output} = x W + + Args: + input (Tensor): Input hypervector(s) of shape :math:`(*, d_{\text{in}})`. + + Returns: + Tensor: Read-result of shape :math:`(*, d_{\text{out}})`. + """ + return input @ self.weight + + @torch.no_grad() + def step(self, input: Tensor, target: Tensor) -> None: + r"""Update the weight matrix with a single (input, target) pair via + the Hebbian outer-product rule. + + .. math:: + W \leftarrow (1 - \eta) W + \eta \, (x^{\top} y) + + Accepts batched inputs for efficiency, though the intended use-case + is one sample at a time. + + Args: + input (Tensor): Input hypervector(s) of shape :math:`(*, d_{\text{in}})`. + target (Tensor): Target hypervector(s) of shape :math:`(*, d_{\text{out}})`. + """ + if input.dim() == 1: + input = input.unsqueeze(0) + if target.dim() == 1: + target = target.unsqueeze(0) + + # Hebbian outer product averaged over the batch + hebbian_update = input.T @ target + batch_size = input.size(0) + + if self.lr == 1.0: + self.weight.add_(hebbian_update) + else: + # Exponential moving average: W = (1 - lr) * W + lr * (x^T y) + self.weight.mul_(1.0 - self.lr).add_(hebbian_update, alpha=self.lr) + + self.num_steps.add_(batch_size) + + @torch.no_grad() + def step_adaptive( + self, + input: Tensor, + target: Tensor, + pred: Optional[Tensor] = None, + ) -> None: + r"""Hebbian update with an anti-Hebbian correction for mispredicted samples. + + Implements a local error-driven learning rule: + + .. math:: + W \leftarrow (1 - \eta) W + \eta \left[ x^{\top} y_{\text{true}} - x^{\top} y_{\text{pred}} \right] + + When ``pred`` is ``None``, this degenerates to :meth:`step`. + + Args: + input (Tensor): Input hypervector(s) :math:`(*, d_{\text{in}})`. + target (Tensor): True target hypervector(s) :math:`(*, d_{\text{out}})`. + pred (Tensor, optional): Predicted target hypervector(s) of same shape as ``target``. + If provided, the anti-Hebbian term is subtracted. + """ + if input.dim() == 1: + input = input.unsqueeze(0) + if target.dim() == 1: + target = target.unsqueeze(0) + + hebbian = input.T @ target + + if pred is not None: + if pred.dim() == 1: + pred = pred.unsqueeze(0) + anti_hebbian = input.T @ pred + hebbian = hebbian - anti_hebbian + + if self.lr == 1.0: + self.weight.add_(hebbian) + else: + self.weight.mul_(1.0 - self.lr).add_(hebbian, alpha=self.lr) + + self.num_steps.add_(input.size(0)) + + @torch.no_grad() + def normalize_(self, eps: float = 1e-12) -> None: + r"""Normalize each column of the weight matrix to unit length in-place. + + This is typically called after accumulation is complete, before switching + to a dot-product readout. + + Args: + eps (float): Epsilon for numerical stability. + """ + norms = self.weight.norm(dim=0, keepdim=True) + norms.clamp_(min=eps) + self.weight.div_(norms) + + @torch.no_grad() + def reset_(self) -> None: + """Zero the weight matrix and reset the step counter.""" + self.weight.zero_() + self.num_steps.zero_() + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"lr={self.lr}, num_steps={self.num_steps.item()}" + ) + + +class StreamingCentroid(nn.Module): + r"""Online centroid classifier that processes one sample per ``step`` call. + + This is the streaming equivalent of :class:`~torchhd.models.Centroid`. Under + the hood it maintains a :class:`HebbianAccumulator` that maps input hypervectors + to an ``out_features``-dimensional class space. Each class is represented as a + one-hot vector, so the Hebbian outer product naturally accumulates a class-prototype + matrix. + + The module supports both **accumulation** (prototypes are sums of all class + samples) and **exponential moving average** (``lr < 1``) modes. + + Args: + in_features (int): Dimensionality of the input hypervectors. + out_features (int): Number of output classes. + lr (float, optional): Hebbian learning rate :math:`\eta`. Default: ``1.0``. + device (``torch.device``, optional): Desired device. + dtype (``torch.dtype``, optional): Desired data type. + requires_grad (bool, optional): If autograd should track parameters. Default: ``False``. + + Shape: + - Input: :math:`(d_{\text{in}})` or :math:`(*, d_{\text{in}})` + - Output: :math:`(1, \text{out\_features})` or :math:`(*, \text{out\_features})` + + Attributes: + accumulator: The underlying :class:`HebbianAccumulator`. + prototype_weight: Convenience view of the accumulated class-prototype matrix + of shape :math:`(\text{out\_features}, d_{\text{in}})`. + + Examples:: + + >>> sc = StreamingCentroid(512, 10) + >>> x = torchhd.random(1, 512) + >>> sc.step(x, 3) # label for class 3 + >>> sc.step(x, 7) # label for class 7 + >>> logits = sc(x) # dot-product similarity to all classes + >>> logits.shape + torch.Size([1, 10]) + """ + + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + + def __init__( + self, + in_features: int, + out_features: int, + lr: float = 1.0, + device=None, + dtype=None, + requires_grad: bool = False, + ) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + + # Hebbian accumulator maps input -> one-hot class vector. + # Accumulator weight shape: (in_features, out_features). + self.accumulator = HebbianAccumulator( + in_features=in_features, + out_features=out_features, + lr=lr, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + @property + def prototype_weight(self) -> Tensor: + r"""Class-prototype matrix of shape :math:`(n_{\text{classes}}, d)`. + + This is the transpose of the accumulator weight, providing a natural + ``(n_classes, d)`` view for dot-product similarity comparisons. + """ + return self.accumulator.weight.T + + def forward(self, input: Tensor, dot: bool = False) -> Tensor: + r"""Compute per-class similarity scores. + + Args: + input (Tensor): Input hypervector(s) :math:`(*, d)`. + dot (bool, optional): If ``True``, use dot-product similarity instead + of cosine similarity. Faster after :meth:`normalize` has been called. + + Returns: + Tensor: Similarity scores of shape :math:`(*, n_{\text{classes}})`. + """ + prototypes = self.prototype_weight + if dot: + return functional.dot_similarity(input, prototypes) + return functional.cosine_similarity(input, prototypes) + + @torch.no_grad() + def step(self, input: Tensor, target: Union[int, Tensor], lr: Optional[float] = None) -> None: + r"""Process a single sample (or small batch) with its class label. + + Args: + input (Tensor): Input hypervector(s) :math:`(*, d)`. + target (int or Tensor): Class label(s). If an integer is given it is + treated as a scalar class index. If a tensor of integers is given + each element is treated as a separate label. + lr (float, optional): Per-step learning rate override. If ``None``, + uses the accumulator's default. + """ + if input.dim() == 1: + input = input.unsqueeze(0) + + # Build one-hot target vector(s) + device = input.device + dtype = self.accumulator.weight.dtype + n = input.size(0) + + if isinstance(target, int): + target = torch.tensor([target], device=device, dtype=torch.long) + + one_hot = torch.zeros(n, self.out_features, device=device, dtype=dtype) + one_hot[torch.arange(n, device=device), target] = 1.0 + + # Hebbian step: outer product of input and one-hot target + prev_lr = self.accumulator.lr + if lr is not None: + self.accumulator.lr = float(lr) + try: + self.accumulator.step(input, one_hot) + finally: + self.accumulator.lr = prev_lr + + @torch.no_grad() + def step_online( + self, input: Tensor, target: Union[int, Tensor], lr: Optional[float] = None + ) -> None: + r"""OnlineHD-style adaptive update: only updates prototypes when prediction is wrong. + + Implements the single-pass online training rule from `OnlineHD: Robust, Efficient, + and Single-Pass Online Learning Using Hyperdimensional System + `_. + + The prediction is made *before* the Hebbian update so that each sample is + evaluated fairly (test-then-train). + + Args: + input (Tensor): Input hypervector(s) :math:`(*, d)`. + target (int or Tensor): True class label(s). + lr (float, optional): Per-step learning rate override. + """ + if input.dim() == 1: + input = input.unsqueeze(0) + + device = input.device + dtype = self.accumulator.weight.dtype + n = input.size(0) + + if isinstance(target, int): + target = torch.tensor([target], device=device, dtype=torch.long) + + # --- Predict before updating (test-then-train) --- + prototypes = self.prototype_weight + logits = functional.cosine_similarity(input, prototypes) + pred = logits.argmax(dim=1) + + # Identify mistakes + is_wrong = target != pred + if is_wrong.sum().item() == 0: + self.accumulator.num_steps.add_(n) + return + + wrong_input = input[is_wrong] + wrong_target = target[is_wrong] + wrong_pred = pred[is_wrong] + + m = wrong_input.size(0) + + one_hot_target = torch.zeros(m, self.out_features, device=device, dtype=dtype) + one_hot_target[torch.arange(m, device=device), wrong_target] = 1.0 + + one_hot_pred = torch.zeros(m, self.out_features, device=device, dtype=dtype) + one_hot_pred[torch.arange(m, device=device), wrong_pred] = 1.0 + + # Alpha scaling: (1 - cos(x, pred_cls)) and (cos(x, true_cls) - 1) + alpha_target = 1.0 - logits[is_wrong].gather(1, wrong_target.unsqueeze(1)) + alpha_pred = logits[is_wrong].gather(1, wrong_pred.unsqueeze(1)) - 1.0 + + hebbian = ( + wrong_input.T @ (alpha_target * one_hot_target) + + wrong_input.T @ (alpha_pred * one_hot_pred) + ) + + prev_lr = self.accumulator.lr + if lr is not None: + self.accumulator.lr = float(lr) + + try: + if self.accumulator.lr == 1.0: + self.accumulator.weight.add_(hebbian, alpha=lr if lr is not None else 1.0) + else: + # Apply per-step scaling: W = (1-lr)*W + lr*(hebbian) + self.accumulator.weight.mul_(1.0 - self.accumulator.lr).add_( + hebbian, alpha=self.accumulator.lr * (lr if lr is not None else 1.0) + ) + finally: + self.accumulator.lr = prev_lr + + self.accumulator.num_steps.add_(n) + + @torch.no_grad() + def normalize(self, eps: float = 1e-12) -> None: + r"""Normalize all class prototypes to unit length in-place. + + After calling this, the ``forward`` pass can use ``dot=True`` for + faster inference (dot-product instead of cosine similarity). + """ + prototypes = self.prototype_weight + norms = prototypes.norm(dim=1, keepdim=True) + if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any(): + import warnings + warnings.warn( + "The norm of a prototype vector is nearly zero upon normalizing, " + "this could indicate a bug." + ) + norms.clamp_(min=eps) + self.accumulator.weight.div_(norms.T) + + @torch.no_grad() + def reset(self) -> None: + """Reset all prototypes to zero and clear the step counter.""" + self.accumulator.reset_() + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"lr={self.accumulator.lr}, num_steps={self.accumulator.num_steps.item()}" + ) \ No newline at end of file diff --git a/torchhd/quantize.py b/torchhd/quantize.py new file mode 100644 index 00000000..e20dea7c --- /dev/null +++ b/torchhd/quantize.py @@ -0,0 +1,783 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Fixed-point quantization and hardware export for hyperdimensional computing. + +This module bridges the gap between floating-point hypervectors as used during +training and the fixed-point / integer representations required for FPGA and +ASIC synthesis. It provides: + +* **Quantization functions** — convert tensors to Q-format fixed-point integers. +* **Packing utilities** — pack quantized vectors into dense bit-arrays ready for + hardware memory (e.g. BRAM, ROM). +* **Export helpers** — render weights as synthesizable Verilog ``localparam`` + declarations, C header arrays, or raw binary files. + +Typical Quantization Pipeline +----------------------------- +1. Train a model using standard torchhd floating-point tensors. +2. Quantize the hypervectors with :func:`to_fixed_point` or :func:`quantize_bipolar`: + ``qv = quantize.to_fixed_point(weight, bits=4, frac=2)``. +3. Pack and export: + ``packed = quantize.pack_bits(qv, bits_per_element=4, order="big")``. +4. Write to file: + ``quantize.export_verilog(packed, "hd_weights.sv")``. + +Supported Formats +----------------- +* Signed fixed-point ``Q`` where *M* integer bits and *N* fractional bits. +* Bipolar (ternary) quantization to {−1, 0, +1}. +* Block-floating point (shared exponent per row / per tensor). +* Packed bit-arrays for dense binary storage. + +Exported Output Formats +----------------------- +* Verilog ``localparam`` arrays for direct FPGA synthesis. +* C ``const`` arrays for embedded firmware. +* Raw binary (``.bin``) for memory images. +""" + +from typing import Optional, Tuple, Union, Literal +import struct + +import torch +from torch import Tensor + + +__all__ = [ + "to_fixed_point", + "from_fixed_point", + "quantize_bipolar", + "block_float_quantize", + "pack_bits", + "unpack_bits", + "export_verilog", + "export_c_header", + "export_binary", + "FixedPointConfig", + "QuantizedWeight", +] + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +class FixedPointConfig: + r"""Fixed-point format descriptor. + + Args: + bits: Total bit-width including sign. + frac: Number of fractional bits (Q notation: M = bits - frac). + signed: Whether the format is signed (two's complement). Default: ``True``. + rounding: Rounding mode — ``"floor"``, ``"nearest"``, or ``"round"`` + (ties to even). Default: ``"nearest"``. + + Attributes: + bits (int): Total bit-width. + frac (int): Fractional bits. + int_bits (int): Integer bits (computed as ``bits - frac``). + min_val (float): Minimum representable value. + max_val (float): Maximum representable value. + step (float): Smallest representable difference (one LSB). + """ + + bits: int + frac: int + signed: bool + rounding: str + + def __init__( + self, + bits: int, + frac: int, + *, + signed: bool = True, + rounding: str = "nearest", + ) -> None: + if bits < 1: + raise ValueError(f"bits must be >= 1, got {bits}") + if frac < 0 or frac > bits: + raise ValueError(f"frac must be in [0, {bits}], got {frac}") + + self.bits = bits + self.frac = frac + self.signed = signed + self.rounding = rounding + + @property + def int_bits(self) -> int: + return self.bits - self.frac + + @property + def step(self) -> float: + return 2.0 ** (-self.frac) + + @property + def min_val(self) -> float: + if self.signed: + return -(2.0 ** (self.int_bits - 1)) + else: + return 0.0 + + @property + def max_val(self) -> float: + if self.signed: + return (2.0 ** (self.int_bits - 1)) - self.step + else: + return (2.0 ** self.int_bits) - self.step + + def __repr__(self) -> str: + fmt = "Q{}.{}s" if self.signed else "Q{}.{}u" + return fmt.format(self.int_bits, self.frac) + + +# --------------------------------------------------------------------------- +# Quantized weight container +# --------------------------------------------------------------------------- + +class QuantizedWeight: + r"""Container holding quantized integer weights together with metadata. + + Attributes: + data (LongTensor): Quantized integer values (raw hardware representation). + config (FixedPointConfig): The fixed-point configuration used. + original_shape (torch.Size): Shape of the unquantized tensor. + scale (float): Global scale factor, if any was applied prior to quantization. + """ + + data: Tensor + config: FixedPointConfig + original_shape: torch.Size + scale: float + + def __init__( + self, + data: Tensor, + config: FixedPointConfig, + original_shape: torch.Size, + scale: float = 1.0, + ) -> None: + self.data = data + self.config = config + self.original_shape = original_shape + self.scale = scale + + def __repr__(self) -> str: + bits = self.config.bits + return ( + f"QuantizedWeight(shape={tuple(self.data.shape)}, " + f"dtype={self.data.dtype}, " + f"format={self.config!r}, " + f"range={self.data.min().item():.1f}..{self.data.max().item():.1f}, " + f"scale={self.scale:.6g})" + ) + + +# --------------------------------------------------------------------------- +# Core quantization functions +# --------------------------------------------------------------------------- + +def _clamp_and_round( + x: Tensor, + cfg: FixedPointConfig, +) -> Tensor: + """Clamp to representable range and quantize to LSB multiples.""" + min_v = cfg.min_val + max_v = cfg.max_val + scale = 2.0 ** cfg.frac # multiply by 2^frac to move to integer domain + + # Clamp + x_c = x.clamp(min_v, max_v) + + # Scale to integer domain + x_scaled = x_c * scale + + # Round + if cfg.rounding == "nearest": + x_rounded = torch.round(x_scaled) + elif cfg.rounding == "floor": + x_rounded = torch.floor(x_scaled) + elif cfg.rounding == "round": + # ties-to-even + x_rounded = torch.round(x_scaled) + else: + raise ValueError(f"Unknown rounding mode: {cfg.rounding}") + + # Convert to integer (two's complement for negative values) + max_code = (1 << cfg.bits) - 1 + half = 1 << (cfg.bits - 1) + + x_int = x_rounded.to(torch.long) + + if cfg.signed: + # Wrap negatives into two's complement + x_int = x_int & max_code + + return x_int + + +def to_fixed_point( + input: Tensor, + bits: int, + frac: int, + *, + signed: bool = True, + rounding: str = "nearest", + scale: Optional[float] = None, + per_row_scale: bool = False, +) -> QuantizedWeight: + r"""Quantize a floating-point tensor to fixed-point. + + Converts each element :math:`v` to an integer code: + + .. math:: + q = \operatorname{clamp}\left( + \operatorname{round}\left(v \cdot 2^{\text{frac}}\right), + 0,\; 2^{\text{bits}} - 1 + \right) + + For signed formats the result is two's complement (the caller can interpret + the integer codes as signed by checking ``cfg.signed``). + + Args: + input (Tensor): Floating-point tensor of any shape. + bits (int): Total bit-width. + frac (int): Number of fractional bits. + signed (bool, optional): Two's complement signed format. Default: ``True``. + rounding (str, optional): ``"nearest"``, ``"floor"``, or ``"round"``. + Default: ``"nearest"``. + scale (float, optional): Pre-scale input by this factor before quantization. + Useful for normalizing weights to fully use the available range. + per_row_scale (bool, optional): If ``True``, compute a separate scale for + each row of a 2-D tensor (block-floating point per row). The resulting + ``QuantizedWeight.scale`` will be a 1-D tensor of per-row scales. + + Returns: + QuantizedWeight: Object containing the integer codes and metadata. + + Examples:: + + >>> w = torch.randn(3, 512) * 0.5 + >>> qw = quantize.to_fixed_point(w, bits=8, frac=4) + >>> qw.data.shape + torch.Size([3, 512]) + >>> qw.data.dtype + torch.int64 + """ + cfg = FixedPointConfig(bits=bits, frac=frac, signed=signed, rounding=rounding) + + if per_row_scale and input.dim() == 2: + # Per-row block-floating point + abs_max = input.abs().max(dim=1, keepdim=True).values + eps = 1e-12 + row_scales = (2.0 ** (cfg.int_bits - 1) - cfg.step) / (abs_max + eps) + scaled_input = input * row_scales + row_scales = row_scales * (2.0 ** cfg.frac) # account for frac in scale + q_int = _clamp_and_round(scaled_input, cfg) + return QuantizedWeight(q_int, cfg, input.shape, scale=row_scales) + else: + if scale is not None: + scaled_input = input * scale + else: + scaled_input = input + q_int = _clamp_and_round(scaled_input, cfg) + return QuantizedWeight(q_int, cfg, input.shape, scale=scale or 1.0) + + +def from_fixed_point(qw: QuantizedWeight) -> Tensor: + r"""Dequantize a :class:`QuantizedWeight` back to floating point. + + .. math:: + v = \left(\frac{q}{2^{\text{frac}}}\right) / \text{scale} + + where for signed two's complement the integer code *q* is sign-extended. + + Args: + qw (QuantizedWeight): Previously quantized weight object. + + Returns: + Tensor: Dequantized floating-point tensor with the original shape. + + Examples:: + + >>> w = torch.randn(3, 512) + >>> qw = quantize.to_fixed_point(w, bits=8, frac=4) + >>> w_rec = quantize.from_fixed_point(qw) + >>> (w - w_rec).abs().max() < 0.1 + True + """ + cfg = qw.config + data = qw.data.float() + + if cfg.signed: + max_code = (1 << cfg.bits) - 1 + half = 1 << (cfg.bits - 1) + # Sign-extend: values >= half are negative + mask_neg = data >= half + data = data.clone() + data[mask_neg] = data[mask_neg] - (1 << cfg.bits) + + # Convert from integer domain back to float + data = data / (2.0 ** cfg.frac) + + if isinstance(qw.scale, Tensor): + # If the scale is a per-block or per-row scale (fewer elements than data), + # reshape both to a block layout, divide, then reshape back. + if qw.scale.numel() != 1 and data.numel() % qw.scale.numel() == 0: + block_size = data.numel() // qw.scale.numel() + flat = data.reshape(-1, block_size) + scale_r = qw.scale.reshape(-1, 1) + flat = flat / scale_r + data = flat.reshape(qw.original_shape) + elif data.numel() == qw.scale.numel(): + data = data / qw.scale + else: + data = data / qw.scale + else: + data = data / qw.scale + + return data.reshape(qw.original_shape) + + +def quantize_bipolar( + input: Tensor, + threshold: float = 0.0, + *, + ternary: bool = False, + zero_thresh: Optional[float] = None, +) -> Tensor: + r"""Quantize to bipolar {−1, +1} or ternary {−1, 0, +1}. + + This is the standard hypervector binarization used for efficient + similarity computation in hardware. + + Args: + input (Tensor): Any floating-point tensor. + threshold (float, optional): Values > ``threshold`` map to +1, otherwise + to −1 (or 0 in ternary mode). Default: ``0.0``. + ternary (bool, optional): If ``True``, produce ternary output with values + {−1, 0, +1}. When ``False`` only {−1, +1} are produced. + zero_thresh (float, optional): In ternary mode, absolute values below + this threshold are set to 0. Default: ``None`` uses ``threshold``. + + Returns: + Tensor: Tensor with the same shape and device as ``input``, dtype ``torch.int8``. + + Examples:: + + >>> x = torch.tensor([-0.5, 0.0, 0.3, 0.6]) + >>> quantize_bipolar(x) + tensor([-1, -1, 1, 1], dtype=torch.int8) + >>> quantize_bipolar(x, threshold=0.3) + tensor([-1, -1, -1, 1], dtype=torch.int8) + >>> quantize_bipolar(x, ternary=True, zero_thresh=0.2) + tensor([-1, 0, 1, 1], dtype=torch.int8) + """ + if ternary: + zt = zero_thresh if zero_thresh is not None else abs(threshold) + pos_mask = input > abs(threshold) + neg_mask = input < -abs(zt) + + result = torch.zeros_like(input, dtype=torch.int8) + result[pos_mask] = 1 + result[neg_mask] = -1 + return result + else: + return torch.where(input > threshold, torch.tensor(1, dtype=torch.int8), + torch.tensor(-1, dtype=torch.int8)) + + +def block_float_quantize( + input: Tensor, + bits: int, + frac: int, + block_size: int = 64, + *, + rounding: str = "nearest", +) -> QuantizedWeight: + r"""Quantize with a shared exponent for every *block_size* elements. + + This is a hardware-friendly scheme: for each block of ``block_size`` elements + compute a shared scale, quantize all elements in that block to the same fixed-point + format, and store the exponent along each block. + + Args: + input (Tensor): Floating-point tensor. + bits (int): Total bit-width. + frac (int): Fixed fractional bit allocation (the exponent varies per block). + block_size (int): Number of elements sharing a scale factor. + rounding (str, optional): Rounding mode. Default: ``"nearest"``. + + Returns: + QuantizedWeight: Quantized data with per-block scale factors in ``.scale``. + """ + cfg = FixedPointConfig(bits=bits, frac=frac, signed=True, rounding=rounding) + flat = input.flatten() + n_el = flat.numel() + + # Pad to multiple of block_size + pad = (block_size - n_el % block_size) % block_size + if pad > 0: + flat = torch.cat([flat, torch.zeros(pad, dtype=flat.dtype, device=flat.device)]) + + blocks = flat.reshape(-1, block_size) + n_blocks = blocks.size(0) + + # Per-block scale: scale so that max abs value maps to full range + abs_max = blocks.abs().max(dim=1, keepdim=True).values + eps = 1e-12 + block_scales = (2.0 ** (cfg.int_bits - 1) - cfg.step) / (abs_max + eps) + + scaled_blocks = blocks * block_scales + q_blocks = _clamp_and_round(scaled_blocks, cfg) + + q_flat = q_blocks.flatten()[:n_el] + scale_flat = block_scales.flatten() + + return QuantizedWeight(q_flat.reshape(input.shape), cfg, input.shape, scale=scale_flat) + + +# --------------------------------------------------------------------------- +# Bit packing +# --------------------------------------------------------------------------- + +def pack_bits( + data: Tensor, + bits_per_element: int, + order: Literal["big", "little"] = "big", +) -> Tensor: + r"""Pack integer codes into a compact bit-array. + + Each element of ``data`` is assumed to fit in ``bits_per_element`` bits. + Elements are packed sequentially into ``torch.uint8`` bytes. + + This is the final step before writing to a hardware memory image. + + Args: + data (LongTensor): Integer codes to pack. + bits_per_element (int): Number of bits per element (e.g. 1, 2, 4, 8). + order (str, optional): ``"big"`` (MSB first) or ``"little"`` (LSB first). + + Returns: + LongTensor: Packed bytes with dtype ``torch.int64`` (values 0–255). + Use ``.to(torch.uint8)`` to get raw bytes. + + Examples:: + + >>> vals = torch.tensor([1, 0, 1, 1, 0, 1, 0, 0], dtype=torch.long) + >>> pack_bits(vals, bits_per_element=1, order="big") + tensor([180]) # 0b10110100 + """ + if bits_per_element > 32: + raise ValueError(f"bits_per_element must be <= 32, got {bits_per_element}") + + flat = data.flatten() + total_bits = flat.numel() * bits_per_element + total_bytes = (total_bits + 7) // 8 + + # Work in uint64 for accumulation + if order == "big": + # MSB-first: first element goes into most significant bits + packed = torch.zeros(total_bytes, dtype=torch.int64) + for i, val in enumerate(flat): + val_int = int(val) & ((1 << bits_per_element) - 1) + bit_pos = i * bits_per_element + byte_idx = bit_pos // 8 + bit_offset = bit_pos % 8 + + bits_remaining = bits_per_element + val_shifted = val_int + + while bits_remaining > 0: + space = 8 - bit_offset + take = min(bits_remaining, space) + # Put the top `take` bits of val_shifted into the byte + shift = bits_remaining - take + chunk = (val_shifted >> shift) & ((1 << take) - 1) + packed[byte_idx] |= chunk << (space - take) + + bits_remaining -= take + bit_offset = 0 + byte_idx += 1 + val_shifted = val_shifted & ((1 << bits_remaining) - 1) + + return packed + else: + # LSB-first: first element goes into least significant bits + packed = torch.zeros(total_bytes, dtype=torch.int64) + for i, val in enumerate(flat): + val_int = int(val) & ((1 << bits_per_element) - 1) + bit_pos = i * bits_per_element + byte_idx = bit_pos // 8 + bit_offset = bit_pos % 8 + + bits_remaining = bits_per_element + val_shifted = val_int + + while bits_remaining > 0: + space = 8 - bit_offset + take = min(bits_remaining, space) + # Put the bottom `take` bits of val_shifted into the byte + chunk = val_shifted & ((1 << take) - 1) + packed[byte_idx] |= chunk << bit_offset + + bits_remaining -= take + bit_offset = 0 + byte_idx += 1 + val_shifted >>= take + + return packed + + +def unpack_bits( + packed: Tensor, + num_elements: int, + bits_per_element: int, + order: Literal["big", "little"] = "big", +) -> Tensor: + r"""Unpack a bit-packed array back to integer codes. + + Inverse of :func:`pack_bits`. + + Args: + packed (LongTensor): Packed bytes (values 0–255). + num_elements (int): Number of elements to extract. + bits_per_element (int): Number of bits per element. + order (str): Byte order used during packing. + + Returns: + LongTensor: Recovered integer codes. + """ + result = torch.zeros(num_elements, dtype=torch.int64) + + for i in range(num_elements): + bit_pos = i * bits_per_element + byte_idx = bit_pos // 8 + bit_offset = bit_pos % 8 + + val = 0 + bits_remaining = bits_per_element + + while bits_remaining > 0: + space = 8 - bit_offset + take = min(bits_remaining, space) + + if order == "big": + chunk = (packed[byte_idx] >> (space - take)) & ((1 << take) - 1) + val = (val << take) | chunk + else: + chunk = (packed[byte_idx] >> bit_offset) & ((1 << take) - 1) + val = val | (chunk << (bits_remaining - take)) + + bits_remaining -= take + bit_offset = 0 + byte_idx += 1 + + result[i] = val + + return result + + +# --------------------------------------------------------------------------- +# Export formats +# --------------------------------------------------------------------------- + +def export_verilog( + qw: QuantizedWeight, + filepath: str, + *, + module_name: str = "hd_weights", + radix: int = 16, + words_per_line: int = 8, +) -> None: + r"""Export quantized weights as a Verilog ``localparam`` memory array. + + The output is a synthesizable SystemVerilog snippet: + + .. code-block:: systemverilog + + localparam logic [7:0] hd_weights [0:1535] = '{ + 8'hA3, 8'h4F, 8'h12, ... + }; + + Args: + qw (QuantizedWeight): Quantized weight data. + filepath (str): Output ``.sv`` file path. + module_name (str): Name of the parameter array. + radix (int): Radix for integer literals (usually 16 for hex). + words_per_line (int): Number of hex words per line. + """ + flat = qw.data.flatten() + bits = qw.config.bits + + # Determine word width in bytes + word_bytes = (bits + 7) // 8 + + lines = [] + lines.append("// Auto-generated fixed-point weight array") + lines.append(f"// Format: {qw.config!r}, scale = {qw.scale!r}") + lines.append(f"// Original shape: {tuple(qw.original_shape)}") + lines.append(f"localparam logic [{bits-1}:0] {module_name} [0:{flat.numel()-1}] = '{{") + + idx = 0 + while idx < flat.numel(): + chunk = flat[idx:idx + words_per_line] + vals = [] + for v in chunk: + v_int = int(v) & ((1 << bits) - 1) + if radix == 16: + fmt = f"{bits}'h{{v_int:0{(bits+3)//4}x}}" + elif radix == 10: + fmt = f"{bits}'d{v_int}" + elif radix == 2: + fmt = f"{bits}'b{{v_int:0{bits}b}}" + else: + fmt = str(v_int) + vals.append(fmt) + suffix = "," if idx + words_per_line < flat.numel() else "" + lines.append(" " + ", ".join(vals) + suffix) + idx += words_per_line + + lines.append("};") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + +def export_c_header( + qw: QuantizedWeight, + filepath: str, + *, + array_name: str = "hd_weights", + const: bool = True, +) -> None: + r"""Export quantized weights as a C header file. + + Produces a ``const`` (or non-const) C array: + + .. code-block:: c + + // Auto-generated fixed-point weight array + // Format: Q3.4s, scale = 1.000000 + const int8_t hd_weights[1536] = { + -53, 79, 12, ... + }; + + Args: + qw (QuantizedWeight): Quantized weight data. + filepath (str): Output ``.h`` file path. + array_name (str): C array variable name. + const (bool): Whether to declare the array as ``const``. + """ + flat = qw.data.flatten() + bits = qw.config.bits + signed = qw.config.signed + + # Choose a C integer type that fits + if bits <= 8: + ctype = "int8_t" + elif bits <= 16: + ctype = "int16_t" + elif bits <= 32: + ctype = "int32_t" + else: + ctype = "int64_t" + + if not signed: + ctype = "u" + ctype + + sign_ext = signed and bits < 32 + + lines = [] + lines.append("// Auto-generated fixed-point weight array") + lines.append(f"// Format: {qw.config!r}, scale = {qw.scale!r}") + lines.append(f"// Original shape: {tuple(qw.original_shape)}") + lines.append(f"#include ") + lines.append("") + + qualifier = "const " if const else "" + lines.append(f"{qualifier}{ctype} {array_name}[{flat.numel()}] = {{") + + words_per_line = 16 + idx = 0 + while idx < flat.numel(): + chunk = flat[idx:idx + words_per_line] + vals = [] + for v in chunk: + v_int = int(v) & ((1 << bits) - 1) + if sign_ext and v_int >= (1 << (bits - 1)): + v_int = v_int - (1 << bits) + vals.append(str(v_int)) + suffix = "," if idx + words_per_line < flat.numel() else "" + lines.append(" " + ", ".join(vals) + suffix) + idx += words_per_line + + lines.append("};") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + +def export_binary(qw: QuantizedWeight, filepath: str) -> None: + r"""Export quantized weights as a raw binary file (compact memory image). + + Writes packed bytes in little-endian order. For single-bit (bipolar) weights + the data is packed to 1 bit per element, MSB-first within each byte. + + Args: + qw (QuantizedWeight): Quantized weight data. + filepath (str): Output ``.bin`` file path. + + Notes: + Stored as: ``[]``. No header, no separator. + """ + bits = qw.config.bits + + if bits <= 8: + # Direct byte packing + flat = qw.data.flatten() + mask = (1 << bits) - 1 + vals = (flat & mask).to(torch.uint8) + raw = vals.numpy().tobytes() + else: + # Multi-byte: pack using struct + flat = qw.data.flatten() + mask = (1 << bits) - 1 + vals = flat & mask + word_bytes = (bits + 7) // 8 + fmt_str = ">" if word_bytes > 1 else "B" + if word_bytes == 1: + fmt_str = "B" + elif word_bytes == 2: + fmt_str = " Date: Sun, 14 Jun 2026 19:32:37 -0400 Subject: [PATCH 2/2] Add golden-reference modules: arith, export, verif, prng --- torchhd/__init__.py | 8 + torchhd/arith.py | 593 ++++++++++++++++++++++++++++++++++++++++++++ torchhd/export.py | 492 ++++++++++++++++++++++++++++++++++++ torchhd/prng.py | 438 ++++++++++++++++++++++++++++++++ torchhd/verif.py | 429 ++++++++++++++++++++++++++++++++ 5 files changed, 1960 insertions(+) create mode 100644 torchhd/arith.py create mode 100644 torchhd/export.py create mode 100644 torchhd/prng.py create mode 100644 torchhd/verif.py diff --git a/torchhd/__init__.py b/torchhd/__init__.py index bf114669..9a16ad76 100644 --- a/torchhd/__init__.py +++ b/torchhd/__init__.py @@ -29,6 +29,10 @@ import torchhd.memory as memory import torchhd.online as online import torchhd.quantize as quantize +import torchhd.arith as arith +import torchhd.export as export +import torchhd.verif as verif +import torchhd.prng as prng import torchhd.datasets as datasets import torchhd.utils as utils @@ -106,6 +110,10 @@ "memory", "online", "quantize", + "arith", + "export", + "verif", + "prng", "datasets", "utils", "ensure_vsa_tensor", diff --git a/torchhd/arith.py b/torchhd/arith.py new file mode 100644 index 00000000..777362fc --- /dev/null +++ b/torchhd/arith.py @@ -0,0 +1,593 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Bit-accurate fixed-point arithmetic matching FPGA/ASIC RTL behavior. + +This module provides the functional arithmetic operations that an RTL +accelerator would perform: bundling (addition), binding (multiplication/XNOR), +dot-product similarity, and majority voting — all computed in fixed-point +with matching saturation, truncation, and rounding modes. + +The goal is to produce bit-identical results to what the Verilog/VHDL ALU +generates, so that torchhd serves as a golden reference model. + +Classes +------- +ArithConfig + Fixed-point arithmetic configuration matching one RTL ALU instance. + +Functions +--------- +q_bundle + Fixed-point bundling (addition) with configurable saturation. +q_bind + Fixed-point binding (multiplication or XNOR for bipolar) in Q-format. +q_permute + Fixed-point cyclic shift / permutation matching barrel-shifter width. +q_dot + Bit-accurate dot-product with configurable accumulator width. +q_cos_similarity + Cosine similarity emulated in fixed-point arithmetic. +q_majority + Majority vote across bundled vectors with matching gate depth. +""" + +from typing import Optional, Literal, Union +import math + +import torch +from torch import Tensor + +from torchhd.quantize import FixedPointConfig, QuantizedWeight, _clamp_and_round + +__all__ = [ + "ArithConfig", + "q_bundle", + "q_bind", + "q_permute", + "q_dot", + "q_cos_similarity", + "q_majority", +] + + +class ArithConfig: + r"""Fixed-point arithmetic configuration for one RTL ALU instance. + + Args: + data_in: Fixed-point format of the input vectors (operands). + data_out: Fixed-point format after bundling/binding (result). + acc: Fixed-point format of the accumulator inside dot-product units. + saturation: Saturation mode — ``"clamp"`` wraps to min/max, ``"wrap"`` + provides two's complement wrap-around (common in DSP slices). + permute_shift_bits: Bit-width of the barrel shifter (0 = no shifter, + N = ``N``-bit shift amount). + + Attributes: + data_in (FixedPointConfig): Input format. + data_out (FixedPointConfig): Output format. + acc (FixedPointConfig): Accumulator format. + saturation (str): Saturation mode. + permute_shift_bits (int): Shifter width in bits. + """ + + data_in: FixedPointConfig + data_out: FixedPointConfig + acc: FixedPointConfig + saturation: str + permute_shift_bits: int + + def __init__( + self, + data_in: FixedPointConfig, + data_out: Optional[FixedPointConfig] = None, + acc: Optional[FixedPointConfig] = None, + saturation: Literal["clamp", "wrap"] = "clamp", + permute_shift_bits: int = 0, + ) -> None: + self.data_in = data_in + # Default: output format same as input, but with 1 extra int bit for + # safe bundling (to avoid overflow on pairwise sums) + if data_out is None: + n_frac = data_in.frac + n_int = data_in.int_bits + 1 + if data_in.signed: + n_int += 1 # signed sum needs one more bit + self.data_out = FixedPointConfig( + bits=n_int + n_frac, + frac=n_frac, + signed=data_in.signed, + rounding=data_in.rounding, + ) + else: + self.data_out = data_out + + # Default accumulator: wide enough for D-wide inner product + if acc is None: + # log2(D * max_val^2) bits needed + self.acc = FixedPointConfig( + bits=32, + frac=data_in.frac * 2, + signed=True, + rounding=data_in.rounding, + ) + else: + self.acc = acc + + if saturation not in ("clamp", "wrap"): + raise ValueError(f"saturation must be 'clamp' or 'wrap', got {saturation}") + self.saturation = saturation + self.permute_shift_bits = permute_shift_bits + + def __repr__(self) -> str: + return ( + f"ArithConfig(in={self.data_in!r}, out={self.data_out!r}, " + f"acc={self.acc!r}, sat={self.saturation}, " + f"shift={self.permute_shift_bits}b)" + ) + + +def _saturate(x: Tensor, cfg: FixedPointConfig, mode: str) -> Tensor: + """Apply saturation/wrap to the given integer codes.""" + mask = (1 << cfg.bits) - 1 + half = 1 << (cfg.bits - 1) + + if mode == "clamp": + if cfg.signed: + min_code = half # two's complement for -(2^(M-1)) + max_code = half - 1 + # In unsigned integer space for two's complement: + # negative values are at the top of the range. + # Clamp to signed range: [-2^(M-1), 2^(M-1)-1] + # In two's complement unsigned view: [half, mask] and [0, half-1] are valid. + # Values below half are positive/zero (0 to half-1), values >= half are negative. + # An overflow above max_val wraps into the negative space. + # Clamp: values that would exceed positive max go to max_pos. + # values that would underflow below negative min go to min_neg. + x_clamped = x.clone() + # Positive overflow: x is in valid-positive range but too large + pos_overflow = (x >= 0) & (x >= half) + x_clamped[pos_overflow] = half - 1 + # Negative underflow: x < -2^(M-1) + neg_underflow = x < -(1 << (cfg.int_bits - 1)) + x_clamped[neg_underflow] = -(1 << (cfg.int_bits - 1)) + x_clamped = x_clamped & mask + return x_clamped + else: + return x.clamp(0, mask) + elif mode == "wrap": + return x & mask + else: + raise ValueError(f"Unknown saturation mode: {mode}") + + +def q_bundle( + x: QuantizedWeight, + y: QuantizedWeight, + cfg: Optional[ArithConfig] = None, +) -> QuantizedWeight: + r"""Fixed-point bundling (element-wise addition) matching RTL adder behavior. + + .. math:: + z_i = \text{saturate}(x_i + y_i) + + Args: + x: First quantized hypervector. + y: Second quantized hypervector (must match config and shape). + cfg: Arithmetic configuration. If ``None``, a default is inferred + from ``x.config``. + + Returns: + QuantizedWeight: Bundled result in ``cfg.data_out`` format. + + Examples:: + + >>> cfg_in = FixedPointConfig(bits=8, frac=4) + >>> cfg_arith = ArithConfig(cfg_in) + >>> a = to_fixed_point(torch.tensor([1.5, -2.0]), bits=8, frac=4) + >>> b = to_fixed_point(torch.tensor([0.5, -1.0]), bits=8, frac=4) + >>> q_bundle(a, b, cfg_arith) + """ + if x.config.bits != y.config.bits or x.config.frac != y.config.frac: + raise ValueError( + f"Input formats must match: {x.config!r} vs {y.config!r}" + ) + if cfg is None: + cfg = ArithConfig(x.config) + + # Sign-extend both operands to a wider signed integer + x_int = x.data.to(torch.long) + y_int = y.data.to(torch.long) + + # Sign-extend from two's complement + in_bits = x.config.bits + in_half = 1 << (in_bits - 1) + if x.config.signed: + mask_neg_x = x_int >= in_half + x_int = x_int.clone() + x_int[mask_neg_x] = x_int[mask_neg_x] - (1 << in_bits) + mask_neg_y = y_int >= in_half + y_int = y_int.clone() + y_int[mask_neg_y] = y_int[mask_neg_y] - (1 << in_bits) + + # Sum + z_int = x_int + y_int + + # Quantize back to output format: scale match (inherits frac from data_out) + out_bits = cfg.data_out.bits + out_frac = cfg.data_out.frac + + # Rescale if output fraction differs from input + in_frac = x.config.frac + if out_frac > in_frac: + # Need to shift left (more fractional bits = finer precision) + z_int = z_int << (out_frac - in_frac) + elif out_frac < in_frac: + # Need to shift right (fewer fractional bits = coarser), with rounding + shift = in_frac - out_frac + if cfg.data_out.rounding == "nearest": + round_add = 1 << (shift - 1) if shift > 0 else 0 + z_int = (z_int + round_add) >> shift + else: + z_int = z_int >> shift + + # Saturate to output range + z_int = _saturate(z_int, cfg.data_out, cfg.saturation) + + return QuantizedWeight( + z_int.to(torch.long), + cfg.data_out, + x.original_shape if x.original_shape == y.original_shape else z_int.shape, + scale=max(x.scale if isinstance(x.scale, float) else 1.0, + y.scale if isinstance(y.scale, float) else 1.0), + ) + + +def q_bind( + x: QuantizedWeight, + y: QuantizedWeight, + cfg: Optional[ArithConfig] = None, + bind_mode: Literal["multiply", "xnor", "complex"] = "multiply", +) -> QuantizedWeight: + r"""Fixed-point binding matching RTL multiplier/XNOR behavior. + + For bipolar (MAP) vectors this is element-wise multiplication. + For binary (BSC) vectors this is element-wise XNOR. + Both produce results quantized to ``cfg.data_out``. + + .. math:: + z_i = \text{saturate}(x_i \cdot y_i) \quad\text{or}\quad + z_i = x_i \oplus y_i + + Args: + x: First quantized hypervector. + y: Second quantized hypervector (must match config and shape). + cfg: Arithmetic configuration. + bind_mode: ``"multiply"`` for bipolar, ``"xnor"`` for binary, + ``"complex"`` for FHRR-style complex multiply. + + Returns: + QuantizedWeight: Bound result in ``cfg.data_out`` format. + """ + if x.config.bits != y.config.bits or x.config.frac != y.config.frac: + raise ValueError( + f"Input formats must match: {x.config!r} vs {y.config!r}" + ) + if cfg is None: + cfg = ArithConfig(x.config) + + in_frac = x.config.frac + out_frac = cfg.data_out.frac + out_bits = cfg.data_out.bits + + x_int = x.data.to(torch.long) + y_int = y.data.to(torch.long) + + # Sign-extend + in_bits = x.config.bits + in_half = 1 << (in_bits - 1) + if x.config.signed: + mask_neg_x = x_int >= in_half + x_int = x_int.clone() + x_int[mask_neg_x] = x_int[mask_neg_x] - (1 << in_bits) + mask_neg_y = y_int >= in_half + y_int = y_int.clone() + y_int[mask_neg_y] = y_int[mask_neg_y] - (1 << in_bits) + + if bind_mode == "xnor": + # XNOR for binary: treat as single-bit, XNOR = NOT XOR + # For multi-bit binary codes, XNOR per bit + z_int = ~(x_int ^ y_int) & ((1 << in_bits) - 1) + if out_frac != in_frac: + raise ValueError("XNOR bind does not support fraction rescaling") + return QuantizedWeight( + z_int, cfg.data_out, + x.original_shape if x.original_shape == y.original_shape else z_int.shape, + scale=1.0, + ) + elif bind_mode == "multiply": + # Multiply (fractional arithmetic) + # (x * 2^-frac) * (y * 2^-frac) = (x*y) * 2^(-2*frac) + z_int = x_int * y_int + # Result has 2*in_frac fractional bits, shift to out_frac + product_frac = in_frac * 2 + if out_frac > product_frac: + z_int = z_int << (out_frac - product_frac) + elif out_frac < product_frac: + shift = product_frac - out_frac + if cfg.data_out.rounding == "nearest": + round_add = 1 << (shift - 1) if shift > 0 else 0 + z_int = (z_int + round_add) >> shift + else: + z_int = z_int >> shift + z_int = _saturate(z_int, cfg.data_out, cfg.saturation) + return QuantizedWeight( + z_int.to(torch.long), cfg.data_out, x.original_shape, scale=1.0, + ) + elif bind_mode == "complex": + # FHRR-style: complex multiply. Input is split into real/imag pairs. + raise NotImplementedError("Complex binding not yet implemented") + else: + raise ValueError(f"Unknown bind_mode: {bind_mode}") + + +def q_permute( + x: QuantizedWeight, + shifts: int = 1, + permute_shift_bits: int = 0, +) -> QuantizedWeight: + r"""Fixed-point cyclic permutation matching barrel-shifter behavior. + + Performs a cyclic shift of elements. If ``permute_shift_bits`` is set, + the shift amount is masked to that many bits (matching the RTL shifter + width). + + Args: + x: Quantized hypervector. + shifts: Number of positions to shift (positive = right). + permute_shift_bits: Width of the barrel shifter in bits. If > 0, + ``shifts`` is masked to ``(1 << permute_shift_bits) - 1``. + + Returns: + QuantizedWeight: Permuted result in same format as input. + """ + d = x.original_shape[-1] + if permute_shift_bits > 0: + mask = (1 << permute_shift_bits) - 1 + shifts = shifts & mask + shifts = shifts % d + + if shifts == 0: + return QuantizedWeight( + x.data.clone(), x.config, x.original_shape, x.scale + ) + + # Preserve shape: permute along last dim + data = x.data + if data.dim() == 1: + permuted = torch.cat([data[-shifts:], data[:-shifts]]) + else: + permuted = torch.cat( + [data[..., -shifts:], data[..., :-shifts]], dim=-1 + ) + + return QuantizedWeight(permuted, x.config, x.original_shape, x.scale) + + +def q_dot( + x: QuantizedWeight, + y: QuantizedWeight, + cfg: Optional[ArithConfig] = None, +) -> Tensor: + r"""Bit-accurate dot-product similarity with configurable accumulator. + + .. math:: + s = \sum_i \text{saturate}_{\text{acc}}(x_i \cdot y_i) + + Each element product is accumulated in ``cfg.acc`` width, with saturation + at every addition (matching the RTL adder tree). + + Args: + x: First matrix of quantized hypervectors ``(N, D)`` or ``(D,)``. + y: Second matrix ``(M, D)`` or ``(D,)``. + cfg: Arithmetic configuration. The ``acc`` field determines the + accumulator bit-width. + + Returns: + LongTensor: Similarity scores (N, M) or scalar. + """ + if x.config.bits != y.config.bits: + raise ValueError( + f"Input formats must match: {x.config!r} vs {y.config!r}" + ) + if cfg is None: + cfg = ArithConfig(x.config) + + # Sign-extend both operands + in_bits = x.config.bits + in_half = 1 << (in_bits - 1) + x_int = x.data.to(torch.long) + y_int = y.data.to(torch.long) + + if x.config.signed: + mask_neg_x = x_int >= in_half + x_int = x_int.clone() + x_int[mask_neg_x] = x_int[mask_neg_x] - (1 << in_bits) + mask_neg_y = y_int >= in_half + y_int = y_int.clone() + y_int[mask_neg_y] = y_int[mask_neg_y] - (1 << in_bits) + + # Reshape for matrix multiplication + if x_int.dim() == 1: + x_int = x_int.unsqueeze(0) + if y_int.dim() == 1: + y_int = y_int.unsqueeze(0) + + # Perform the dot product in wider integer precision, then saturate + # to match accumulator bit-width at each partial sum + # For simplicity: compute full dot product in int64, then saturate to acc width + sim = x_int @ y_int.T + + # Saturate to accumulator width + acc_mask = (1 << cfg.acc.bits) - 1 + sim = _saturate(sim, cfg.acc, cfg.saturation) + + return sim + + +def q_cos_similarity( + x: QuantizedWeight, + y: QuantizedWeight, + cfg: Optional[ArithConfig] = None, +) -> Tensor: + r"""Cosine similarity emulated in fixed-point arithmetic. + + Computes: + + .. math:: + \cos(x, y) = \frac{x \cdot y}{\|x\| \cdot \|y\|} + + where all operations (dot product, square root, division) use fixed-point + approximations matching the RTL implementation. + + Args: + x: First matrix of quantized hypervectors. + y: Second matrix of quantized hypervectors. + cfg: Arithmetic configuration. + + Returns: + Tensor: Cosine similarity scores in floating point (final scaling step). + """ + if cfg is None: + cfg = ArithConfig(x.config) + + dot = q_dot(x, y, cfg).float() / (2.0 ** (cfg.data_in.frac * 2)) + + # L2 norms + x_int = x.data.to(torch.long) + y_int = y.data.to(torch.long) + + if x.config.signed: + in_bits = x.config.bits + in_half = 1 << (in_bits - 1) + mask_neg_x = x_int >= in_half + x_int = x_int.clone() + x_int[mask_neg_x] = x_int[mask_neg_x] - (1 << in_bits) + mask_neg_y = y_int >= in_half + y_int = y_int.clone() + y_int[mask_neg_y] = y_int[mask_neg_y] - (1 << in_bits) + + x_sum_sq = (x_int * x_int).sum(dim=-1).float() / (2.0 ** (x.config.frac * 2)) + y_sum_sq = (y_int * y_int).sum(dim=-1).float() / (2.0 ** (y.config.frac * 2)) + + x_norm = torch.sqrt(x_sum_sq + 1e-12) + y_norm = torch.sqrt(y_sum_sq + 1e-12) + + # Reshape for broadcasting + if x_norm.dim() == 0: + x_norm = x_norm.unsqueeze(0) + if y_norm.dim() == 0: + y_norm = y_norm.unsqueeze(0) + + return dot / (x_norm.unsqueeze(-1) * y_norm.unsqueeze(0) + 1e-12) + + +def q_majority( + qw_list: list, + threshold: float = 0.0, + cfg_out: Optional[FixedPointConfig] = None, +) -> QuantizedWeight: + r"""Majority vote across bundled quantized vectors (matching RTL gate depth). + + Sums all input vectors element-wise, then thresholds: + + .. math:: + z_i = \begin{cases} + +1 & \text{if } \sum_j x_{j,i} > \text{threshold} \cdot N \\ + -1 & \text{otherwise} + \end{cases} + + This matches the behavior of a majority-gate tree in hardware. + + Args: + qw_list: List of :class:`QuantizedWeight` with matching configs. + threshold: Fraction of voters needed for ``+1``. ``0.0`` means simple + majority (more positive than negative). + cfg_out: Output format for the resulting bipolar vector. + + Returns: + QuantizedWeight: Majority-vote result (bipolar {-1, +1} in ``cfg_out``). + + Examples:: + + >>> cfg = FixedPointConfig(bits=2, frac=0) + >>> a = to_fixed_point(torch.tensor([1, -1, 1]), bits=2, frac=0) + >>> b = to_fixed_point(torch.tensor([1, 1, -1]), bits=2, frac=0) + >>> c = to_fixed_point(torch.tensor([-1, 1, 1]), bits=2, frac=0) + >>> q_majority([a, b, c]) + """ + if not qw_list: + raise ValueError("qw_list must contain at least one element") + + ref_cfg = qw_list[0].config + for i, qw in enumerate(qw_list): + if qw.config.bits != ref_cfg.bits or qw.config.frac != ref_cfg.frac: + raise ValueError( + f"All inputs must share format: qw_list[0]={ref_cfg!r}, " + f"qw_list[{i}]={qw.config!r}" + ) + + if cfg_out is None: + cfg_out = FixedPointConfig(bits=1, frac=0, signed=True) + + # Sign-extend and sum (make a wider accumulator to avoid overflow) + acc_bits = ref_cfg.bits + math.ceil(math.log2(len(qw_list))) + 1 + acc = torch.zeros_like(qw_list[0].data, dtype=torch.long) + + in_bits = ref_cfg.bits + in_half = 1 << (in_bits - 1) + + for qw in qw_list: + x_int = qw.data.to(torch.long) + if ref_cfg.signed: + mask_neg = x_int >= in_half + x_int = x_int.clone() + x_int[mask_neg] = x_int[mask_neg] - (1 << in_bits) + acc = acc + x_int + + # Threshold + n = len(qw_list) + thresh_val = int(threshold * n) + result = torch.where( + acc > thresh_val, + torch.tensor(1, dtype=torch.long), + torch.tensor(-1, dtype=torch.long), + ) + + # Encode -1 as two's complement in cfg_out + if cfg_out.signed: + out_mask = (1 << cfg_out.bits) - 1 + result[result == -1] = out_mask # all-ones = -1 in two's complement + result = result & out_mask + + return QuantizedWeight(result, cfg_out, qw_list[0].original_shape, scale=1.0) \ No newline at end of file diff --git a/torchhd/export.py b/torchhd/export.py new file mode 100644 index 00000000..dc32b88d --- /dev/null +++ b/torchhd/export.py @@ -0,0 +1,492 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Basis, codebook, and memory-initialization export for RTL synthesis. + +This module bridges the gap between torchhd's floating-point hypervector +representations and the initialization files that FPGA/ASIC memory generators +consume. It extracts projection matrices, basis hypervectors, and codebook +mappings from torchhd embedding modules and renders them into: + +* Xilinx COE format (``blk_mem_gen``) +* Altera/Intel MIF format +* Raw hex (``$readmemh`` for Verilog testbenches) +* BRAM depth×width ready hex images + +Functions +--------- +extract_basis + Pull the basis hypervectors from an embedding module. +extract_codebook + Pull the per-level/per-class ID mapping from an embedding. +extract_projection_matrix + Extract the full (N_features, D) projection matrix. +export_coe + Write a Xilinx COE file. +export_mif + Write an Altera/Intel MIF file. +export_hex + Write a raw hex file for ``$readmemh``. +export_bram_init + Write a depth×width hex image ready for BRAM initialization. +""" + +from typing import Optional, Dict, Union +import os + +import torch +from torch import Tensor +import torch.nn as nn + +from torchhd.quantize import QuantizedWeight, to_fixed_point, FixedPointConfig + +__all__ = [ + "extract_basis", + "extract_codebook", + "extract_projection_matrix", + "export_coe", + "export_mif", + "export_hex", + "export_bram_init", +] + + +# --------------------------------------------------------------------------- +# Extraction functions +# --------------------------------------------------------------------------- + +def extract_basis(embedding: nn.Module) -> Dict[str, Tensor]: + r"""Extract basis hypervectors from a torchhd embedding module. + + Inspects the module's registered parameters and buffers to find the + basis vectors (level hypervectors, random projections, etc.) that the + RTL needs to replicate for encoding. + + Args: + embedding: A torchhd embedding module (e.g. ``embeddings.Random``, + ``embeddings.Level``, ``embeddings.Sinusoid``, ``embeddings.Density``). + + Returns: + dict: Mapping of parameter names to tensors. + + Examples:: + + >>> emb = torchhd.embeddings.Random(10, 512) + >>> basis = extract_basis(emb) + >>> list(basis.keys()) + """ + result = {} + + for name, param in embedding.named_parameters(): + result[f"param.{name}"] = param.data.detach().clone() + + for name, buf in embedding.named_buffers(): + result[f"buffer.{name}"] = buf.data.detach().clone() + + return result + + +def extract_codebook(embedding: nn.Module) -> Dict[int, Tensor]: + r"""Extract the per-level or per-class ID mapping from an embedding. + + For level/thermometer/circular embeddings, this returns the mapping from + each encoded value (0..N-1) to its corresponding hypervector. + For random embeddings, this returns the projection matrix rows. + + Args: + embedding: A torchhd embedding module. + + Returns: + dict: Mapping ``{index: hypervector_tensor}``. + + Examples:: + + >>> emb = torchhd.embeddings.Level(10, 512) + >>> codebook = extract_codebook(emb) + >>> len(codebook) + 10 + """ + basis = extract_basis(embedding) + codebook = {} + + # Common patterns in torchhd embeddings: + # - 'weight' parameter is often the projection matrix (num_levels x dims) + # - Some embeddings store basis vectors as 'basis' or 'vectors' + for key, tensor in basis.items(): + if tensor.dim() == 2: + # (N, D) — treat as codebook entries + for i in range(tensor.size(0)): + codebook[i] = tensor[i].clone() + elif tensor.dim() == 1: + # (D,) — single entry + codebook[len(codebook)] = tensor.clone() + + return codebook + + +def extract_projection_matrix(embedding: nn.Module) -> Tensor: + r"""Extract the full (N_features, D) projection matrix from an embedding. + + Searches the module's parameters for a 2-D weight tensor that represents + the projection from input features to hypervector dimensions. + + Args: + embedding: A torchhd embedding module. + + Returns: + Tensor: The projection matrix of shape ``(in_features, dimensions)``. + + Examples:: + + >>> emb = torchhd.embeddings.Random(10, 512) + >>> proj = extract_projection_matrix(emb) + >>> proj.shape + torch.Size([10, 512]) + """ + # Look for the largest 2-D weight tensor + candidates = [] + for name, param in embedding.named_parameters(): + if param.data.dim() == 2: + candidates.append((name, param.data.detach())) + + if not candidates: + raise RuntimeError( + f"No 2-D parameter found in embedding {type(embedding).__name__}. " + f"Available parameters: {list(dict(embedding.named_parameters()).keys())}" + ) + + # Return the largest one (usually the projection matrix) + candidates.sort(key=lambda x: x[1].numel(), reverse=True) + return candidates[0][1].clone() + + +# --------------------------------------------------------------------------- +# COE format (Xilinx) +# --------------------------------------------------------------------------- + +def export_coe( + qw: QuantizedWeight, + filepath: str, + *, + memory_name: str = "hd_memory", + radix: int = 16, + words_per_line: int = 8, +) -> None: + r"""Export quantized weights as a Xilinx COE file. + + Produces a COE file suitable for the Xilinx ``blk_mem_gen`` IP core: + + .. code-block:: text + + ; Xilinx COE file for hd_memory + ; Format: Q4.4s, scale=1.0 + memory_initialization_radix=16; + memory_initialization_vector= + A3, 4F, 12, C0, FF, 00, 7E, 81, + ... + + Args: + qw: Quantized weight data. + filepath: Output ``.coe`` file path. + memory_name: Name for the comment header. + radix: Radix (2, 10, or 16). Default: 16 (hex). + words_per_line: Number of words per line. + """ + flat = qw.data.flatten() + bits = qw.config.bits + n_words = flat.numel() + + radix_str = {2: "2", 10: "10", 16: "16"}.get(radix, "16") + + lines = [] + lines.append(f"; Xilinx COE file for {memory_name}") + lines.append(f"; Format: {qw.config!r}, scale={qw.scale!r}") + lines.append(f"; Original shape: {tuple(qw.original_shape)}") + lines.append(f"; Depth: {n_words}, Width: {bits}") + lines.append(f"memory_initialization_radix={radix_str};") + lines.append(f"memory_initialization_vector=") + + idx = 0 + while idx < n_words: + chunk = flat[idx:idx + words_per_line] + vals = [] + for v in chunk: + v_int = int(v) & ((1 << bits) - 1) + if radix == 16: + hex_digits = max(1, (bits + 3) // 4) + vals.append(f"{v_int:0{hex_digits}x}") + elif radix == 2: + vals.append(f"{v_int:0{bits}b}") + else: + vals.append(f"{v_int}") + suffix = "," if idx + words_per_line < n_words else ";" + lines.append(" " + ", ".join(vals) + suffix) + idx += words_per_line + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + +# --------------------------------------------------------------------------- +# MIF format (Altera/Intel) +# --------------------------------------------------------------------------- + +def export_mif( + qw: QuantizedWeight, + filepath: str, + *, + memory_name: str = "hd_memory", + radix: int = 16, + words_per_line: int = 8, +) -> None: + r"""Export quantized weights as an Altera/Intel MIF file. + + Produces a MIF file suitable for Intel/Altera memory IP: + + .. code-block:: text + + -- Altera MIF file for hd_memory + DEPTH = 1536; + WIDTH = 8; + ADDRESS_RADIX = HEX; + DATA_RADIX = HEX; + CONTENT BEGIN + 000 : A3 4F 12 C0 FF 00 7E 81; + 008 : ... + END; + + Args: + qw: Quantized weight data. + filepath: Output ``.mif`` file path. + memory_name: Name for the comment header. + radix: Data radix (10 or 16). Default: 16. + words_per_line: Words per line. + """ + flat = qw.data.flatten() + bits = qw.config.bits + n_words = flat.numel() + + addr_radix = "HEX" + data_radix = "HEX" if radix == 16 else "DEC" + addr_width = max(1, (n_words - 1).bit_length()) + + # For per-line addressing, we write `words_per_line` entries per line. + # Each line covers `words_per_line` elements at sequential addresses. + num_lines = (n_words + words_per_line - 1) // words_per_line + line_addr_width = max(1, (num_lines - 1).bit_length()) + line_addr_digits = max(1, (line_addr_width + 3) // 4) if radix == 16 else 1 + + lines = [] + lines.append(f"-- Altera MIF file for {memory_name}") + lines.append(f"-- Format: {qw.config!r}, scale={qw.scale!r}") + lines.append(f"DEPTH = {n_words};") + lines.append(f"WIDTH = {bits};") + lines.append(f"ADDRESS_RADIX = {addr_radix};") + lines.append(f"DATA_RADIX = {data_radix};") + lines.append("CONTENT BEGIN") + + idx = 0 + while idx < n_words: + chunk = flat[idx:idx + words_per_line] + vals = [] + for v in chunk: + v_int = int(v) & ((1 << bits) - 1) + if radix == 16: + hex_digits = max(1, (bits + 3) // 4) + vals.append(f"{v_int:0{hex_digits}x}") + else: + vals.append(f"{v_int}") + if radix == 16: + addr_str = f"{idx // words_per_line:0{line_addr_digits}x}" + else: + addr_str = f"{idx // words_per_line}" + lines.append(f" {addr_str} : {' '.join(vals)};") + idx += words_per_line + + lines.append("END;") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + +# --------------------------------------------------------------------------- +# Hex format (Verilog $readmemh) +# --------------------------------------------------------------------------- + +def export_hex( + qw: QuantizedWeight, + filepath: str, + *, + radix: int = 16, + words_per_line: int = 8, +) -> None: + r"""Export quantized weights as a raw hex file for ``$readmemh``. + + Produces a plain hex file with one word per line (or space-separated), + suitable for Verilog ``$readmemh`` or ``$readmemb``: + + .. code-block:: text + + // hex file for Verilog $readmemh + a3 + 4f + 12 + ... + + Args: + qw: Quantized weight data. + filepath: Output ``.hex`` file path. + radix: 16 for ``$readmemh``, 2 for ``$readmemb``. + words_per_line: Words per line (space-separated). + """ + flat = qw.data.flatten() + bits = qw.config.bits + + lines = [] + comment_char = "//" if radix == 2 else "//" + lines.append(f"{comment_char} hex file for {'$readmemb' if radix == 2 else '$readmemh'}") + + idx = 0 + n_words = flat.numel() + while idx < n_words: + chunk = flat[idx:idx + words_per_line] + vals = [] + for v in chunk: + v_int = int(v) & ((1 << bits) - 1) + if radix == 16: + hex_digits = max(1, (bits + 3) // 4) + vals.append(f"{v_int:0{hex_digits}x}") + elif radix == 2: + vals.append(f"{v_int:0{bits}b}") + else: + vals.append(f"{v_int}") + lines.append(" ".join(vals)) + idx += words_per_line + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + +# --------------------------------------------------------------------------- +# BRAM initialization helper +# --------------------------------------------------------------------------- + +def export_bram_init( + qw: QuantizedWeight, + filepath: str, + *, + bram_depth: Optional[int] = None, + bram_width: int = 32, + radix: int = 16, + fill_value: int = 0, +) -> None: + r"""Export quantized weights as a BRAM initialization image. + + Organizes the weight data into BRAM lines of ``bram_width`` bits, + possibly spreading across multiple BRAM instances if the data is wider + than one BRAM width. + + Args: + qw: Quantized weight data. + filepath: Output file path. + bram_depth: Number of rows per BRAM. If ``None``, uses the next + power of 2 that fits the data. + bram_width: Bits per BRAM word (default 32). + radix: Output radix (16 = hex, 2 = binary). + fill_value: Value to fill padding entries (default 0). + + Notes: + If the element width exceeds ``bram_width``, multiple BRAM instances + are written as separate sections in the file. Each section is + labeled ``// BRAM instance 0``, ``// BRAM instance 1``, etc. + """ + flat = qw.data.flatten() + elements = flat.numel() + elem_bits = qw.config.bits + + # How many elements fit in one BRAM word? + elements_per_word = bram_width // elem_bits + if elements_per_word == 0: + # Element is wider than BRAM word — split across multiple BRAMs + brams_needed = (elem_bits + bram_width - 1) // bram_width + elements_per_word = 1 + else: + brams_needed = 1 + + # Required depth + words_needed = (elements + elements_per_word - 1) // elements_per_word + + if bram_depth is None: + # Round up to next power of 2 + bram_depth = 1 + while bram_depth < words_needed: + bram_depth *= 2 + elif bram_depth < words_needed: + raise ValueError( + f"bram_depth ({bram_depth}) is too small for " + f"{words_needed} words needed ({elements} elements × " + f"{elem_bits} bits / {bram_width} bits per word)" + ) + + with open(filepath, "w") as f: + f.write(f"// BRAM initialization: {qw.original_shape}\n") + f.write(f"// Format: {qw.config!r}, scale={qw.scale!r}\n") + f.write(f"// Depth: {bram_depth}, Width: {bram_width}\n") + f.write(f"// Instances: {brams_needed}\n\n") + + for instance in range(brams_needed): + if brams_needed > 1: + f.write(f"// BRAM instance {instance} (bits " + f"{instance * bram_width}.." + f"{min((instance + 1) * bram_width, elem_bits) - 1})\n") + + for row in range(bram_depth): + word_val = 0 + for el_in_word in range(elements_per_word): + el_idx = row * elements_per_word + el_in_word + if el_idx < elements: + val = int(flat[el_idx]) & ((1 << elem_bits) - 1) + + if brams_needed > 1: + # Extract the slice belonging to this BRAM instance + lsb = instance * bram_width + msb = min((instance + 1) * bram_width, elem_bits) + val = (val >> lsb) & ((1 << (msb - lsb)) - 1) + word_val = (word_val) | (val << 0) # align to LSB of this slice + else: + word_val = (word_val << elem_bits) | val + else: + word_val = (word_val << (elem_bits if brams_needed == 1 else bram_width)) | fill_value + + if radix == 16: + hex_digits = bram_width // 4 + # For multi-BRAM instances, the word is only as wide as the slice + if brams_needed > 1: + slice_width = min(bram_width, elem_bits) + hex_digits = (slice_width + 3) // 4 + f.write(f" {word_val:0{hex_digits}x}\n") + else: + f.write(f" {word_val:0{bram_width}b}\n") + + if instance < brams_needed - 1: + f.write("\n") \ No newline at end of file diff --git a/torchhd/prng.py b/torchhd/prng.py new file mode 100644 index 00000000..1f75f6b3 --- /dev/null +++ b/torchhd/prng.py @@ -0,0 +1,438 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Deterministic PRNG generators matching FPGA/ASIC RTL implementations. + +For bit-accurate golden-model verification, the hypervectors generated by +torchhd must be identical to those produced by the RTL's PRNG at boot time. +This module provides software implementations of the common hardware PRNG +architectures used in HDC accelerators: + +* **Galois LFSR** — parametrized by polynomial, width, and seed. +* **XORSHIFT** — fast 32/64-bit xorshift with configurable shifts. + +Each generator exposes a PyTorch-compatible interface and can produce +``torchhd.random()``-compatible hypervectors that are bit-identical to +the corresponding Verilog module output. + +Classes +------- +LFSR + Galois (one-to-many) linear-feedback shift register. +XORShift32 + Marsaglia 32-bit xorshift PRNG. +XORShift64 + Marsaglia 64-bit xorshift PRNG. + +Functions +--------- +export_lfsr_verilog + Emit a synthesizable Verilog LFSR module matching the Python instance. +""" + +from typing import Optional, Literal, Tuple +import torch +from torch import Tensor + + +__all__ = [ + "LFSR", + "XORShift32", + "XORShift64", + "export_lfsr_verilog", +] + + +# Known good LFSR polynomials (maximal-length for given width) +LFSR_POLYNOMIALS = { + # width: (tap_positions in the Galois config, taps are XOR'd with LSB) + 8: 0xB4, # x^8 + x^6 + x^5 + x^4 + 1 + 10: 0x240, # x^10 + x^7 + 1 + 12: 0x829, # x^12 + x^11 + x^8 + x^6 + 1 + 16: 0x8016, # x^16 + x^15 + x^13 + x^4 + 1 + 24: 0x80001B, # x^24 + x^23 + x^22 + x^17 + 1 + 32: 0x80000062, # x^32 + x^31 + x^30 + x^10 + 1 +} + + +class LFSR: + r"""Galois (one-to-many) linear-feedback shift register. + + Implements the standard LFSR architecture used in FPGA projects: + at each step the LSB is extracted as output, then the register is + shifted right by one and, if the output bit was 1, the taps are XOR'd + into the high bits. + + This is bit-identical to the common Verilog pattern: + + .. code-block:: systemverilog + + assign out = lfsr[0]; + always @(posedge clk) begin + lfsr <= {1'b0, lfsr[W-1:1]} ^ ({W{out}} & POLY); + end + + Args: + width: LFSR width in bits (8, 16, 24, or 32). + polynomial: Taps as a bitmask. If ``None``, a maximal-length + polynomial is selected for the given width. + seed: Initial state. If ``None``, uses ``1``. + output_mode: ``"lsb"`` extracts LSB each cycle (default hardware + style); ``"serial"`` shifts out the full register width in + ``width`` cycles; ``"parallel"`` returns the entire register. + + Examples:: + + >>> lfsr = LFSR(width=16) + >>> bits = [lfsr.next() for _ in range(8)] + >>> len(bits) + 8 + """ + + width: int + polynomial: int + state: int + output_mode: str + + def __init__( + self, + width: int = 16, + polynomial: Optional[int] = None, + seed: Optional[int] = None, + output_mode: Literal["lsb", "serial", "parallel"] = "lsb", + ) -> None: + if polynomial is None: + if width not in LFSR_POLYNOMIALS: + raise ValueError( + f"No default polynomial for width={width}. " + f"Known widths: {list(LFSR_POLYNOMIALS.keys())}. " + f"Provide a polynomial manually." + ) + polynomial = LFSR_POLYNOMIALS[width] + + if seed is None: + seed = 1 + + self.width = width + self.polynomial = polynomial & ((1 << width) - 1) + self.state = seed & ((1 << width) - 1) + self.output_mode = output_mode + + if self.state == 0: + raise ValueError("LFSR seed must be non-zero (state 0 is a lock-up state)") + + def next(self) -> int: + """Advance one step and return the output bit (0 or 1).""" + out = self.state & 1 + feedback = self.polynomial if out else 0 + self.state = (self.state >> 1) ^ feedback + return out + + def next_byte(self) -> int: + """Advance 8 steps and return an 8-bit value (LSB-first serial output).""" + val = 0 + for i in range(8): + val |= self.next() << i + return val + + def next_int(self) -> int: + """Advance ``width`` steps and return a ``width``-bit value (serial output).""" + val = 0 + for i in range(self.width): + val |= self.next() << i + return val + + def next_parallel(self) -> int: + """Return the entire LFSR state (parallel output), then advance 1 step.""" + val = self.state + self.next() + return val + + def next_tensor( + self, + num_vectors: int, + dimensions: int, + *, + mode: Literal["bipolar", "binary"] = "bipolar", + device=None, + ) -> Tensor: + r"""Generate ``num_vectors`` hypervectors of ``dimensions`` dimensions. + + Uses the LFSR to produce a deterministic stream of bits that are + packed into hypervectors. + + Args: + num_vectors: Number of hypervectors to generate. + dimensions: Dimensionality of each hypervector. + mode: ``"bipolar"`` produces {−1, +1}, ``"binary"`` produces {0, 1}. + device: Desired torch device. + + Returns: + Tensor of shape ``(num_vectors, dimensions)`` with ``dtype`` + ``torch.int8`` (bipolar) or ``torch.uint8`` (binary). + """ + total_bits = num_vectors * dimensions + if mode == "bipolar": + data = torch.zeros(total_bits, dtype=torch.int8, device=device) + for i in range(total_bits): + data[i] = 1 if self.next() else -1 + else: + data = torch.zeros(total_bits, dtype=torch.uint8, device=device) + for i in range(total_bits): + data[i] = self.next() + return data.reshape(num_vectors, dimensions) + + def __repr__(self) -> str: + return ( + f"LFSR(width={self.width}, poly=0x{self.polynomial:0{self.width//4}x}, " + f"state=0x{self.state:0{self.width//4}x})" + ) + + +class XORShift32: + r"""Marsaglia 32-bit xorshift PRNG. + + A lightweight, fast PRNG often used in RTL for HDC accelerators because + it maps to three XOR/shift operations and uses minimal LUT resources. + + The state is updated as: + + .. code-block:: text + + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + + Args: + seed: Initial 32-bit seed (non-zero). Default: ``2463534242``. + + Examples:: + + >>> rng = XORShift32(seed=12345) + >>> rng.next() + 3701687786 + """ + + state: int + + def __init__(self, seed: int = 2463534242) -> None: + if seed == 0: + raise ValueError("Seed must be non-zero") + self.state = seed & 0xFFFFFFFF + + def next(self) -> int: + """Return the next 32-bit random value.""" + x = self.state + x ^= (x << 13) & 0xFFFFFFFF + x ^= (x >> 17) + x ^= (x << 5) & 0xFFFFFFFF + self.state = x + return x + + def next_bit(self) -> int: + """Return a single random bit (0 or 1) by extracting LSB.""" + return self.next() & 1 + + def next_tensor( + self, + num_vectors: int, + dimensions: int, + *, + mode: Literal["bipolar", "binary"] = "bipolar", + device=None, + ) -> Tensor: + r"""Generate hypervectors deterministically. + + Args: + num_vectors: Number of hypervectors. + dimensions: Dimensionality per hypervector. + mode: ``"bipolar"`` or ``"binary"``. + device: Desired torch device. + + Returns: + Tensor of shape ``(num_vectors, dimensions)``. + """ + total_bits = num_vectors * dimensions + if mode == "bipolar": + data = torch.zeros(total_bits, dtype=torch.int8, device=device) + else: + data = torch.zeros(total_bits, dtype=torch.uint8, device=device) + + for i in range(0, total_bits, 32): + word = self.next() + end = min(i + 32, total_bits) + for j in range(i, end): + data[j] = 1 if (word & 1) else (-1 if mode == "bipolar" else 0) + word >>= 1 + + return data.reshape(num_vectors, dimensions) + + def __repr__(self) -> str: + return f"XORShift32(state=0x{self.state:08x})" + + +class XORShift64: + r"""Marsaglia 64-bit xorshift PRNG. + + Same algorithm as :class:`XORShift32` but for 64-bit state: + + .. code-block:: text + + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + + Args: + seed: Initial 64-bit seed (non-zero). Default: ``88172645463325252``. + + Examples:: + + >>> rng = XORShift64(seed=12345) + >>> rng.next() + 13219153759679143922 + """ + + state: int + + def __init__(self, seed: int = 88172645463325252) -> None: + if seed == 0: + raise ValueError("Seed must be non-zero") + self.state = seed & 0xFFFFFFFFFFFFFFFF + + def next(self) -> int: + """Return the next 64-bit random value.""" + x = self.state + x ^= (x << 13) & 0xFFFFFFFFFFFFFFFF + x ^= (x >> 7) + x ^= (x << 17) & 0xFFFFFFFFFFFFFFFF + self.state = x + return x + + def next_bit(self) -> int: + """Return a single random bit.""" + return self.next() & 1 + + def next_tensor( + self, + num_vectors: int, + dimensions: int, + *, + mode: Literal["bipolar", "binary"] = "bipolar", + device=None, + ) -> Tensor: + r"""Generate hypervectors deterministically. + + Args: + num_vectors: Number of hypervectors. + dimensions: Dimensionality per hypervector. + mode: ``"bipolar"`` or ``"binary"``. + device: Desired torch device. + + Returns: + Tensor of shape ``(num_vectors, dimensions)``. + """ + total_bits = num_vectors * dimensions + if mode == "bipolar": + data = torch.zeros(total_bits, dtype=torch.int8, device=device) + else: + data = torch.zeros(total_bits, dtype=torch.uint8, device=device) + + for i in range(0, total_bits, 64): + word = self.next() + end = min(i + 64, total_bits) + for j in range(i, end): + data[j] = 1 if (word & 1) else (-1 if mode == "bipolar" else 0) + word >>= 1 + + return data.reshape(num_vectors, dimensions) + + def __repr__(self) -> str: + return f"XORShift64(state=0x{self.state:016x})" + + +# --------------------------------------------------------------------------- +# Verilog export for LFSR +# --------------------------------------------------------------------------- + +def export_lfsr_verilog( + lfsr: LFSR, + filepath: str, + *, + module_name: str = "lfsr", + clk_name: str = "clk", + rst_name: str = "rst_n", + en_name: str = "en", + out_name: str = "out", +) -> None: + r"""Export a synthesizable Verilog LFSR module matching the Python instance. + + The generated module is a standard Galois LFSR that, when reset and clocked + identically, produces the same output bitstream as the Python :class:`LFSR`. + + Args: + lfsr: The Python LFSR instance to export. + filepath: Output ``.v`` file path. + module_name: Verilog module name. + clk_name: Clock signal name. + rst_name: Reset signal name (active low). + en_name: Enable signal name. + out_name: Output bit signal name. + """ + width = lfsr.width + poly = lfsr.polynomial + seed = lfsr.state # current state (or original seed would be saved) + + # Format the polynomial as a bit string for readability + poly_bits = f"{poly:0{width}b}" + + lines = [] + lines.append("// Auto-generated Galois LFSR") + lines.append(f"// Width: {width}, Polynomial: 0x{poly:0{width//4}x}") + lines.append(f"// Seed: 0x{seed:0{width//4}x}") + lines.append("") + lines.append(f"module {module_name} (") + lines.append(f" input wire {clk_name},") + lines.append(f" input wire {rst_name},") + lines.append(f" input wire {en_name},") + lines.append(f" output wire {out_name}") + lines.append(");") + lines.append("") + lines.append(f" reg [{width-1}:0] lfsr_state;") + lines.append("") + lines.append(f" assign {out_name} = lfsr_state[0];") + lines.append("") + lines.append(f" always @(posedge {clk_name} or negedge {rst_name}) begin") + lines.append(f" if (!{rst_name}) begin") + lines.append(f" lfsr_state <= {width}'h{seed:0{(width+3)//4}x};") + lines.append(f" end else if ({en_name}) begin") + lines.append(f" if (lfsr_state[0])") + lines.append(f" lfsr_state <= (lfsr_state >> 1) ^ {width}'h{poly:0{(width+3)//4}x};") + lines.append(f" else") + lines.append(f" lfsr_state <= lfsr_state >> 1;") + lines.append(f" end") + lines.append(f" end") + lines.append("") + lines.append(f"endmodule") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") \ No newline at end of file diff --git a/torchhd/verif.py b/torchhd/verif.py new file mode 100644 index 00000000..bffaa838 --- /dev/null +++ b/torchhd/verif.py @@ -0,0 +1,429 @@ +# +# MIT License +# +# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +r"""Golden trace generation and SystemVerilog testbench export. + +This module captures the cycle-by-cycle behavior of a torchhd model and +renders it into synthesizable SystemVerilog testbench components. The +workflow is: + +1. Wrap your model in :class:`GoldenTrace`. +2. Run your training or inference loop — every ``step``, ``add``, or encoding + call is recorded with its inputs and outputs. +3. Call :meth:`GoldenTrace.write_sv_stimulus` to emit an ``initial`` block + that replays every input transition at the correct cycle offset. +4. Call :meth:`GoldenTrace.write_sv_checker` to emit a parallel checker block + with ``assert`` statements comparing RTL outputs to golden expected values. + +Classes +------- +GoldenTrace + Cycle-accurate log of model I/O for testbench generation. +""" + +from typing import Optional, List, Dict, Any, Tuple, Union +import math + +import torch +from torch import Tensor + + +__all__ = [ + "GoldenTrace", +] + + +class GoldenTrace: + r"""Cycle-accurate trace of model inputs and expected outputs. + + Records every operation performed on a wrapped model: encoding calls, + weight updates, forward passes. Produces a timestamped log that can be + rendered into SystemVerilog stimulus/checker blocks. + + Args: + name: Module display name for generated comments. + clk_period_ns: Clock period for the stimulus file in nanoseconds. + Default: 10 (100 MHz). + tolerance: Acceptable absolute error between RTL output and golden + output for floating-point comparisons. For integer/fixed-point + comparisons tolerance is in LSBs. + + Attributes: + events (list): List of recorded events. Each event is a dict with: + ``cycle``, ``kind``, ``input_data``, ``output_data``, ``metadata``. + """ + + name: str + clk_period_ns: float + tolerance: float + events: List[Dict[str, Any]] + + def __init__( + self, + name: str = "hd_model", + clk_period_ns: float = 10.0, + tolerance: float = 1e-6, + ) -> None: + self.name = name + self.clk_period_ns = clk_period_ns + self.tolerance = tolerance + self.events = [] + self._cycle = 0 + + def _record( + self, + kind: str, + inputs: Dict[str, Any], + outputs: Any, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Record one event.""" + event = { + "cycle": self._cycle, + "kind": kind, + "inputs": inputs, + "outputs": outputs, + "metadata": metadata or {}, + } + self.events.append(event) + + def record_encode( + self, + input_data: Tensor, + encoded: Tensor, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + r"""Record an encoding step (input → hypervector). + + Args: + input_data: Raw input features. + encoded: Resulting hypervector(s). + metadata: Optional extra info (feature names, etc.). + """ + self._record( + "encode", + {"input": input_data.detach().clone()}, + encoded.detach().clone(), + metadata, + ) + self._cycle += 1 + + def record_step( + self, + input_hv: Tensor, + target: Any, + outputs: Optional[Tensor] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + r"""Record a training step (hypervector + label → update). + + Args: + input_hv: Input hypervector. + target: Class label or target vector. + outputs: Model output logits (if prediction was made). + metadata: Optional extra info. + """ + if isinstance(target, torch.Tensor): + tgt_copy = target.detach().clone() + elif isinstance(target, (int, float)): + tgt_copy = target + else: + tgt_copy = str(target) + + self._record( + "step", + {"input": input_hv.detach().clone(), "target": tgt_copy}, + outputs.detach().clone() if outputs is not None and isinstance(outputs, torch.Tensor) else outputs, + metadata, + ) + self._cycle += 1 + + def record_forward( + self, + input_hv: Tensor, + outputs: Tensor, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + r"""Record an inference forward pass. + + Args: + input_hv: Input hypervector. + outputs: Model output logits or class scores. + metadata: Optional extra info. + """ + self._record( + "forward", + {"input": input_hv.detach().clone()}, + outputs.detach().clone(), + metadata, + ) + self._cycle += 1 + + def record_custom( + self, + kind: str, + inputs: Dict[str, Any], + outputs: Any, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + r"""Record a custom event with arbitrary I/O. + + Args: + kind: Event type label. + inputs: Dict of input name → tensor/scalar. + outputs: Output tensor or scalar. + metadata: Optional extra info. + """ + inputs_clean = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs_clean[k] = v.detach().clone() + else: + inputs_clean[k] = v + + if isinstance(outputs, torch.Tensor): + outputs_clean = outputs.detach().clone() + else: + outputs_clean = outputs + + self._record(kind, inputs_clean, outputs_clean, metadata) + self._cycle += 1 + + # ---------------------------------------------------------------- + # SystemVerilog export + # ---------------------------------------------------------------- + + def write_sv_stimulus( + self, + filepath: str, + *, + input_bits: int = 8, + output_bits: int = 32, + signal_map: Optional[Dict[str, str]] = None, + ) -> None: + r"""Emit a SystemVerilog stimulus ``initial`` block. + + Generates an ``initial begin ... end`` block that replays every + recorded input at the corresponding cycle boundary using ``#CLK_PERIOD`` + delays. The resulting file can be ``\`include``-d into a testbench. + + Args: + filepath: Output ``.sv`` file path. + input_bits: Bit-width of input signals (used for formatting). + output_bits: Bit-width of output signals. + signal_map: Optional mapping from trace signal names to RTL + signal names (e.g. ``{"input": "hv_in", "target": "class_in"}``). + + Example output:: + + // Stimulus for hd_model + // Cycle 0: encode + hv_in = 256'hA3F4...; + #10; // wait for clock + // Cycle 1: step + hv_in = 256'h12C0...; + class_in = 8'd3; + #10; + """ + if signal_map is None: + signal_map = { + "input": "hv_in", + "target": "class_in", + } + + lines = [] + lines.append(f"// Auto-generated stimulus for {self.name}") + lines.append(f"// {len(self.events)} events, clk_period = {self.clk_period_ns} ns") + lines.append("") + lines.append(f"initial begin") + lines.append(f" // Reset / wait for reset") + lines.append(f" #({self.clk_period_ns} * 2);") + lines.append("") + + for i, event in enumerate(self.events): + kind = event["kind"] + cycle = event["cycle"] + lines.append(f" // Cycle {cycle} ({kind})") + + # Emit input assignments + for iname, ivalue in event["inputs"].items(): + sig = signal_map.get(iname, iname) + if isinstance(ivalue, torch.Tensor): + flat = ivalue.flatten() + if flat.numel() == 1: + val = int(flat[0]) & ((1 << input_bits) - 1) + lines.append(f" {sig} = {input_bits}'d{val};") + else: + # Multi-element vector: emit as concatenation + hex_str = "" + for v in flat: + hex_str += f"{int(v) & ((1 << input_bits) - 1):02x}" + # SystemVerilog hex literal + bit_width = flat.numel() * input_bits + lines.append(f" {sig} = {bit_width}'h{hex_str};") + elif isinstance(ivalue, int): + lines.append(f" {sig} = {input_bits}'d{ivalue};") + else: + lines.append(f" {sig} = {ivalue};") + + lines.append(f" #({self.clk_period_ns}); // wait for clock") + lines.append("") + + lines.append(f" // Done") + lines.append(f" #({self.clk_period_ns} * 5);") + lines.append(f" $finish;") + lines.append(f"end") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + def write_sv_checker( + self, + filepath: str, + *, + input_bits: int = 8, + output_bits: int = 32, + signal_map: Optional[Dict[str, str]] = None, + tolerance_lsb: int = 1, + ) -> None: + r"""Emit a SystemVerilog checker block with ``assert`` statements. + + For each recorded event that has output data, generates an assertion + comparing the RTL output signal to the golden expected value. + + Args: + filepath: Output ``.sv`` file path. + input_bits: Bit-width of input signals. + output_bits: Bit-width of output signals. + signal_map: Mapping from trace signal names to RTL signal names. + Default maps ``"output"`` → ``"hv_out"``, ``"score"`` → ``"score_out"``. + tolerance_lsb: Allowed LSB error for integer/fixed-point comparisons. + + Example output:: + + // Checker for hd_model + always @(posedge clk) begin + if (check_cycle == 0) begin + assert (hv_out == 256'hEXPECTED) else $error("Cycle 0 mismatch"); + end + ... + end + """ + if signal_map is None: + signal_map = { + "output": "hv_out", + "score": "score_out", + "class": "class_out", + } + + lines = [] + lines.append(f"// Auto-generated checker for {self.name}") + lines.append(f"// {len(self.events)} events, tolerance = ±{tolerance_lsb} LSB") + lines.append("") + + # Attempt to infer a sensible output signal name from the events + output_signal = "hv_out" + for event in self.events: + if event["outputs"] is not None: + if isinstance(event["outputs"], torch.Tensor) and event["outputs"].numel() > 1: + output_signal = signal_map.get("output", "hv_out") + else: + output_signal = signal_map.get("score", "score_out") + break + + lines.append(f"// Output signal: {output_signal}") + lines.append("") + + # Counter + lines.append(f"reg [31:0] check_cycle = 0;") + lines.append(f"reg [31:0] check_errors = 0;") + lines.append("") + lines.append(f"always @(posedge clk) begin") + lines.append(f" if (!rst_n) begin") + lines.append(f" check_cycle <= 0;") + lines.append(f" check_errors <= 0;") + lines.append(f" end else begin") + lines.append(f" case (check_cycle)") + + for i, event in enumerate(self.events): + if event["outputs"] is None: + continue + outputs = event["outputs"] + if not isinstance(outputs, torch.Tensor): + continue + + flat = outputs.flatten() + if flat.numel() == 1: + expected = int(flat[0]) + lines.append(f" {i}: begin") + lines.append(f" if ({output_signal} < ({expected} - {tolerance_lsb}) || " + f"{output_signal} > ({expected} + {tolerance_lsb})) begin") + lines.append(f' $error("Cycle %0d: expected ~%0d, got %0d", ' + f"check_cycle, {expected}, {output_signal});") + lines.append(f" check_errors <= check_errors + 1;") + lines.append(f" end") + lines.append(f" end") + else: + # Multi-element: emit concatenated expected + hex_str = "" + for v in flat: + hex_str += f"{int(v) & ((1 << output_bits) - 1):08x}" + bit_width = flat.numel() * output_bits + lines.append(f" {i}: begin") + lines.append(f" if ({output_signal} !== {bit_width}'h{hex_str}) begin") + lines.append(f' $error("Cycle %0d: output mismatch", check_cycle);') + lines.append(f" check_errors <= check_errors + 1;") + lines.append(f" end") + lines.append(f" end") + + lines.append(f" default: ;") + lines.append(f" endcase") + lines.append(f" check_cycle <= check_cycle + 1;") + lines.append(f" end") + lines.append(f"end") + + with open(filepath, "w") as f: + f.write("\n".join(lines) + "\n") + + # ---------------------------------------------------------------- + # Summary + # ---------------------------------------------------------------- + + def summary(self) -> str: + """Return a human-readable summary of the trace.""" + kinds = {} + total_inputs = 0 + total_outputs = 0 + for e in self.events: + kinds[e["kind"]] = kinds.get(e["kind"], 0) + 1 + if e["outputs"] is not None: + total_outputs += 1 + + lines = [ + f"GoldenTrace({self.name!r}): {len(self.events)} events over {self._cycle} cycles", + f" Clock: {self.clk_period_ns} ns", + f" Tolerance: {self.tolerance}", + ] + for kind, count in sorted(kinds.items()): + lines.append(f" {kind}: {count}") + return "\n".join(lines) \ No newline at end of file