Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 142 additions & 84 deletions src/common/schedulers/slurm_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import logging
import os
import re
import shlex

# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
import subprocess # nosec B404
from datetime import datetime, timezone
from typing import Dict, List

Expand Down Expand Up @@ -60,11 +65,25 @@
SCONTROL = f"sudo {SLURM_BINARIES_DIR}/scontrol"
SINFO = f"{SLURM_BINARIES_DIR}/sinfo"

SCONTROL_OUTPUT_AWK_PARSER = (
'awk \'BEGIN{{RS="\\n\\n" ; ORS="######\\n";}} {{print}}\' | '
+ "grep -oP '^(NodeName=\\S+)|(NodeAddr=\\S+)|(NodeHostName=\\S+)|(?<!Next)(State=\\S+)|"
+ "(Partitions=\\S+)|(SlurmdStartTime=\\S+)|(LastBusyTime=\\S+)|(ReservationName=\\S+)"
+ "|(InstanceId=\\S+)|(Reason=.*)|(######)'"
SCONTROL_NODE_INFO_FIELD_REGEX = re.compile(
r"^(NodeName=\S+)"
r"|(NodeAddr=\S+)"
r"|(NodeHostName=\S+)"
r"|(?<!Next)(State=\S+)"
r"|(Partitions=\S+)"
r"|(SlurmdStartTime=\S+)"
r"|(LastBusyTime=\S+)"
r"|(ReservationName=\S+)"
r"|(InstanceId=\S+)"
r"|(Reason=.*)",
re.MULTILINE,
)

# Fields extracted from `scontrol show partitions` output. `(?<!Next)` ensures `State` is matched but
# `NextState` is not.
SCONTROL_PARTITION_INFO_FIELD_REGEX = re.compile(
r"^(PartitionName=\S+)" r"|(?<!Next)(State=\S+)",
re.MULTILINE,
)

# Set default timeouts for running different slurm commands.
Expand Down Expand Up @@ -336,6 +355,83 @@ def set_nodes_idle(nodes, reason=None, reset_node_addrs_hostname=False):
update_nodes(nodes=nodes, state="resume", reason=reason, raise_on_error=False)


def _run_scontrol_command(
scontrol_args: str,
command_timeout: int = DEFAULT_GET_INFO_COMMAND_TIMEOUT,
raise_on_error: bool = True,
) -> str:
"""
Run a `scontrol <args>` command as a standalone subprocess and return its stdout.

scontrol's exit code, stderr and raw stdout are logged. When scontrol exits non-zero but still returns
output (e.g. one of the requested nodes does not exist: Slurm prints "Node <x> not found" and exits 1
while still returning the other nodes), the output is returned and a warning is logged. When it exits
non-zero with no output, an error is logged and, when raise_on_error, a CalledProcessError is raised.
"""
command = f"{SCONTROL} {scontrol_args}"
log.debug("Executing scontrol command: %s", command)
try:
# nosec B603: command is built from trusted/validated input and run without a shell.
result = subprocess.run( # nosec B603
shlex.split(command),
timeout=command_timeout,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
except subprocess.TimeoutExpired:
log.error("scontrol command timed out after %s seconds: '%s'", command_timeout, command)
raise
except OSError as e:
log.error("Unable to execute scontrol command '%s'. Failed with exception: %s", command, e)
raise

stdout = result.stdout or ""
stderr = (result.stderr or "").strip()
# Raw scontrol output is logged at DEBUG only, so it is not emitted at the default INFO level.
log.debug("scontrol command '%s' exited with code %s. Raw output:\n%s", command, result.returncode, stdout)

if result.returncode != 0:
if stdout.strip():
log.warning(
"scontrol command '%s' returned non-zero exit code %s but produced output; proceeding with "
"the output. stderr: '%s'",
command,
result.returncode,
stderr,
)
else:
log.error(
"scontrol command '%s' failed with exit code %s and produced no output. stderr: '%s'",
command,
result.returncode,
stderr,
)
if raise_on_error:
raise subprocess.CalledProcessError(result.returncode, command, output=stdout, stderr=result.stderr)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Block]

def _run_command(command_function, command, env=None, raise_on_error=True, execute_as_user=None, log_error=True):
try:
if env is None:
env = {}
env.update(os.environ.copy())
if execute_as_user:
log.debug("Executing command as user '%s': %s", execute_as_user, command)
pw_record = pwd.getpwnam(execute_as_user)
user_uid = pw_record.pw_uid
user_gid = pw_record.pw_gid
preexec_fn = _demote(user_uid, user_gid)
return command_function(command, env, preexec_fn)
else:
log.debug("Executing command: %s", command)
return command_function(command, env, None)
except subprocess.CalledProcessError as e:
# CalledProcessError.__str__ already produces a significant error message
if raise_on_error:
if log_error:
log.error(e)
raise
else:
if log_error:
log.warning(e)
return e
except OSError as e:
log.error("Unable to execute the command %s. Failed with exception: %s", command, e)
raise

In the old _run_command it also catches OSError, can we add the except OSError to align?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added in line 386 - 388

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


return stdout


def _extract_scontrol_records(raw_output: str, field_regex) -> List[Dict[str, str]]:
"""
Split raw scontrol output into records and extract the relevant `key=value` fields of each record.

Records are separated by blank lines. For each record, only the fields matched by field_regex are kept,
returned as a dict mapping the Slurm field name to its value. Records with no matching field are skipped.
"""
records = []
for record in re.split(r"\n\n+", raw_output.strip()):
fields = {}
for match in field_regex.finditer(record):
key, _, value = match.group(0).partition("=")
fields[key] = value
if fields:
records.append(fields)
return records


def get_nodes_info(nodes: str = "", command_timeout=DEFAULT_GET_INFO_COMMAND_TIMEOUT) -> List[SlurmNode]:
"""
Retrieve SlurmNode list from slurm nodelist notation.
Expand All @@ -351,15 +447,11 @@ def get_nodes_info(nodes: str = "", command_timeout=DEFAULT_GET_INFO_COMMAND_TIM
if nodes == "":
nodes = _get_all_partition_nodes(",".join(PartitionNodelistMapping.instance().get_partitions()))

# Validation to sanitize the input argument and make it safe to use the function affected by B604
validate_subprocess_argument(nodes)

# awk is used to replace the \n\n record separator with '######\n'
# Note: In case the node does not belong to any partition the Partitions field is missing from Slurm output
show_node_info_command = f"{SCONTROL} show nodes {nodes} | {SCONTROL_OUTPUT_AWK_PARSER}"
nodeinfo_str = check_command_output(show_node_info_command, timeout=command_timeout, shell=True) # nosec B604

return _parse_nodes_info(nodeinfo_str)
raw_node_info = _run_scontrol_command(f"show nodes {nodes}", command_timeout=command_timeout)
return _parse_nodes_info(_extract_scontrol_records(raw_node_info, SCONTROL_NODE_INFO_FIELD_REGEX))


def get_partitions_info(command_timeout=DEFAULT_GET_INFO_COMMAND_TIMEOUT) -> List[SlurmPartition]:
Expand All @@ -368,16 +460,10 @@ def get_partitions_info(command_timeout=DEFAULT_GET_INFO_COMMAND_TIMEOUT) -> Lis

This function considers only partitions managed by ParallelCluster.
"""
partitions = list(PartitionNodelistMapping.instance().get_partitions())
grep_filter = _get_partition_grep_filter(partitions)
show_partition_info_command = (
f"{SCONTROL} show partitions -o {grep_filter} " + '| grep -oP "^PartitionName=\\K(\\S+)| State=\\K(\\S+)"'
)
# It's safe to use the function affected by B604 since the command is fully built in this code
partition_info_str = check_command_output(
show_partition_info_command, timeout=command_timeout, shell=True # nosec B604
)
partitions_info = _parse_partition_name_and_state(partition_info_str)
partitions = set(PartitionNodelistMapping.instance().get_partitions())
raw_partition_info = _run_scontrol_command("show partitions", command_timeout=command_timeout)
partition_records = _extract_scontrol_records(raw_partition_info, SCONTROL_PARTITION_INFO_FIELD_REGEX)
partitions_info = _parse_partitions_info(partition_records, managed_partitions=partitions or None)
return [
SlurmPartition(
partition_name,
Expand All @@ -388,15 +474,6 @@ def get_partitions_info(command_timeout=DEFAULT_GET_INFO_COMMAND_TIMEOUT) -> Lis
]


def _get_partition_grep_filter(partitions: List[str]) -> str:
grep_filter = ""
if partitions:
grep_filter += " | grep"
for partition in partitions:
grep_filter += f' -e "PartitionName={partition}"'
return grep_filter


def resume_powering_down_nodes():
"""Resume nodes that are powering_down so that are set in power state right away."""
# TODO: This function was added due to Slurm ticket 12915. The bug is not reproducible and the ticket was then
Expand All @@ -406,9 +483,23 @@ def resume_powering_down_nodes():
update_nodes(nodes=powering_down_nodes, state="resume", raise_on_error=False)


def _parse_partition_name_and_state(partition_info):
"""Parse partition name and state from scontrol output."""
return grouper(partition_info.splitlines(), 2)
def _parse_partitions_info(partition_records, managed_partitions=None):
"""
Extract (partition_name, partition_state) tuples from scontrol partition records.

Only partitions in managed_partitions are returned; if managed_partitions is None, all partitions are
returned. Records missing the name or state are skipped.
"""
partitions_info = []
for fields in partition_records:
partition_name = fields.get("PartitionName")
partition_state = fields.get("State")
if not partition_name or not partition_state:
continue
if managed_partitions is not None and partition_name not in managed_partitions:
continue
partitions_info.append((partition_name, partition_state))
return partitions_info


def _get_all_partition_nodes(partition_name, command_timeout=DEFAULT_GET_INFO_COMMAND_TIMEOUT):
Expand All @@ -433,43 +524,14 @@ def _get_slurm_nodes(states=None, partition_name=None, command_timeout=DEFAULT_G
return check_command_output(sinfo_command, timeout=command_timeout, shell=True).splitlines() # nosec B604


def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]:
"""Parse slurm node info into SlurmNode objects."""
# [ec2-user@ip-10-0-0-58 ~]$ /opt/slurm/bin/scontrol show nodes compute-dy-c5xlarge-[1-3],compute-dy-c5xlarge-50001\
# | awk 'BEGIN{{RS="\n\n" ; ORS="######\n";}} {{print}}' | grep -oP "^(NodeName=\S+)|(NodeAddr=\S+)
# |(NodeHostName=\S+)|(?<!Next)(State=\S+)|(Partitions=\S+)|(SlurmdStartTime=\S+)|(LastBusyTime=\\S+)
# |(ReservationName=\S+)|(Reason=.*)|(######)"
# NodeName=compute-dy-c5xlarge-1
# NodeAddr=1.2.3.4
# NodeHostName=compute-dy-c5xlarge-1
# State=IDLE+CLOUD+POWER
# Partitions=compute,compute2
# SlurmdStartTime=2023-01-26T09:57:15
# Reason=some reason
# ReservationName=root_1
# ######
# NodeName=compute-dy-c5xlarge-2
# NodeAddr=1.2.3.4
# NodeHostName=compute-dy-c5xlarge-2
# State=IDLE+CLOUD+POWER
# Partitions=compute,compute2
# SlurmdStartTime=2023-01-26T09:57:15
# Reason=(Code:InsufficientInstanceCapacity)Failure when resuming nodes
# ######
# NodeName=compute-dy-c5xlarge-3
# NodeAddr=1.2.3.4
# NodeHostName=compute-dy-c5xlarge-3
# State=IDLE+CLOUD+POWER
# Partitions=compute,compute2
# SlurmdStartTime=2023-01-26T09:57:15
# ######
# NodeName=compute-dy-c5xlarge-50001
# NodeAddr=1.2.3.4
# NodeHostName=compute-dy-c5xlarge-50001
# State=IDLE+CLOUD+POWER
# SlurmdStartTime=None
# ######
def _parse_nodes_info(node_records: List[Dict[str, str]]) -> List[SlurmNode]:
"""
Build SlurmNode objects from scontrol node records extracted by _extract_scontrol_records.

Each record is a dict mapping Slurm field names to their values, e.g.:
{"NodeName": "compute-dy-c5xlarge-1", "NodeAddr": "1.2.3.4", "NodeHostName": "compute-dy-c5xlarge-1",
"State": "IDLE+CLOUD+POWER", "Partitions": "compute,compute2", "SlurmdStartTime": "2023-01-26T09:57:15"}
"""
map_slurm_key_to_arg = {
"NodeName": "name",
"NodeAddr": "nodeaddr",
Expand All @@ -485,13 +547,12 @@ def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]:

date_fields = ["SlurmdStartTime", "LastBusyTime"]

node_info = slurm_node_info.split("######\n")
slurm_nodes = []
for node in node_info:
lines = node.splitlines()
for fields in node_records:
if "NodeName" not in fields:
continue
kwargs = {}
for line in lines:
key, value = line.split("=", 1)
for key, value in fields.items():
if key in date_fields:
if value not in ["None", "Unknown"]:
value = datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").astimezone(tz=timezone.utc)
Expand All @@ -501,15 +562,12 @@ def _parse_nodes_info(slurm_node_info: str) -> List[SlurmNode]:
# Slurm reports an unset InstanceId as "(null)"
value = None
kwargs[map_slurm_key_to_arg[key]] = value
if lines:
try:
if is_static_node(kwargs["name"]):
node = StaticNode(**kwargs)
slurm_nodes.append(node)
else:
node = DynamicNode(**kwargs)
slurm_nodes.append(node)
except InvalidNodenameError:
log.warning("Ignoring node %s because it has an invalid name", kwargs["name"])
try:
if is_static_node(kwargs["name"]):
slurm_nodes.append(StaticNode(**kwargs))
else:
slurm_nodes.append(DynamicNode(**kwargs))
except InvalidNodenameError:
log.warning("Ignoring node %s because it has an invalid name", kwargs["name"])

return slurm_nodes
67 changes: 30 additions & 37 deletions src/common/schedulers/slurm_reservation_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re

# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
import subprocess # nosec B404
from datetime import datetime
from typing import List, Union

from common.schedulers.slurm_commands import DEFAULT_SCONTROL_COMMAND_TIMEOUT, SCONTROL
from common.schedulers.slurm_commands import (
DEFAULT_SCONTROL_COMMAND_TIMEOUT,
SCONTROL,
_extract_scontrol_records,
_run_scontrol_command,
)
from common.utils import (
SlurmCommandError,
SlurmCommandErrorHandler,
Expand All @@ -30,9 +36,11 @@
logger = logging.getLogger(__name__)


SCONTROL_SHOW_RESERVATION_OUTPUT_AWK_PARSER = (
'awk \'BEGIN{{RS="\\n\\n" ; ORS="######\\n";}} {{print}}\' | '
+ "grep -oP '^(ReservationName=\\S+)|(?<!Next)(State=\\S+)|(Users=\\S+)|(Nodes=\\S+)|(######)'"
# Fields extracted from raw `scontrol show reservations` output. Only `ReservationName` is anchored at the
# start of a line; `(?<!Next)` ensures `State` is matched but `NextState` is not.
SCONTROL_RESERVATION_INFO_FIELD_REGEX = re.compile(
r"^(ReservationName=\S+)" r"|(?<!Next)(State=\S+)" r"|(Users=\S+)" r"|(Nodes=\S+)",
re.MULTILINE,
)


Expand Down Expand Up @@ -256,41 +264,26 @@ def get_slurm_reservations_info(

Official documentation is https://slurm.schedmd.com/reservations.html
"""
# awk is used to replace the \n\n record separator with '######\n'
show_reservations_command = f"{SCONTROL} show reservations | {SCONTROL_SHOW_RESERVATION_OUTPUT_AWK_PARSER}"
slurm_reservations_info = check_command_output(
show_reservations_command, raise_on_error=raise_on_error, timeout=command_timeout, shell=True
) # nosec B604

return _parse_reservations_info(slurm_reservations_info)


def _parse_reservations_info(slurm_reservations_info: str) -> List[SlurmReservation]:
"""Parse slurm reservations info into SlurmReservation objects."""
# $ /opt/slurm/bin/scontrol show reservations awk 'BEGIN{{RS="\n\n" ; ORS="######\n";}} {{print}}' |
# grep -oP '^(ReservationName=\S+)|(?<!Next)(State=\S+)|(Users=\S+)|(Nodes=\S+)|(######)'
# ReservationName=root_8
# Nodes=queuep4d-dy-crp4d-[1-5]
# Users=root
# State=ACTIVE
# ######
# ReservationName=root_9
# Nodes=queue1-st-crt2micro-1
# Users=root
# State=ACTIVE
# ######
raw_reservations_info = _run_scontrol_command(
"show reservations", command_timeout=command_timeout, raise_on_error=raise_on_error
)
reservation_records = _extract_scontrol_records(raw_reservations_info, SCONTROL_RESERVATION_INFO_FIELD_REGEX)

return _parse_reservations_info(reservation_records)


def _parse_reservations_info(reservation_records: List[dict]) -> List[SlurmReservation]:
"""
Build SlurmReservation objects from scontrol reservation records extracted by _extract_scontrol_records.

Each record is a dict mapping Slurm field names to their values, e.g.:
{"ReservationName": "root_8", "Nodes": "queuep4d-dy-crp4d-[1-5]", "Users": "root", "State": "ACTIVE"}
"""
map_slurm_key_to_arg = {"ReservationName": "name", "Nodes": "nodes", "Users": "users", "State": "state"}

reservation_info = slurm_reservations_info.split("######\n")
slurm_reservations = []
for reservation in reservation_info:
lines = reservation.splitlines()
kwargs = {}
for line in lines:
key, value = line.split("=")
kwargs[map_slurm_key_to_arg[key]] = value
if lines:
reservation = SlurmReservation(**kwargs)
slurm_reservations.append(reservation)
for fields in reservation_records:
kwargs = {map_slurm_key_to_arg[key]: value for key, value in fields.items()}
slurm_reservations.append(SlurmReservation(**kwargs))

return slurm_reservations
Loading
Loading