importwarningsfromtypingimportList,Optionalimporttorchdef_mps_is_available()->bool:# require `torch.backends.mps.is_available()` for a reasonable check in torch<2.0returntorch.backends.mps.is_built()andtorch.backends.mps.is_available()
[docs]defpick_devices(architecture_devices:List[str],desired_device:Optional[str]=None,)->List[torch.device]:"""Pick (best) devices for training. The choice is made on the intersection of the ``architecture_devices`` and the available devices on the current system. If no ``desired_device`` is provided the first device of this intersection will be returned. :param architecture_devices: Devices supported by the architecture. The list should be sorted by the preference of the architecture while the most preferred device should be first and the least one last. :param desired_device: desired device by the user. For example, ``"cpu"``, "``cuda``", ``"multi-gpu"``, etc. """available_devices=["cpu"]iftorch.cuda.is_available():available_devices.append("cuda")iftorch.cuda.device_count()>1:available_devices.append("multi-cuda")if_mps_is_available():available_devices.append("mps")# intersect between available and architecture's devices. keep order of architecturepossible_devices=[dfordinarchitecture_devicesifdinavailable_devices]ifnotpossible_devices:raiseValueError(f"No matching device found! The architecture requires "f"{', '.join(architecture_devices)}; but your system only has "f"{', '.join(available_devices)}.")# If desired device given compare the possible devices and try to find a matchifdesired_deviceisNone:desired_device=possible_devices[0]else:desired_device=desired_device.lower()# we copy whatever the input device string is, to avoid that some strings# that do not get resolved but passed directly do not get converteduser_requested_device=desired_device# convert "gpu" and "multi-gpu" to "cuda" or "mps" if availableifdesired_device=="gpu":iftorch.cuda.is_available():desired_device="cuda"elif_mps_is_available():desired_device="mps"else:raiseValueError("Requested 'gpu' device, but found no GPU (CUDA or MPS) devices.")elifdesired_device=="cuda"andnottorch.cuda.is_available():raiseValueError("Requested 'cuda' device, but cuda is not available.")elifdesired_device=="mps"andnot_mps_is_available():raiseValueError("Requested 'mps' device, but mps is not available.")ifdesired_device=="multi-gpu":desired_device="multi-cuda"ifdesired_devicenotinarchitecture_devices:raiseValueError(f"Desired device {user_requested_device!r} name resolved to "f"{desired_device!r} is not supported by the selected "f"architecture. Please choose from {', '.join(possible_devices)}.")ifdesired_devicenotinavailable_devices:raiseValueError(f"Desired device {user_requested_device!r} name resolved to "f"{desired_device!r} is not supported by the selected "f"your current system. Please choose from {', '.join(possible_devices)}.")ifpossible_devices.index(desired_device)>0:warnings.warn(f"Device {user_requested_device!r} — name resolved to "f"{desired_device!r} — requested, but {possible_devices[0]!r} is ""preferred by the architecture and available on current system.",stacklevel=2,)if(desired_device=="cuda"andtorch.cuda.device_count()>1andany(dinpossible_devicesfordin["multi-cuda","multi_gpu"])):warnings.warn(f"Requested single 'cuda' device by specifying {user_requested_device!r} ""but current system has "f"{torch.cuda.device_count()} cuda devices and architecture supports ""multi-gpu training. Consider using 'multi-gpu' to accelerate ""training.",stacklevel=2,)# convert the requested device to a list of torch devicesifdesired_device=="multi-cuda":return[torch.device(f"cuda:{i}")foriinrange(torch.cuda.device_count())]else:return[torch.device(desired_device)]