diff --git a/README.md b/README.md index 0555b85..6e9e9b9 100644 --- a/README.md +++ b/README.md @@ -348,6 +348,14 @@ pytest - 📚 Documentation - 🐛 Bug fixes +**Quick template for new architecture support:** +```python +from quantllm import register_architecture, turbo + +register_architecture("new-arch", base_model_type="llama") +model = turbo("org/new-arch-7b", base_model_fallback=True, trust_remote_code=True) +``` + --- ## 📜 License diff --git a/docs/guide/loading-models.md b/docs/guide/loading-models.md index 3bf1010..12de986 100644 --- a/docs/guide/loading-models.md +++ b/docs/guide/loading-models.md @@ -74,6 +74,59 @@ model = turbo( ) ``` +### New Architecture Fallbacks (for very recent model releases) + +If `transformers` does not recognize a just-released architecture yet, register a fallback family: + +```python +from quantllm import turbo, register_architecture + +# Map new architecture/model_type to a compatible base family +register_architecture("newmodel", base_model_type="llama") + +model = turbo( + "new-model-org/NewModel-7B", + model_type_override="llama", # optional explicit override + base_model_fallback=True, # enabled by default; can be disabled + trust_remote_code=True, +) +``` + +> ⚠️ **Security note:** `trust_remote_code=True` executes model-provided code. +> Only enable it for trusted publishers, especially when loading unregistered or very new architectures. + +You can also load from config only (no checkpoint weights) while waiting for upstream support: + +```python +model = turbo( + "new-model-org/NewModel-7B", + from_config_only=True, + trust_remote_code=True, +) +``` + +#### Fast contribution template for new architectures + +1. Add a registration in your code or PR: + - `register_architecture("new-arch", base_model_type="llama")` +2. Validate loading with: + - `turbo("org/model", base_model_fallback=True, trust_remote_code=True)` +3. Add/extend a focused test in `tests/test_architecture_fallback.py`. + +#### Real-world style "released yesterday" example + +```python +from quantllm import turbo, register_architecture + +# Example: transformers doesn't recognize Qwen3 yet +register_architecture("qwen3", base_model_type="qwen2") + +model = turbo( + "Qwen/Qwen3-8B", + trust_remote_code=True, +) +``` + ### Memory Options ```python diff --git a/quantllm/__init__.py b/quantllm/__init__.py index 6f2933b..da0b7fe 100644 --- a/quantllm/__init__.py +++ b/quantllm/__init__.py @@ -35,6 +35,7 @@ from .core import ( turbo, TurboModel, + register_architecture, SmartConfig, HardwareProfiler, ModelAnalyzer, @@ -117,6 +118,7 @@ def show_banner(force: bool = False): # Main API "turbo", "TurboModel", + "register_architecture", "SmartConfig", "HardwareProfiler", "ModelAnalyzer", diff --git a/quantllm/core/__init__.py b/quantllm/core/__init__.py index 5e64f1a..823ca59 100644 --- a/quantllm/core/__init__.py +++ b/quantllm/core/__init__.py @@ -8,7 +8,7 @@ from .hardware import HardwareProfiler from .smart_config import SmartConfig from .model_analyzer import ModelAnalyzer -from .turbo_model import TurboModel, turbo +from .turbo_model import TurboModel, turbo, register_architecture from .compilation import ( compile_model, compile_for_inference, @@ -51,6 +51,7 @@ "ModelAnalyzer", "TurboModel", "turbo", + "register_architecture", # Compilation "compile_model", "compile_for_inference", diff --git a/quantllm/core/turbo_model.py b/quantllm/core/turbo_model.py index 53ec668..ffabb37 100644 --- a/quantllm/core/turbo_model.py +++ b/quantllm/core/turbo_model.py @@ -5,9 +5,12 @@ """ import os +import re import shutil import tempfile -from typing import Optional, Dict, Any, Union, List +import copy +from functools import lru_cache +from typing import Optional, Dict, Any, Union, List, Type import torch import torch.nn as nn from transformers import ( @@ -32,6 +35,14 @@ "quantization": "Q4_K_M", "push_quantization": None, } +DEFAULT_ARCHITECTURE_FALLBACKS = { + "llama": "llama", + "mistral": "mistral", + "mixtral": "mistral", + "qwen": "qwen2", + "phi": "phi", + "gemma": "gemma", +} class TurboModel: @@ -58,6 +69,9 @@ class TurboModel: >>> model.export("gguf", "my_model.gguf") """ + _architecture_registry: Dict[str, str] = {} + _model_class_registry: Dict[str, Type[PreTrainedModel]] = {} + def __init__( self, model: PreTrainedModel, @@ -82,6 +96,180 @@ def __init__( self._lora_applied = False self.export_push_config = self._build_export_push_config(export_push_config) self.verbose = verbose + + @classmethod + def register_architecture( + cls, + architecture: str, + *, + base_model_type: Optional[str] = None, + model_class: Optional[Type[PreTrainedModel]] = None, + ) -> None: + """ + Register a new architecture alias and optional explicit model class. + + Args: + architecture: Architecture or model type name to register + base_model_type: Base model family to fall back to (e.g., "llama") + model_class: Explicit model class with from_pretrained() + """ + normalized = architecture.lower().strip() + if not normalized: + raise ValueError("architecture must be a non-empty string") + + if base_model_type: + cls._architecture_registry[normalized] = base_model_type.lower().strip() + + if model_class is not None: + cls._model_class_registry[normalized] = model_class + + @classmethod + def resolve_model_type( + cls, + model_name: str, + *, + config_model_type: Optional[str] = None, + model_type_override: Optional[str] = None, + ) -> Optional[str]: + """ + Resolve model type using override, registry, and default family patterns. + + If config_model_type is provided but unregistered, the original config value + is returned unchanged. + """ + if model_type_override: + return model_type_override.lower().strip() + + model_type = (config_model_type or "").lower().strip() + if model_type: + return cls._architecture_registry.get(model_type, model_type) + + name = model_name.lower() + for pattern, fallback in cls._architecture_registry.items(): + if cls._matches_model_name_pattern(name, pattern): + return fallback + + for pattern, fallback in DEFAULT_ARCHITECTURE_FALLBACKS.items(): + if cls._matches_model_name_pattern(name, pattern): + return fallback + + return None + + @classmethod + def _matches_model_name_pattern(cls, model_name: str, pattern: str) -> bool: + """Return True when pattern appears as a token in model_name.""" + return cls._compiled_model_name_pattern(pattern).search(model_name) is not None + + @staticmethod + @lru_cache(maxsize=None) + def _compiled_model_name_pattern(pattern: str): + """Compile and cache token-boundary regex patterns for model-name matching.""" + escaped = re.escape(pattern) + # Match architecture tokens as standalone chunks split by separators. + return re.compile(rf"(^|[^a-z0-9]){escaped}([^a-z0-9]|$)") + + @staticmethod + def _should_apply_quantization( + quantize: bool, + bits: int, + from_config_only: bool, + ) -> bool: + """Check whether quantization arguments should be added for loading.""" + return quantize and bits < 16 and not from_config_only + + @classmethod + def _load_model_with_fallback( + cls, + model_name: str, + model_kwargs: Dict[str, Any], + *, + trust_remote_code: bool, + hf_config: Optional[Any], + model_type_override: Optional[str], + base_model_fallback: bool, + from_config_only: bool, + ) -> PreTrainedModel: + """Load model with architecture fallback and optional config-only mode.""" + config_model_type = (getattr(hf_config, "model_type", None) or "").lower().strip() + is_registered_architecture = config_model_type in cls._architecture_registry if config_model_type else False + resolved_model_type = cls.resolve_model_type( + model_name, + config_model_type=getattr(hf_config, "model_type", None), + model_type_override=model_type_override, + ) + resolved_config = hf_config + + if hf_config is not None and resolved_model_type: + current_model_type = getattr(hf_config, "model_type", None) + if current_model_type != resolved_model_type: + resolved_config = copy.deepcopy(hf_config) + setattr(resolved_config, "model_type", resolved_model_type) + + if ( + trust_remote_code + and config_model_type + and not is_registered_architecture + and config_model_type not in DEFAULT_ARCHITECTURE_FALLBACKS.values() + ): + logger.warning( + "trust_remote_code=True is enabled for unregistered architecture '%s' " + "(resolved fallback: '%s'). Only use this for models from trusted sources.", + config_model_type, + resolved_model_type, + ) + + if from_config_only: + if resolved_config is None: + raise ValueError( + "from_config_only=True requires a loadable config. " + "Try trust_remote_code=True or set model_type_override." + ) + return AutoModelForCausalLM.from_config( + resolved_config, + trust_remote_code=trust_remote_code, + torch_dtype=model_kwargs.get("torch_dtype"), + ) + + try: + return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + except Exception as primary_error: + if not base_model_fallback: + raise + fallback_error = None + # Fallback priority: resolved config model_type -> explicitly registered model class. + if resolved_config is not None: + fallback_kwargs = dict(model_kwargs) + fallback_kwargs["config"] = resolved_config + try: + return AutoModelForCausalLM.from_pretrained(model_name, **fallback_kwargs) + except Exception as fallback_config_error: + fallback_error = fallback_config_error + + if resolved_model_type: + registered_cls = cls._model_class_registry.get(resolved_model_type) + if registered_cls is not None: + class_kwargs = dict(model_kwargs) + if resolved_config is not None: + class_kwargs["config"] = resolved_config + try: + return registered_cls.from_pretrained(model_name, **class_kwargs) + except Exception as fallback_registered_error: + fallback_error = fallback_registered_error + + error_details = f" Last fallback error: {fallback_error}" if fallback_error else "" + architecture_label = config_model_type or "" + resolved_label = resolved_model_type or "" + + raise RuntimeError( + "Failed to load model with AutoModelForCausalLM and fallback resolution.\n" + f"Architecture '{architecture_label}' resolved to base model type '{resolved_label}'.\n" + "Try one of:\n" + f"1) register_architecture('{architecture_label}', base_model_type='llama').\n" + "2) Use model_type_override='llama' (or your compatible base family).\n" + "3) Use from_config_only=True with a loadable config " + "(usually trust_remote_code=True)." + + error_details + ) from (fallback_error or primary_error) @classmethod def from_pretrained( @@ -96,6 +284,9 @@ def from_pretrained( # Advanced options trust_remote_code: bool = True, quantize: bool = True, + model_type_override: Optional[str] = None, + base_model_fallback: bool = True, + from_config_only: bool = False, config_override: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, verbose: bool = True, @@ -117,6 +308,9 @@ def from_pretrained( dtype: Override dtype (default: bf16 if available, else fp16) trust_remote_code: Trust remote code in model quantize: Whether to quantize the model + model_type_override: Override detected model_type for very new architectures + base_model_fallback: Retry loading with resolved base model config on failure + from_config_only: Build model from config only (without loading weights) config_override: Dict to override any auto-detected settings config: Shared export/push config (format, quantization, push_format, etc.) verbose: Print loading progress @@ -196,6 +390,8 @@ def from_pretrained( "torch_dtype": smart_config.dtype, } + hf_config = None + # Check if model is already quantized to prevent conflicts try: from transformers import AutoConfig @@ -225,7 +421,7 @@ def from_pretrained( pass # Ignore config loading errors, proceed with defaults # Apply quantization if requested - if quantize and smart_config.bits < 16: + if cls._should_apply_quantization(quantize, smart_config.bits, from_config_only): model_kwargs.update(cls._get_quantization_kwargs(smart_config)) # Device map for memory management @@ -240,9 +436,14 @@ def from_pretrained( if verbose: task = p.add_task("Downloading & Loading model...", total=None) - model = AutoModelForCausalLM.from_pretrained( + model = cls._load_model_with_fallback( model_name, - **model_kwargs, + model_kwargs, + trust_remote_code=trust_remote_code, + hf_config=hf_config, + model_type_override=model_type_override, + base_model_fallback=base_model_fallback, + from_config_only=from_config_only, ) if verbose: @@ -1892,6 +2093,25 @@ def _replace_with_triton(self, module: nn.Module, bits: int) -> int: return count +def register_architecture( + architecture: str, + *, + base_model_type: Optional[str] = None, + model_class: Optional[Type[PreTrainedModel]] = None, +) -> None: + """ + Register a new architecture alias and optional explicit model class. + + Example: + >>> register_architecture("my-new-model", base_model_type="llama") + """ + TurboModel.register_architecture( + architecture, + base_model_type=base_model_type, + model_class=model_class, + ) + + def turbo( model: str, *, @@ -1899,6 +2119,7 @@ def turbo( max_length: Optional[int] = None, device: Optional[str] = None, dtype: Optional[str] = None, + base_model_fallback: bool = True, config: Optional[Dict[str, Any]] = None, **kwargs, ) -> TurboModel: @@ -1914,6 +2135,7 @@ def turbo( max_length: Override max sequence length (default: auto) device: Override device (default: best GPU) dtype: Override dtype (default: bf16/fp16) + base_model_fallback: Retry with resolved base model config on first-load failure config: Shared export/push config (format, quantization, push_format, etc.) **kwargs: Additional options passed to from_pretrained @@ -1945,6 +2167,7 @@ def turbo( max_length=max_length, device=device, dtype=dtype, + base_model_fallback=base_model_fallback, config=config, **kwargs, ) diff --git a/tests/test_architecture_fallback.py b/tests/test_architecture_fallback.py new file mode 100644 index 0000000..9382178 --- /dev/null +++ b/tests/test_architecture_fallback.py @@ -0,0 +1,266 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import transformers + +from quantllm.core.turbo_model import TurboModel +import quantllm.core.turbo_model as turbo_model_module + + +class _DummySmartConfig(SimpleNamespace): + def print_summary(self): + return None + + +def _make_smart_config(): + return _DummySmartConfig( + bits=16, + effective_loading_bits=16, + dtype="float16", + cpu_offload=False, + device="cpu", + gradient_checkpointing=False, + use_flash_attention=False, + compile_model=False, + ) + + +def _make_tokenizer(): + return SimpleNamespace(pad_token=None, eos_token="", eos_token_id=2) + + +def test_resolve_model_type_detects_common_patterns(): + assert TurboModel.resolve_model_type("meta-llama/Llama-3.2-3B") == "llama" + # Newer Qwen names still fall back to the qwen2 base family. + assert TurboModel.resolve_model_type("Qwen/Qwen3-8B") == "qwen2" + assert TurboModel.resolve_model_type("org/custom-arch-1b") is None + + +def test_register_architecture_maps_new_model_to_base_family(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + TurboModel.register_architecture("newmodel", base_model_type="llama") + + assert TurboModel.resolve_model_type("org/newmodel-7b") == "llama" + + +def test_registered_class_fallback_is_used(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + raise ValueError("Unrecognized configuration class") + + @staticmethod + def from_config(*args, **kwargs): + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + registered_call = Mock() + + def _registered_from_pretrained(cls, *args, **kwargs): + registered_call() + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + class _RegisteredModel: + from_pretrained = classmethod(_registered_from_pretrained) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + TurboModel.register_architecture("newmodel", base_model_type="llama") + TurboModel.register_architecture("llama", model_class=_RegisteredModel) + + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=False, + verbose=False, + ) + + assert registered_call.called is True + assert loaded.model.config.model_type == "llama" + + +def test_from_pretrained_supports_from_config_only(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="llama", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + called_from_pretrained = False + called_from_config = False + + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.called_from_pretrained = True + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + @classmethod + def from_config(cls, *args, **kwargs): + cls.called_from_config = True + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + loaded = TurboModel.from_pretrained( + "org/llama-like-7b", + quantize=False, + verbose=False, + from_config_only=True, + ) + + assert _FakeAutoModel.called_from_pretrained is False + assert _FakeAutoModel.called_from_config is True + assert loaded.model.config.model_type == "llama" + + +def test_trust_remote_code_warns_for_unregistered_architecture(monkeypatch, caplog): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: _make_smart_config(), + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + if "config" in kwargs: + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + raise ValueError("Unrecognized configuration class") + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + with caplog.at_level("WARNING"): + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=False, + verbose=False, + base_model_fallback=True, + trust_remote_code=True, + ) + + assert loaded.model.config.model_type == "llama" + assert ( + "trust_remote_code=True is enabled for unregistered architecture 'newmodel'" + in caplog.text + ) + + +def test_quantization_kwargs_are_preserved_during_fallback(monkeypatch): + monkeypatch.setattr(TurboModel, "_architecture_registry", {}) + monkeypatch.setattr(TurboModel, "_model_class_registry", {}) + smart_config = _make_smart_config() + smart_config.bits = 4 + monkeypatch.setattr( + turbo_model_module.SmartConfig, + "detect", + lambda *args, **kwargs: smart_config, + ) + monkeypatch.setattr( + turbo_model_module.AutoTokenizer, + "from_pretrained", + lambda *args, **kwargs: _make_tokenizer(), + ) + monkeypatch.setattr( + transformers.AutoConfig, + "from_pretrained", + lambda *args, **kwargs: SimpleNamespace( + model_type="newmodel", + quantization_config=None, + ), + ) + monkeypatch.setattr( + TurboModel, + "_get_quantization_kwargs", + classmethod(lambda cls, cfg: {"quantization_config": "nf4-sentinel"}), + ) + + calls = [] + + class _FakeAutoModel: + @staticmethod + def from_pretrained(*args, **kwargs): + calls.append(kwargs) + if len(calls) == 1: + raise ValueError("Unrecognized configuration class") + return SimpleNamespace(config=SimpleNamespace(model_type="llama")) + + monkeypatch.setattr( + turbo_model_module, + "AutoModelForCausalLM", + _FakeAutoModel, + ) + + loaded = TurboModel.from_pretrained( + "org/newmodel-7b", + quantize=True, + verbose=False, + base_model_fallback=True, + ) + + assert loaded.model.config.model_type == "llama" + assert len(calls) == 2 + assert calls[0]["quantization_config"] == "nf4-sentinel" + assert calls[1]["quantization_config"] == "nf4-sentinel"