Skip to content

CPU overhead optimizations for te autocast#2957

Open
vthumbe1503 wants to merge 1 commit intoNVIDIA:mainfrom
vthumbe1503:cpu_opt_te_autocast
Open

CPU overhead optimizations for te autocast#2957
vthumbe1503 wants to merge 1 commit intoNVIDIA:mainfrom
vthumbe1503:cpu_opt_te_autocast

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 4, 2026

Description

te-autocast has quite a bit of CPU overheads on Grace Systems.
Here are the results on GB200 after the optimizations

  • Without Optimizations
    image

  • Optimization1 --> Cache recipe string representation for getting unique autocast key

image

  • Optimization2 --> Use enter, exit methods instead of using contextlib.contextmanager
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 4, 2026

Greptile Summary

This PR reduces CPU overhead on the FP8 autocast hot path via two complementary changes: (1) lazy _cached_repr caching on all Recipe subclasses and the frozen MMParams/QParams dataclasses, with invalidation through an overridden Recipe.__setattr__, and (2) converting the autocast context manager from a @contextmanager generator to a __slots__-bearing class to avoid ~0.5 µs per-invocation overhead from contextlib.GeneratorContextManager. The fp8_autocast compatibility shim is updated accordingly, and get_unique_autocast_key switches to a faster f-string key format. All findings are P2.

Confidence Score: 4/5

Safe to merge — no P0/P1 defects; all findings are style or edge-case hardening suggestions.

All three comments are P2: one on a fragile key separator, one on a silent-corruption edge case for same-instance nested reuse (not a realistic usage pattern), and one on undeclared dataclass field storage. The core logic of caching, invalidation, and context manager enter/exit is sound.

