diff --git a/cuda_core/cuda/core/graph/__init__.py b/cuda_core/cuda/core/graph/__init__.py index 7c608584ae..3f81098628 100644 --- a/cuda_core/cuda/core/graph/__init__.py +++ b/cuda_core/cuda/core/graph/__init__.py @@ -2,32 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from cuda.core.graph._graph_builder import ( - Graph, - GraphBuilder, - GraphCompleteOptions, - GraphDebugPrintOptions, -) -from cuda.core.graph._graph_def import ( - Condition, - GraphAllocOptions, - GraphDef, -) -from cuda.core.graph._graph_node import GraphNode -from cuda.core.graph._subclasses import ( - AllocNode, - ChildGraphNode, - ConditionalNode, - EmptyNode, - EventRecordNode, - EventWaitNode, - FreeNode, - HostCallbackNode, - IfElseNode, - IfNode, - KernelNode, - MemcpyNode, - MemsetNode, - SwitchNode, - WhileNode, -) +from ._graph_builder import * +from ._graph_def import * +from ._graph_node import * +from ._subclasses import * diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 9639c323c7..eff81ace96 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -21,6 +21,9 @@ from cuda.core._utils.cuda_utils import ( handle_return, ) +__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] + + @dataclass class GraphDebugPrintOptions: """Options for debug_dot_print(). diff --git a/cuda_core/cuda/core/graph/_graph_def.pyx b/cuda_core/cuda/core/graph/_graph_def.pyx index db9e384ff9..a83d30a56b 100644 --- a/cuda_core/cuda/core/graph/_graph_def.pyx +++ b/cuda_core/cuda/core/graph/_graph_def.pyx @@ -27,6 +27,8 @@ from dataclasses import dataclass from cuda.core._utils.cuda_utils import driver +__all__ = ['Condition', 'GraphAllocOptions', 'GraphDef'] + cdef class Condition: """A condition variable for conditional graph nodes. diff --git a/cuda_core/cuda/core/graph/_graph_node.pyx b/cuda_core/cuda/core/graph/_graph_node.pyx index ea52be75ed..e80f42f30a 100644 --- a/cuda_core/cuda/core/graph/_graph_node.pyx +++ b/cuda_core/cuda/core/graph/_graph_node.pyx @@ -61,6 +61,8 @@ import weakref from cuda.core.graph._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver +__all__ = ['GraphNode'] + # See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object) _node_registry = weakref.WeakValueDictionary() diff --git a/cuda_core/cuda/core/graph/_subclasses.pyx b/cuda_core/cuda/core/graph/_subclasses.pyx index 277fea4f78..0f06e9e81e 100644 --- a/cuda_core/cuda/core/graph/_subclasses.pyx +++ b/cuda_core/cuda/core/graph/_subclasses.pyx @@ -34,6 +34,24 @@ from cuda.core.graph._utils cimport _is_py_host_trampoline from cuda.core._utils.cuda_utils import driver, handle_return +__all__ = [ + 'AllocNode', + 'ChildGraphNode', + 'ConditionalNode', + 'EmptyNode', + 'EventRecordNode', + 'EventWaitNode', + 'FreeNode', + 'HostCallbackNode', + 'IfElseNode', + 'IfNode', + 'KernelNode', + 'MemcpyNode', + 'MemsetNode', + 'SwitchNode', + 'WhileNode', +] + cdef bint _has_cuGraphNodeGetParams = False cdef bint _version_checked = False