Source code for metatrain.utils.devices

import warnings
from typing import List, Optional

import torch


def _mps_is_available() -> bool:
    # require `torch.backends.mps.is_available()` for a reasonable check in torch<2.0
    return torch.backends.mps.is_built() and torch.backends.mps.is_available()


[docs] def pick_devices( architecture_devices: List[str], desired_device: Optional[str] = None, ) -> List[torch.device]: """Pick (best) devices for training. The choice is made on the intersection of the ``architecture_devices`` and the available devices on the current system. If no ``desired_device`` is provided the first device of this intersection will be returned. :param architecture_devices: Devices supported by the architecture. The list should be sorted by the preference of the architecture while the most preferred device should be first and the least one last. :param desired_device: desired device by the user. For example, ``"cpu"``, "``cuda``", ``"multi-gpu"``, etc. """ available_devices = ["cpu"] if torch.cuda.is_available(): available_devices.append("cuda") if torch.cuda.device_count() > 1: available_devices.append("multi-cuda") if _mps_is_available(): available_devices.append("mps") # intersect between available and architecture's devices. keep order of architecture possible_devices = [d for d in architecture_devices if d in available_devices] if not possible_devices: raise ValueError( f"No matching device found! The architecture requires " f"{', '.join(architecture_devices)}; but your system only has " f"{', '.join(available_devices)}." ) # If desired device given compare the possible devices and try to find a match if desired_device is None: desired_device = possible_devices[0] else: desired_device = desired_device.lower() # we copy whatever the input device string is, to avoid that some strings # that do not get resolved but passed directly do not get converted user_requested_device = desired_device # convert "gpu" and "multi-gpu" to "cuda" or "mps" if available if desired_device == "gpu": if torch.cuda.is_available(): desired_device = "cuda" elif _mps_is_available(): desired_device = "mps" else: raise ValueError( "Requested 'gpu' device, but found no GPU (CUDA or MPS) devices." ) elif desired_device == "cuda" and not torch.cuda.is_available(): raise ValueError("Requested 'cuda' device, but cuda is not available.") elif desired_device == "mps" and not _mps_is_available(): raise ValueError("Requested 'mps' device, but mps is not available.") if desired_device == "multi-gpu": desired_device = "multi-cuda" if desired_device not in architecture_devices: raise ValueError( f"Desired device {user_requested_device!r} name resolved to " f"{desired_device!r} is not supported by the selected " f"architecture. Please choose from {', '.join(possible_devices)}." ) if desired_device not in available_devices: raise ValueError( f"Desired device {user_requested_device!r} name resolved to " f"{desired_device!r} is not supported by the selected " f"your current system. Please choose from {', '.join(possible_devices)}." ) if possible_devices.index(desired_device) > 0: warnings.warn( f"Device {user_requested_device!r} — name resolved to " f"{desired_device!r} — requested, but {possible_devices[0]!r} is " "preferred by the architecture and available on current system.", stacklevel=2, ) if ( desired_device == "cuda" and torch.cuda.device_count() > 1 and any(d in possible_devices for d in ["multi-cuda", "multi_gpu"]) ): warnings.warn( f"Requested single 'cuda' device by specifying {user_requested_device!r} " "but current system has " f"{torch.cuda.device_count()} cuda devices and architecture supports " "multi-gpu training. Consider using 'multi-gpu' to accelerate " "training.", stacklevel=2, ) # convert the requested device to a list of torch devices if desired_device == "multi-cuda": return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] else: return [torch.device(desired_device)]