transformer_engine/pytorch/quantization.py — specifically get_unique_autocast_key and the autocast.enter/exit pair.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Converts autocast from a @contextmanager generator to a class-based context manager with __slots__, and fp8_autocast to a thin wrapper returning the class instance; also optimizes get_unique_autocast_key — two P2 concerns: unescaped `
transformer_engine/common/recipe/init.py Adds lazy _cached_repr caching to all Recipe subclasses and frozen dataclasses MMParams/QParams, with invalidation via Recipe.__setattr__; P2 concern that _cached_repr is stored outside declared dataclass fields on frozen classes.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant AC as autocast (class)
    participant FP8GM as FP8GlobalStateManager
    participant Recipe as Recipe.__repr__

    User->>AC: autocast(enabled, recipe, ...)
    AC->>AC: __init__: store params, _fp8_state=None

    User->>AC: __enter__()
    AC->>AC: check_recipe_support (if enabled)
    AC->>FP8GM: get_autocast_state()
    FP8GM-->>AC: fp8_state saved in self._fp8_state
    AC->>FP8GM: autocast_enter(enabled, recipe, group, ...)
    FP8GM->>FP8GM: get_unique_autocast_key(recipe, group)
    FP8GM->>Recipe: recipe.__dict__.get(_cached_repr)
    alt cache hit
        Recipe-->>FP8GM: cached repr string
    else cache miss
        Recipe->>Recipe: compute repr, store in __dict__
        Recipe-->>FP8GM: repr string
    end
    FP8GM->>FP8GM: key = recipe_repr|group_id
    AC-->>User: self

    User->>AC: __exit__(exc_type, exc_val, exc_tb)
    AC->>FP8GM: set_autocast_state(self._fp8_state)
    AC->>FP8GM: autocast_exit(enabled, _graph)
    AC-->>User: None (exceptions not suppressed)
Loading

Comments Outside Diff (3)

  1. transformer_engine/pytorch/quantization.py, line 595-599 (link)

    P2 Key format change could produce ambiguous keys

    The new key format f"{recipe_repr}|{group_id}" uses | as a separator without escaping. If a future recipe's __repr__ ever emits a | character, two distinct (recipe, group) pairs could map to the same string. The old str(tuple) format was unambiguous because it quoted the recipe repr. A safer pattern uses a separator that cannot appear in repr output, or encodes the parts deterministically.

  2. transformer_engine/pytorch/quantization.py, line 911-929 (link)

    P2 Nested reuse of the same instance silently corrupts state

    The old generator-based implementation raised RuntimeError: generator already executing if you tried to enter the same context manager object twice concurrently. The new class-based implementation silently accepts nested reuse, but the second __enter__ call overwrites self._fp8_state with the state captured inside the first context, so the outer __exit__ restores the wrong state permanently.

    ctx = autocast(enabled=True, recipe=recipe)
    with ctx:           # _fp8_state = pre_context_state
        with ctx:       # _fp8_state = state_inside_first_block  ← overwrites!
            pass        # __exit__: restores state_inside_first_block
        # _fp8_state is now state_inside_first_block
    # __exit__: restores state_inside_first_block, NOT pre_context_state  ← bug

    Adding a guard in __enter__ would preserve the old safety behavior:

    def __enter__(self) -> "autocast":
        if self._fp8_state is not None:
            raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
        ...
  3. transformer_engine/common/recipe/__init__.py, line 63-69 (link)

    P2 _cached_repr stored outside declared dataclass fields

    MMParams is @dataclass(frozen=True). Storing _cached_repr via object.__setattr__ bypasses the frozen guard correctly in CPython, but _cached_repr is not a declared dataclass field — it won't appear in dataclasses.fields(), dataclasses.asdict(), dataclasses.astuple(), or copy.replace(). If downstream code serializes or copies an MMParams instance, the cached repr would be lost silently. Documenting this with a comment or declaring it as field(init=False, repr=False, compare=False) would make the intent clearer. The same applies to QParams.

Reviews (1): Last reviewed commit: "cpu optimizations for te autocast" | Re-trigger Greptile

Comment on lines +595 to +599
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Key format change could produce ambiguous keys

The new key format f"{recipe_repr}|{group_id}" uses | as a separator without escaping. If a future recipe's __repr__ ever emits a | character, two distinct (recipe, group) pairs could map to the same string. The old str(tuple) format was unambiguous because it quoted the recipe repr. A safer pattern uses a separator that cannot appear in repr output, or encodes the parts deterministically.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
group_id = id(group) if group is not None else None
return f"recipe=({str(recipe)}),group={group_id}"

Comment on lines +911 to +929
def __enter__(self) -> "autocast":
if self._enabled:
check_recipe_support(self._recipe)
# Save current state so we always restore it on exit.
self._fp8_state = FP8GlobalStateManager.get_autocast_state()
FP8GlobalStateManager.autocast_enter(
enabled=self._enabled,
calibrating=self._calibrating,
fp8_recipe=self._recipe,
fp8_group=self._amax_reduction_group,
_graph=self._graph,
)
return self

FP8GlobalStateManager.autocast_enter(
enabled=enabled,
calibrating=calibrating,
fp8_recipe=recipe,
fp8_group=amax_reduction_group,
_graph=_graph,
)
try:
yield
finally:
FP8GlobalStateManager.set_autocast_state(fp8_state)
FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
FP8GlobalStateManager.set_autocast_state(self._fp8_state)
FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph)
# Do not suppress exceptions.
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Nested reuse of the same instance silently corrupts state

The old generator-based implementation raised RuntimeError: generator already executing if you tried to enter the same context manager object twice concurrently. The new class-based implementation silently accepts nested reuse, but the second __enter__ call overwrites self._fp8_state with the state captured inside the first context, so the outer __exit__ restores the wrong state permanently.

ctx = autocast(enabled=True, recipe=recipe)
with ctx:           # _fp8_state = pre_context_state
    with ctx:       # _fp8_state = state_inside_first_block  ← overwrites!
        pass        # __exit__: restores state_inside_first_block
    # _fp8_state is now state_inside_first_block
# __exit__: restores state_inside_first_block, NOT pre_context_state  ← bug

Adding a guard in __enter__ would preserve the old safety behavior:

def __enter__(self) -> "autocast":
    if self._fp8_state is not None:
        raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
    ...

Comment on lines +63 to +69
def __repr__(self) -> str:
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = f"MMParams(use_split_accumulator={self.use_split_accumulator})"
object.__setattr__(self, "_cached_repr", result)
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _cached_repr stored outside declared dataclass fields

MMParams is @dataclass(frozen=True). Storing _cached_repr via object.__setattr__ bypasses the frozen guard correctly in CPython, but _cached_repr is not a declared dataclass field — it won't appear in dataclasses.fields(), dataclasses.asdict(), dataclasses.astuple(), or copy.replace(). If downstream code serializes or copies an MMParams instance, the cached repr would be lost silently. Documenting this with a comment or declaring it as field(init=False, repr=False, compare=False) would make the intent clearer. The same applies to QParams.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.

Comment on lines +63 to +69
def __repr__(self) -> str:
cached = self.__dict__.get("_cached_repr")
if cached is not None:
return cached
result = f"MMParams(use_split_accumulator={self.use_split_accumulator})"
object.__setattr__(self, "_cached_repr", result)
return result
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.

# changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot
# path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially
# free after the first call.
_cached_repr: Optional[str] = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three problems:

  • _cached_repr is being set as a class attr, not an instance attr.
  • Accessing _cached_repr via __dict__ is non-standard and bug-prone.
  • Splitting the cache logic between the base class and child classes results in code duplication and more risk of bugs, especially if it involves non-standard __dict__ accesses.

What if we concentrated the caching logic in the base class:

class Recipe:

    def __init__(self) -> None:
        self._cached_repr: Optional[str] = None

    @abc.abstractmethod
    def _make_repr(self) -> str:
        ...

    def __repr__(self) -> str:
        if self._cached_repr is None:
            self._cached_repr = self._make_repr()
        return self._cached_repr

    ...

class DelayedScaling(Recipe):

    def _make_repr(self) -> str:
        return f"..."

Comment on lines +593 to +594
# directly getting the cached repr is about 40 ns faster than str(recipe)
# on grace systems.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to mention in the PR description, but not that useful in the code itself. Profiling becomes outdated once we move on to the next architecture.

Comment on lines +595 to +599
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else 0
return f"{recipe_repr}|{group_id}"
group_id = id(group) if group is not None else None
return f"recipe=({str(recipe)}),group={group_id}"

Comment on lines +883 to +886
# Class-based context manager (instead of ``@contextmanager`` from contextlib)
# to avoid the ~0.5us / invocation overhead of contextlib's generator-driven
# ``GeneratorContextManager``. ``__slots__`` further avoids per-instance
# dict allocation.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we mentioning the context manager here? It makes sense for this PR, but once the code is merged it will be completely random. This comment should explain what we are doing with __slots__, and we should explain the custom context manager logic in __enter__ and __exit__.

Comment on lines +928 to +929
# Do not suppress exceptions.
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The function already returns None and the comment is trivially true (all Python outside of a try statement is not suppressing exceptions).

Suggested change
# Do not suppress exceptions.
return None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants