[docs]defcheck_architecture_name(name:str)->None:"""Check if the requested architecture is available. If the architecture is not found an :func:`ValueError` is raised. If an architecture with the same name as an experimental or deprecated architecture exist, this architecture is suggested. If no architecture exist the closest architecture is given to help debugging typos. :param name: name of the architecture :raises ValueError: if the architecture is not found """try:iffind_spec(f"metatrain.{name}")isnotNone:returneliffind_spec(f"metatrain.experimental.{name}")isnotNone:msg=(f"Architecture {name!r} is not a stable architecture. An ""experimental architecture with the same name was found. Set "f"`name: experimental.{name}` in your options file to use this ""experimental architecture.")eliffind_spec(f"metatrain.deprecated.{name}")isnotNone:msg=(f"Architecture {name!r} is not a stable architecture. A ""deprecated architecture with the same name was found. Set "f"`name: deprecated.{name}` in your options file to use this ""deprecated architecture.")else:# not found anywhere, just raise the following except blockraiseModuleNotFoundErrorexceptModuleNotFoundError:msg=f"Architecture {name!r} is not a valid architecture."closest_match=difflib.get_close_matches(word=name,possibilities=find_all_architectures())ifclosest_match:msg+=f" Do you mean '{closest_match[0]}'?"raiseValueError(msg)
[docs]defcheck_architecture_options(name:str,options:Dict,)->None:"""Verifies that an options instance only contains valid keys If the architecture developer does not provide a validation scheme the ``options`` will not checked. :param name: name of the architecture :param options: architecture options to check """schema_path=get_architecture_path(name)/"schema-hypers.json"ifschema_path.exists():withopen(schema_path,"r")asf:schema=json.load(f)validate(instance=options,schema=schema)else:logging.debug("No schema found for {name!r} architecture. Skipping validation.")
[docs]defget_architecture_name(path:Union[str,Path])->str:"""Name of an architecture based on path to pointing inside an architecture. The function should be used to determine the ``ARCHITECTURE_NAME`` based on the name of the folder. :param absolute_architecture_path: absolute path of the architecture directory :returns: architecture name :raises ValueError: if ``absolute_architecture_path`` does not point to a valid architecture directory. .. seealso:: :py:func:`get_architecture_path` to get the relative path within the metatrain project of an architecture name. """path=Path(path)ifpath.is_dir():directory=pathelifpath.is_file():directory=path.parentelse:raiseValueError(f"`path` {str(path)!r} does not exist")architecture_path=directory.relative_to(PACKAGE_ROOT)name=str(architecture_path).replace("/",".")try:check_architecture_name(name)exceptValueErroraserr:raiseValueError(f"`path` {str(path)!r} does not point to a valid architecture folder")fromerrreturnname
[docs]defimport_architecture(name:str):"""Import an architecture. :param name: name of the architecture :raises ImportError: if the architecture dependencies are not met """check_architecture_name(name)try:returnimportlib.import_module(f"metatrain.{name}")exceptImportErroraserr:# consistent name with pyproject.toml's `optional-dependencies` sectionname_for_deps=nameif"experimental."innameor"deprecated."inname:name_for_deps=".".join(name.split(".")[1:])name_for_deps=name_for_deps.replace("_","-")raiseImportError(f"Trying to import '{name}' but architecture dependencies "f"seem not be installed. \n"f"Try to install them with `pip install metatrain[{name_for_deps}]`")fromerr
[docs]defget_architecture_path(name:str)->Path:"""Return the relative path to the architecture directory. Path based on the ``name`` within the metatrain project directory. :param name: name of the architecture :returns: path to the architecture directory .. seealso:: :py:func:`get_architecture_name` to get the name based on an absolute path of an architecture. """check_architecture_name(name)returnPACKAGE_ROOT/Path(name.replace(".","/"))
[docs]deffind_all_architectures()->List[str]:"""Find all currently available architectures. To find the architectures the function searches for the mandatory ``default-hypers.yaml`` file in each architecture directory. :returns: List of architectures names """options_files_path=PACKAGE_ROOT.rglob("default-hypers.yaml")architecture_names=[]foroption_file_pathinoptions_files_path:architecture_names.append(get_architecture_name(option_file_path))returnarchitecture_names
[docs]defget_default_hypers(name:str)->Dict:"""Dictionary of the default architecture hyperparameters. :param: name of the architecture :returns: default hyper parameters of the architectures """check_architecture_name(name)default_hypers=OmegaConf.load(get_architecture_path(name)/"default-hypers.yaml")# We present the `default-hypers.yaml` file inside the documentation. For a better# user experience we store these yaml files with an additional level of indentation# (`"architecture"`), which we have to remove here to get the raw default hypers.returnOmegaConf.to_container(default_hypers)["architecture"]