Source code for protomotions.utils.fabric_config

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Configuration classes for Lightning Fabric distributed training."""

from typing import Dict, Any, Union, Optional, List
from omegaconf import DictConfig
from dataclasses import dataclass, field
from lightning import fabric

from protomotions.utils.hydra_replacement import instantiate


[docs] @dataclass class FabricConfig: """Configuration for Lightning Fabric distributed training.""" accelerator: str = field( default="gpu", metadata={"help": "Hardware accelerator: 'gpu', 'cpu', 'tpu', 'auto'."} ) devices: Union[int, str] = field( default=1, metadata={"help": "Number of devices or 'auto' for all available."} ) num_nodes: Union[int, str] = field( default=1, metadata={"help": "Number of nodes for distributed training.", "min": 1} ) strategy: Union[Dict, fabric.strategies.Strategy] = field( default_factory=fabric.strategies.DDPStrategy, metadata={"help": "Distributed training strategy (DDP, FSDP, etc)."} ) precision: Union[str, int] = field( default="32-true", metadata={"help": "Training precision: '32-true', '16-mixed', 'bf16-mixed'."} ) loggers: Optional[List[Union[Dict, fabric.loggers.Logger]]] = field( default=None, metadata={"help": "List of logging backends (WandB, TensorBoard, etc)."} ) callbacks: Optional[List[Union[Dict, Any]]] = field( default=None, metadata={"help": "List of training callbacks."} ) def __post_init__(self): if self.strategy is not None and ( isinstance(self.strategy, dict) or isinstance(self.strategy, DictConfig) ): self.strategy = instantiate(self.strategy) if self.loggers is not None: loggers = [] for logger in self.loggers: if isinstance(logger, dict) or isinstance(logger, DictConfig): loggers.append(instantiate(logger)) else: loggers.append(logger) self.loggers = loggers if self.callbacks is not None: callbacks = [] for callback in self.callbacks: if isinstance(callback, dict) or isinstance(callback, DictConfig): callbacks.append(instantiate(callback)) else: callbacks.append(callback) self.callbacks = callbacks