Source code for rxn_network.utils.ray

"""Functions for working with Ray (parallelization library)."""

from __future__ import annotations

import os

import ray

from rxn_network.utils.funcs import get_logger

logger = get_logger(__name__)


[docs] def initialize_ray(quiet: bool = False): """Simple function to initialize ray. Basic support for running ray on multiple nodes. Currently supports SLURM and PBS job schedulers. SLURM: Checks enviornment for existence of "ip_head" for situations where the user is running on multiple nodes. Automatically creats a new ray cluster if it has not been initialized. See https://github.com/NERSC/slurm-ray-cluster/ PBS: Checks environment for PBS_NNODES > 1. """ if not quiet: logger.setLevel("INFO") if not ray.is_initialized(): logger.info("Ray is not initialized. Checking for existing cluster...") if os.environ.get("IP_HEAD") or int(os.environ.get("PBS_NNODES", 0)) > 1: ray.init( address="auto", ) else: logger.info("Could not identify existing Ray instance. Creating a new one...") ray.init() logger.info( f"HOST: {ray.nodes()[0]['NodeManagerHostname']}, " f"Num CPUs: {ray.cluster_resources()['CPU']}, " f"Total Memory: {ray.cluster_resources()['memory']}" ) else: logger.info("Ray is already initialized.")
[docs] def to_iterator(obj_ids, get_obj_ids: bool = False): """Method to convert a list of ray object ids to an iterator that can be used in a for loop. """ while obj_ids: done, obj_ids = ray.wait(obj_ids) if get_obj_ids: yield done[0], ray.get(done[0]) else: yield ray.get(done[0])