Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions cuda_core/cuda/core/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from cuda.core.graph._graph_builder import (
Graph,
GraphBuilder,
GraphCompleteOptions,
GraphDebugPrintOptions,
)
from cuda.core.graph._graph_def import (
Condition,
GraphAllocOptions,
GraphDef,
)
from cuda.core.graph._graph_node import GraphNode
from cuda.core.graph._subclasses import (
AllocNode,
ChildGraphNode,
ConditionalNode,
EmptyNode,
EventRecordNode,
EventWaitNode,
FreeNode,
HostCallbackNode,
IfElseNode,
IfNode,
KernelNode,
MemcpyNode,
MemsetNode,
SwitchNode,
WhileNode,
)
from ._graph_builder import *
from ._graph_def import *
from ._graph_node import *
from ._subclasses import *
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/graph/_graph_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ from cuda.core._utils.cuda_utils import (
handle_return,
)

__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions']


@dataclass
class GraphDebugPrintOptions:
"""Options for debug_dot_print().
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/graph/_graph_def.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from dataclasses import dataclass

from cuda.core._utils.cuda_utils import driver

__all__ = ['Condition', 'GraphAllocOptions', 'GraphDef']


cdef class Condition:
"""A condition variable for conditional graph nodes.
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/graph/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ import weakref
from cuda.core.graph._adjacency_set_proxy import AdjacencySetProxy
from cuda.core._utils.cuda_utils import driver

__all__ = ['GraphNode']

# See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object)
_node_registry = weakref.WeakValueDictionary()

Expand Down
18 changes: 18 additions & 0 deletions cuda_core/cuda/core/graph/_subclasses.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ from cuda.core.graph._utils cimport _is_py_host_trampoline

from cuda.core._utils.cuda_utils import driver, handle_return

__all__ = [
'AllocNode',
'ChildGraphNode',
'ConditionalNode',
'EmptyNode',
'EventRecordNode',
'EventWaitNode',
'FreeNode',
'HostCallbackNode',
'IfElseNode',
'IfNode',
'KernelNode',
'MemcpyNode',
'MemsetNode',
'SwitchNode',
'WhileNode',
]


cdef bint _has_cuGraphNodeGetParams = False
cdef bint _version_checked = False
Expand Down
Loading