CPU overhead optimizations for te autocast#2957
CPU overhead optimizations for te autocast#2957vthumbe1503 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR reduces CPU overhead on the FP8 autocast hot path via two complementary changes: (1) lazy Confidence Score: 4/5Safe to merge — no P0/P1 defects; all findings are style or edge-case hardening suggestions. All three comments are P2: one on a fragile key separator, one on a silent-corruption edge case for same-instance nested reuse (not a realistic usage pattern), and one on undeclared dataclass field storage. The core logic of caching, invalidation, and context manager enter/exit is sound. transformer_engine/pytorch/quantization.py — specifically get_unique_autocast_key and the autocast.enter/exit pair. Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant AC as autocast (class)
participant FP8GM as FP8GlobalStateManager
participant Recipe as Recipe.__repr__
User->>AC: autocast(enabled, recipe, ...)
AC->>AC: __init__: store params, _fp8_state=None
User->>AC: __enter__()
AC->>AC: check_recipe_support (if enabled)
AC->>FP8GM: get_autocast_state()
FP8GM-->>AC: fp8_state saved in self._fp8_state
AC->>FP8GM: autocast_enter(enabled, recipe, group, ...)
FP8GM->>FP8GM: get_unique_autocast_key(recipe, group)
FP8GM->>Recipe: recipe.__dict__.get(_cached_repr)
alt cache hit
Recipe-->>FP8GM: cached repr string
else cache miss
Recipe->>Recipe: compute repr, store in __dict__
Recipe-->>FP8GM: repr string
end
FP8GM->>FP8GM: key = recipe_repr|group_id
AC-->>User: self
User->>AC: __exit__(exc_type, exc_val, exc_tb)
AC->>FP8GM: set_autocast_state(self._fp8_state)
AC->>FP8GM: autocast_exit(enabled, _graph)
AC-->>User: None (exceptions not suppressed)
|
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | ||
| if recipe_repr is None: | ||
| recipe_repr = str(recipe) | ||
| group_id = id(group) if group is not None else 0 | ||
| return f"{recipe_repr}|{group_id}" |
There was a problem hiding this comment.
Key format change could produce ambiguous keys
The new key format f"{recipe_repr}|{group_id}" uses | as a separator without escaping. If a future recipe's __repr__ ever emits a | character, two distinct (recipe, group) pairs could map to the same string. The old str(tuple) format was unambiguous because it quoted the recipe repr. A safer pattern uses a separator that cannot appear in repr output, or encodes the parts deterministically.
There was a problem hiding this comment.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | |
| if recipe_repr is None: | |
| recipe_repr = str(recipe) | |
| group_id = id(group) if group is not None else 0 | |
| return f"{recipe_repr}|{group_id}" | |
| group_id = id(group) if group is not None else None | |
| return f"recipe=({str(recipe)}),group={group_id}" |
| def __enter__(self) -> "autocast": | ||
| if self._enabled: | ||
| check_recipe_support(self._recipe) | ||
| # Save current state so we always restore it on exit. | ||
| self._fp8_state = FP8GlobalStateManager.get_autocast_state() | ||
| FP8GlobalStateManager.autocast_enter( | ||
| enabled=self._enabled, | ||
| calibrating=self._calibrating, | ||
| fp8_recipe=self._recipe, | ||
| fp8_group=self._amax_reduction_group, | ||
| _graph=self._graph, | ||
| ) | ||
| return self | ||
|
|
||
| FP8GlobalStateManager.autocast_enter( | ||
| enabled=enabled, | ||
| calibrating=calibrating, | ||
| fp8_recipe=recipe, | ||
| fp8_group=amax_reduction_group, | ||
| _graph=_graph, | ||
| ) | ||
| try: | ||
| yield | ||
| finally: | ||
| FP8GlobalStateManager.set_autocast_state(fp8_state) | ||
| FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) | ||
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | ||
| FP8GlobalStateManager.set_autocast_state(self._fp8_state) | ||
| FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph) | ||
| # Do not suppress exceptions. | ||
| return None |
There was a problem hiding this comment.
Nested reuse of the same instance silently corrupts state
The old generator-based implementation raised RuntimeError: generator already executing if you tried to enter the same context manager object twice concurrently. The new class-based implementation silently accepts nested reuse, but the second __enter__ call overwrites self._fp8_state with the state captured inside the first context, so the outer __exit__ restores the wrong state permanently.
ctx = autocast(enabled=True, recipe=recipe)
with ctx: # _fp8_state = pre_context_state
with ctx: # _fp8_state = state_inside_first_block ← overwrites!
pass # __exit__: restores state_inside_first_block
# _fp8_state is now state_inside_first_block
# __exit__: restores state_inside_first_block, NOT pre_context_state ← bugAdding a guard in __enter__ would preserve the old safety behavior:
def __enter__(self) -> "autocast":
if self._fp8_state is not None:
raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
...| def __repr__(self) -> str: | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result |
There was a problem hiding this comment.
_cached_repr stored outside declared dataclass fields
MMParams is @dataclass(frozen=True). Storing _cached_repr via object.__setattr__ bypasses the frozen guard correctly in CPython, but _cached_repr is not a declared dataclass field — it won't appear in dataclasses.fields(), dataclasses.asdict(), dataclasses.astuple(), or copy.replace(). If downstream code serializes or copies an MMParams instance, the cached repr would be lost silently. Documenting this with a comment or declaring it as field(init=False, repr=False, compare=False) would make the intent clearer. The same applies to QParams.
There was a problem hiding this comment.
I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.
| def __repr__(self) -> str: | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result |
There was a problem hiding this comment.
I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.
| # changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot | ||
| # path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially | ||
| # free after the first call. | ||
| _cached_repr: Optional[str] = None |
There was a problem hiding this comment.
Three problems:
_cached_repris being set as a class attr, not an instance attr.- Accessing
_cached_reprvia__dict__is non-standard and bug-prone. - Splitting the cache logic between the base class and child classes results in code duplication and more risk of bugs, especially if it involves non-standard
__dict__accesses.
What if we concentrated the caching logic in the base class:
class Recipe:
def __init__(self) -> None:
self._cached_repr: Optional[str] = None
@abc.abstractmethod
def _make_repr(self) -> str:
...
def __repr__(self) -> str:
if self._cached_repr is None:
self._cached_repr = self._make_repr()
return self._cached_repr
...
class DelayedScaling(Recipe):
def _make_repr(self) -> str:
return f"..."| # directly getting the cached repr is about 40 ns faster than str(recipe) | ||
| # on grace systems. |
There was a problem hiding this comment.
This is good to mention in the PR description, but not that useful in the code itself. Profiling becomes outdated once we move on to the next architecture.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | ||
| if recipe_repr is None: | ||
| recipe_repr = str(recipe) | ||
| group_id = id(group) if group is not None else 0 | ||
| return f"{recipe_repr}|{group_id}" |
There was a problem hiding this comment.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | |
| if recipe_repr is None: | |
| recipe_repr = str(recipe) | |
| group_id = id(group) if group is not None else 0 | |
| return f"{recipe_repr}|{group_id}" | |
| group_id = id(group) if group is not None else None | |
| return f"recipe=({str(recipe)}),group={group_id}" |
| # Class-based context manager (instead of ``@contextmanager`` from contextlib) | ||
| # to avoid the ~0.5us / invocation overhead of contextlib's generator-driven | ||
| # ``GeneratorContextManager``. ``__slots__`` further avoids per-instance | ||
| # dict allocation. |
There was a problem hiding this comment.
Why are we mentioning the context manager here? It makes sense for this PR, but once the code is merged it will be completely random. This comment should explain what we are doing with __slots__, and we should explain the custom context manager logic in __enter__ and __exit__.
| # Do not suppress exceptions. | ||
| return None |
There was a problem hiding this comment.
Nit: The function already returns None and the comment is trivially true (all Python outside of a try statement is not suppressing exceptions).
| # Do not suppress exceptions. | |
| return None |
Description
te-autocast has quite a bit of CPU overheads on Grace Systems.
Here are the results on GB200 after the optimizations
Without Optimizations

Optimization1 --> Cache recipe string representation for getting unique autocast key
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: