在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第三篇,看看弹性代理的基本功能。
弹性训练系列文章如下:
[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
我们先总述一下,TE 最重要的是 Agent 和 Rendezvous 这两个概念。
rendezvous
实现 worker 间的相互发现,当有成员变动时候负责基于 rendezvous
进行变更同步。TE 是围绕在 Rendezvous 基础之上的多个elastic agent构成,这是一种功能分离,让我们对比一下看看。
torch.distributed.Store
API。此存储仅由已完成Rendezvous的成员共享,它旨在让Torch Distributed Elastic在初始化作业过程之中交换控制和数据信息。我们首先从源码中取出示意图看看,大家先有一个总体概念。
我们本文只是简单介绍一下 rendezvous,重点在于介绍 agent。
在 Torch Distributed Elastic 上下文之中,我们使用 rendezvous 这个术语来特指一个特定功能:一个结合了对等发现(peer discovery)的分布式同步(distributed synchronization)原语。
Rendezvous 被Torch Distributed Elastic用来收集一个训练job的参与者(节点),这样,参与者们可以商议得到参与者列表和每个参与者的角色,也可以对训练何时开始/恢复做出一致的集体决定。
Rendezvous 把功能分割解耦,业务逻辑被抽象成为一系列算子,比如 _RendevzousJoinOp
。而 Rendezvous 内部维护了一套状态机,由算子决定下一步操作。比如 _RendezvousOpExecutor
来执行各种算子,依据算子结果得到下一步应该执行的 Action,从而对本身进行操作。
比如在 _DistributedRendezvousOpExecutor
之中,如果发现了当前 action 是 ADD_TO_WAIT_LIST,会执行 _add_to_wait_list
,进而调用 self._state.wait_list.add(self._node)
if action == _Action.KEEP_ALIVE: self._keep_alive() elif action == _Action.ADD_TO_PARTICIPANTS: self._add_to_participants() elif action == _Action.ADD_TO_WAIT_LIST: # 发现当前Action self._add_to_wait_list() # 然后执行 elif action == _Action.REMOVE_FROM_PARTICIPANTS: self._remove_from_participants() elif action == _Action.REMOVE_FROM_WAIT_LIST: self._remove_from_wait_list() elif action == _Action.MARK_RENDEZVOUS_COMPLETE: self._mark_rendezvous_complete() elif action == _Action.MARK_RENDEZVOUS_CLOSED: self._mark_rendezvous_closed()
Elastic agent 是 torchelastic 的控制台(control plane),他是一个独立进程,负责启动和管理底层 worker 进程,代理具体负责:
torch.distributed.init_process_group()
。下图来自知乎,算是对上一个图的细化。
Torchelast agent 和 用户worker 依据故障切换契约来工作:
简单的agent部署在每个节点上,并与本地进程协同工作。更高级的agent可以远程启动和管理workers。Agent可以做到彻底的去中心化,与其他agents(管理同一个job的workers)进行沟通协调做出一个集体性决策,决策是基于其管理的 workers 情况来完成。
对于如何配置,源码中也给出了示例,如果在GPU上启动训练一个拥有 8 个 trainer(每GPU一个trainer)的 job,我们可以做如下配置。
1. Use 8 x single GPU instances, place an agent per instance, managing 1 worker per agent. 2. Use 4 x double GPU instances, place an agent per instance, managing 2 workers per agent. 3. Use 2 x quad GPU instances, place an agent per instance, managing 4 workers per agent. 4. Use 1 x 8 GPU instance, place an agent per instance, managing 8 workers per agent.
基类ElasticAgent
是一个 Abstract Class,真正运行的代理都需要由此派生。从 ElasticAgent 的注释可知,代理进程负责管理一个或多个worker 进程。工作进程被假定为常规分布式PyTorch脚本。当worker进程由代理创建时,代理将为worker进程提供必要的信息,以便正确初始化torch进程组。部署时,精确的拓扑和 agent-to-worker 比率取决于代理的具体实现和用户作业放置偏好。
class ElasticAgent(abc.ABC): """ Agent process responsible for managing one or more worker processes. The worker processes are assumed to be regular distributed PyTorch scripts. When the worker process is created by the agent, the agent provides the necessary information for the worker processes to properly initialize a torch process group. The exact deployment topology and ratio of agent-to-worker is dependent on the specific implementation of the agent and the user's job placement preferences. Usage :: group_result = agent.run() if group_result.is_failed(): # workers failed failure = group_result.failures[0] log.exception(f"worker 0 failed with exit code : {failure.exit_code}") else: return group_result.return_values[0] # return rank 0's results """ @abc.abstractmethod def run(self, role: str = DEFAULT_ROLE) -> RunResult: """ Runs the agent, retrying the worker group on failures up to ``max_restarts``. Returns: The result of the execution, containing the return values or failure details for each worker mapped by the worker's global rank. Raises: Exception - any other failures NOT related to worker process """ raise NotImplementedError() @abc.abstractmethod def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: """ Returns: The ``WorkerGroup`` for the given ``role``. Note that the worker group is a mutable object and hence in a multi-threaded/process environment it may change state. Implementors are encouraged (but not required) to return a defensive read-only copy. """ raise NotImplementedError()
ElasticAgent 有两个派生类:
SimpleElasticAgent
实现了基类的部分函数,其目的是为了方便扩展新代理的实现。LocalElasticAgent
派生了SimpleElasticAgent
,是目前弹性训练最终使用的代理,主要用于在本地进行操作,负责管理单机上所有的worker进程。我们首先要看看 worker,这是 Agent 所管理的主体。
Worker 类代表了一个worker实例,我们上文介绍了WorkerSpec
,Worker 就是依据 WorkerSpec
构建出来的,其重点成员变量如下:
id(任意):唯一标识一个worker,具体是由ElasticAgent的特定实现来解释,对于本地代理,它可以是worker的pid(int)
,对于远程代理,它可以被编码为``host:port(string)`。
local_rank :worker的local rank。
global_rank:worker的global rank。
role_rank:具有相同角色的所有worker的rank。
world_size:全局worker数量。
role_world_size:具有相同角色的worker数量。
class Worker: """ Represents a worker instance. Contrast this with ``WorkerSpec`` that represents the specifications of a worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to a ``WorkerSpec`` as an object is to a class. The ``id`` of the worker is interpreted by the specific implementation of ``ElasticAgent``. For a local agent, it could be the ``pid (int)`` of the worker, for a remote agent it could be encoded as ``host:port (string)``. Args: id (Any): uniquely identifies a worker (interpreted by the agent) local_rank (int): local rank of the worker global_rank (int): global rank of the worker role_rank (int): rank of the worker across all workers that have the same role world_size (int): number of workers (globally) role_world_size (int): number of workers that have the same role """ def __init__( self, local_rank: int, global_rank: int = -1, role_rank: int = -1, world_size: int = -1, role_world_size: int = -1, ): # unique identifier for this worker self.id: Any = None # rank of the worker among workers with the same role being monitored # by the same ``agent`` instance. self.local_rank: int = local_rank # rank of the worker among all the workers across all roles # across all ``agent`` instances. # Global rank is not stable between re-rendezvous. self.global_rank: int = global_rank # rank of the worker among all the workers with the same role # across all ``agent`` instances. # Global rank is not stable between re-rendezvous. self.role_rank: int = role_rank # total number of workers (globally). Due to elasticity # the world size may change between re-rendezvous. self.world_size: int = world_size # total number of workers that share the same role. Due to elasticity # the role world size may change between re-rendezvous. self.role_world_size: int = role_world_size
WorkerGroup 代表了一个工作组,作为一个整体来管理多个 workers,进行批量处理。
class WorkerGroup: """ Represents the set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker group contains cross instance workers or not depends on the implementation of the agent. """ def __init__(self, spec: WorkerSpec): self.spec = spec self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] # assigned after rdzv self.store = None self.group_rank = None self.group_world_size = None self.state = WorkerState.INIT
在SimpleElasticAgent 初始化之中,会建立一个 WorkerGroup。
class SimpleElasticAgent(ElasticAgent): """ An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` (e.g. one particular type of worker role). """ def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): self._worker_group = WorkerGroup(spec) self._remaining_restarts = self._worker_group.spec.max_restarts self._store = None self._exit_barrier_timeout = exit_barrier_timeout self._total_execution_time = 0
WorkerState 表示 WorkerGroup
的状态。工作组中的所有工作人员作为一个整体来维护/更改状态。如果工作组中的一个worker失败,则整个工作组被认为是失败:
UNKNOWN - agent lost track of worker group state, unrecoverable INIT - worker group object created not yet started HEALTHY - workers running and healthy UNHEALTHY - workers running and unhealthy STOPPED - workers stopped (interruped) by the agent SUCCEEDED - workers finished running (exit 0) FAILED - workers failed to successfully finish (exit !0)
具体这些状态意义如下:
UNKNOWN-代理丢失了对工作组状态的跟踪,无法恢复
INIT-创建的工作组对象尚未启动
HEALTHY-worker健康运行
UNHEALTHY-worker在运行但是不健康
STOPPED-代理停止(中断)worker
SUCCEEDED-worker已完成运行(exit数值为0)
FAILED-worker未能成功完成(exit数值不等于0)
工作组从初始的INIT
状态开始,然后进入"健康"或"不健康"状态,最后到达终端"成功"或"失败"状态。工作组可以被代理打断并且临时置于"停止"状态。处于"已停止"状态的工作进程可以在不久的将来被调度重启,被设置为已停止的
状态的例子为:
当工作组上的操作(启动、停止、rdzv、重试等)失败,并导致操作部分应用于工作组时,状态将为"未知"。这通常发生在状态改变期间发生异常,而且异常未捕获/未处理的情况下。当工作组处于"未知"状态,代理不会恢复工作组,因此最好终止作业,并且由job manager重试节点。
WorkerState 具体定义如下:
class WorkerState(str, Enum): """ State of the ``WorkerGroup``. Workers in a worker group change state as a unit. If a single worker in a worker group fails the entire set is considered failed:: UNKNOWN - agent lost track of worker group state, unrecoverable INIT - worker group object created not yet started HEALTHY - workers running and healthy UNHEALTHY - workers running and unhealthy STOPPED - workers stopped (interruped) by the agent SUCCEEDED - workers finished running (exit 0) FAILED - workers failed to successfully finish (exit !0) A worker group starts from an initial ``INIT`` state, then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. Worker groups can be interrupted and temporarily put into ``STOPPED`` state by the agent. Workers in ``STOPPED`` state are scheduled to be restarted in the near future by the agent. Some examples of workers being put into ``STOPPED`` state are: 1. Worker group failure|unhealthy observed 2. Membership change detected When actions (start, stop, rdzv, retry, etc) on worker group fails and results in the action being partially applied to the worker group the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled exceptions during state change events on the agent. The agent is not expected to recover worker groups in ``UNKNOWN`` state and is better off self terminating and allowing the job manager to retry the node. """ UNKNOWN = "UNKNOWN" INIT = "INIT" HEALTHY = "HEALTHY" UNHEALTHY = "UNHEALTHY" STOPPED = "STOPPED" SUCCEEDED = "SUCCEEDED" FAILED = "FAILED" @staticmethod def is_running(state: "WorkerState") -> bool: """ Returns: True if the worker state represents workers still running (e.g. that the process exists but not necessarily healthy). """ return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
SimpleElasticAgent 是 Agent 的实现类之一。此抽象是为了方便扩展新的 agent 实现。从后面可知,目前内置的 LocalElasticAgent
负责管理单机上的所有 worker 进程,如果用户希望只用一个代理就管理多机上所有的 worker,而不仅仅是本机 worker,那么可以通过扩展 SimpleElasticAgent
来实现一个自定义 Agent。
class SimpleElasticAgent(ElasticAgent): """ An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` (e.g. one particular type of worker role). """ def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): self._worker_group = WorkerGroup(spec) self._remaining_restarts = self._worker_group.spec.max_restarts self._store = None self._exit_barrier_timeout = exit_barrier_timeout self._total_execution_time = 0
SimpleElasticAgent 主循环 _invoke_run 是核心逻辑(这里默认代理和worker在同一个机器之上),其中做如下操作:
self._initialize_workers(self._worker_group)
完成初始化工作,比如来启动 worker,为每个worker 分配 rank 等等。def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: # NOTE: currently only works for a single role spec = self._worker_group.spec role = spec.role self._initialize_workers(self._worker_group) # 启动worker monitor_interval = spec.monitor_interval rdzv_handler = spec.rdzv_handler while True: assert self._worker_group.state != WorkerState.INIT # 定期监控 time.sleep(monitor_interval) # 监控客户程序运行情况 run_result = self._monitor_workers(self._worker_group) # 得到进程运行结果 state = run_result.state self._worker_group.state = state put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) put_metric(f"workers.{role}.{state.name.lower()}", 1) if state == WorkerState.SUCCEEDED: # 程序正常结束 self._exit_barrier() return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: # 程序出错 if self._remaining_restarts > 0: # 重试 self._remaining_restarts -= 1 self._restart_workers(self._worker_group) else: self._stop_workers(self._worker_group) # 重试次数达到,结束workers self._worker_group.state = WorkerState.FAILED self._exit_barrier() return run_result elif state == WorkerState.HEALTHY: # 节点成员关系有变化,比如scale up,就会有新节点waiting # membership changes do not count as retries num_nodes_waiting = rdzv_handler.num_nodes_waiting() group_rank = self._worker_group.group_rank # 如果有新的节点在waiting,就重启所有workers if num_nodes_waiting > 0: self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state")
上面只是大概讲了下这个总体流程,我们接下来对这个总体流程逐一分析。
代理主循环之中,首先使用 self._initialize_workers(self._worker_group)
来启动 worker。在 _initialize_workers
之中:
self._rendezvous(worker_group)
进行节点之间的同步共识操作以及rank处理等等。_start_workers
启动 workers。这里的 _start_workers
是虚函数,需要派生类实现。@prof def _initialize_workers(self, worker_group: WorkerGroup) -> None: r""" Starts a fresh set of workers for the worker_group. Essentially a rendezvous followed by a start_workers. The caller should first call ``_stop_workers()`` to stop running workers prior to calling this method. Optimistically sets the state of the worker group that just started as ``HEALTHY`` and delegates the actual monitoring of state to ``_monitor_workers()`` method """ role = worker_group.spec.role # TODO after stopping workers, wait at least monitor_interval*2 for # workers on different nodes to fail on a collective op before waiting # on the rdzv barrier, this way we ensure that nodes enter rdzv # at around the same time and reduce false positive rdzv timeout errors self._rendezvous(worker_group) # 同步共识操作 worker_ids = self._start_workers(worker_group) # 启动worker for local_rank, w_id in worker_ids.items(): worker = worker_group.workers[local_rank] worker.id = w_id worker_group.state = WorkerState.HEALTHY
我们首先看看_rendezvous,其做如下操作:
以上两点都是利用 rendezvous 的信息来进行处理,比如从 rendezvous 之中提取 ranks。
@prof def _rendezvous(self, worker_group: WorkerGroup) -> None: r""" Runs rendezvous for the workers specified by worker spec. Assigns workers a new global rank and world size. Updates the rendezvous store for the worker group. """ spec = worker_group.spec # 处理成员关系变化,注意,这里得到的是 group rank! store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous() self._store = store # store被设置到 Agent之中,store可以被认为是远端KV存储 # 依据 group rank 为 worker 建立 ranks workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec) worker_group.workers = workers worker_group.store = store worker_group.group_rank = group_rank worker_group.group_world_size = group_world_size if group_rank == 0: self._set_master_addr_port(store, spec.master_addr, spec.master_port) master_addr, master_port = self._get_master_addr_port(store) restart_count = spec.max_restarts - self._remaining_restarts
Elastic 调用 rdzv_handler.next_rendezvous() 来处理成员关系变化,目的是启动下一轮 rendezvous 操作(因为本worker已经启动,需要加入集群)。
注意,next_rendezvous 是 RendezvousHandler 的内部函数。这一函数调用会被阻塞,直到 worker 的数量达到了要求。在 worker 被初始化,或者重启的时候,这一函数都会被调用。当函数返回时,不同的 worker group 会以返回中的 rank 作为唯一的标示。其内部逻辑是:
_RendezvousExitOp
让该node退出。_RendezvousJoinOp
把该node重新加入。def next_rendezvous(self) -> Tuple[Store, int, int]: """See base class.""" self._stop_heartbeats() # Delay the execution for a small random amount of time if this is our # first run. This will slightly skew the rendezvous attempts across the # nodes and reduce the load on the backend. if self._state_holder.state.round == 0: _delay(seconds=(0, 0.3)) exit_op = _RendezvousExitOp() join_op = _RendezvousJoinOp() deadline = self._get_deadline(self._settings.timeout.join) self._op_executor.run(exit_op, deadline) self._op_executor.run(join_op, deadline) self._start_heartbeats() rank, world_size = self._get_world() store = self._get_store() return store, rank, world_size # 返回的是 worker group 的rank
接着是调用 _assign_worker_ranks 为 worker 建立 ranks。分配 rank 算法如下:
然后生成 workers,把 worker 都赋值在 worker_group.workers 之中。
@prof def _assign_worker_ranks( self, store, group_rank: int, group_world_size: int, spec: WorkerSpec ) -> List[Worker]: """ Determines proper ranks for worker processes. The rank assignment is done according to the following algorithm: 1. Each agent writes its configuration(group_rank, group_world_size , num_workers) to the common store. 2. Each agent retrieves configuration for all agents and performs two level sort using role and rank. 3. Determine the global rank: the global rank of the workers for the current agent is the offset of the infos array up to group_rank of the agent. The offset is computed as a sum of local_world_size of all agents that have rank less than the group_rank. The workers would have the ranks: [offset, offset+local_world_size) 4. Determine the role rank: The role rank is determined using the algorithms in the point 3 with the exception that the offset is done from the first agent that has the same role as current one and has the minimum group rank. """ # 每个代理将其配置(group_rank, group_world_size, num_workers)写入公共存储。 role_infos = self._share_and_gather(store, group_rank, group_world_size, spec) # 每个代理检索所有代理的配置,并使用角色和rank执行两级排序。 my_role_info = role_infos[group_rank] # 确定全局rank:当前代理的global rank是 本代理 的 group_rank 在infos数组的偏移量(offset)。偏移量的计算方法是,排名低于group_rank的所有代理的local_world之和。workers 的等级为:[offset, offset+local_world_size]。 worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank) role_infos = sorted( role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare) ) role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries( role_infos, my_role_info.role ) role_pos = next( idx for idx, role_info in enumerate(role_infos) if _RoleInstanceInfo.compare(role_info, my_role_info) == 0 ) # 确定role rank:使用第3点中的算法确定role rank,不同之处是:偏移量计算是从与当前角色相同且具有最小 group rank 的第一个代理开始。 role_world_size, role_ranks = self._get_ranks( role_infos, role_pos, role_start_idx, role_end_idx + 1 ) # 生成 workers,把 worker 都赋值在 worker_group.workers 之中。 workers = [] for ind in range(spec.local_world_size): worker = Worker( local_rank=ind, global_rank=worker_global_ranks[ind], role_rank=role_ranks[ind], world_size=worker_world_size, role_world_size=role_world_size, ) workers.append(worker) return workers
调用 派生类的 _start_workers 来启动 worker 进程,因此基类这里没有实现,我们后续会看到派生类如何实现。
@abc.abstractmethod def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: r""" Starts ``worker_group.spec.local_world_size`` number of workers according to worker spec for the worker group . Returns a map of ``local_rank`` to worker ``id``. """ raise NotImplementedError()
目前逻辑如下,具体是:
+--------------------------------------------------+ | LocalElasticAgent | _initialize_workers | | + | | | | | | | +----------------------+ | v | |WorkerGroup | | _rendezvous(worker_group) | | | | + | | spec | | | | | | | | 1 | | group_world_size | | v | | | | rdzv_handler.next_rendezvous() | | store | | + | | | +----------------+ | | | | group_rank | | Worker0(rank 0)| | 2 | ranks | | | | Worker1(rank 1)| | Workers v | | workers +----------> | ... | | <----+ _assign_worker_ranks | | | | Workern(rank n)| | 3 | +----------------------+ +----------------+ | | | +--------------------------------------------------+
接下来会分别把 rank 相关和 worker 相关的函数都分别罗列出来,以便大家更好的理解。
前面的 _assign_worker_ranks 为 worker 建立 ranks,但是其内部有些细节我们还需要梳理一下。
这里要介绍一下 _RoleInstanceInfo 这个数据结构。代理使用该类与其他代理交换信息。该信息用于确定本代理workers的rank。这些代理工作在异构环境下,不同代理也许有不同数量的workers。其构建参数是:
class _RoleInstanceInfo: """ The class is used by the agent to exchange the information with other agents. The information is used to determine the rank of the workers that agent manages in heterogeneous environments, where different agents can have different number of workers. """ __slots__ = ["role", "rank", "local_world_size"] def __init__(self, role: str, rank: int, local_world_size: int): r""" Args: role (str): user-defined role for the workers with this spec rank (int): the rank of the agent local_world_size (int): number of local workers to run """ self.role = role self.rank = rank self.local_world_size = local_world_size def serialize(self) -> bytes: dict_data = { "role": self.role, "rank": self.rank, "local_world_size": self.local_world_size, } return json.dumps(dict_data).encode(encoding="UTF-8") @staticmethod def deserialize(data: bytes): dict_data = json.loads(data.decode(encoding="UTF-8")) return _RoleInstanceInfo( dict_data["role"], dict_data["rank"], dict_data["local_world_size"] ) @staticmethod def compare(obj1, obj2) -> int: if obj1.role == obj2.role: return obj1.rank - obj2.rank elif obj1.role > obj2.role: return 1 else: return -1 @staticmethod def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]: start_idx, end_idx = -1, -1 for idx, role_info in enumerate(roles_infos): if role_info.role == role: if start_idx == -1: start_idx = idx end_idx = idx return (start_idx, end_idx)
_share_and_gather 的作用是在各个代理之间同步,得到角色的总体信息。每个代理将其配置(group_rank, group_world_size , num_workers)写入公共存储。这里就是使用之前 Rendezvous 返回的 store 来进行信息共享。
def _share_and_gather( self, store, group_rank: int, group_world_size: int, spec: WorkerSpec ) -> List: agent_role_info = _RoleInstanceInfo( spec.role, group_rank, spec.local_world_size ) key_prefix = "torchelastic/role_info" agent_config_enc = agent_role_info.serialize() role_infos_bytes = store_util.synchronize( store, agent_config_enc, group_rank, group_world_size, key_prefix ) role_infos = [ _RoleInstanceInfo.deserialize(role_info_bytes) for role_info_bytes in role_infos_bytes ] return role_infos
依据 role infos 来确定全局rank:当前代理的global rank是 本代理 的 group_rank 在infos数组的偏移量(offset)。偏移量的计算方法是,排名低于group_rank的所有代理的local_world之和。workers 的等级为:[offset, offset+local_world_size]。
def _get_ranks( self, role_infos: List[_RoleInstanceInfo], role_idx: int, start_idx: int = 0, end_idx: int = -1, ) -> Tuple[int, List[int]]: if end_idx == -1: end_idx = len(role_infos) prefix_sum = 0 total_sum = 0 for idx in range(start_idx, end_idx): if role_idx > idx: prefix_sum += role_infos[idx].local_world_size total_sum += role_infos[idx].local_world_size return ( total_sum, list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)), )
目前逻辑拓展如下:
_initialize_workers + | | v _rendezvous(worker_group) + +----------------------------------------------+ | | LocalElasticAgent | | 1 | | 2 v | +--------------+ rdzv_handler.next_rendezvous() | +--------------------+ | | + | | WorkerGroup | | | | | | | | | 3 | ranks | | | | | v | | spec | | | +--------------+------------------+ | | | | | | _assign_worker_ranks | | | | | | | | | | store <----------------------------+ | | 4 | | | | | | | role_infos = _share_and_gather( | | | | | | | + store) | | | group_world_size<--------------------+ | | | 5 | | | | | | | | | | | | | | | v | | | group_rank <-------------------------+ | | _get_ranks(world...) | | | | | | _get_ranks(role...) | | | | +----------------+ | | + | | | workers +----------->+ Worker0(rank 0)| | | | | | | | | Worker1(rank 1)| | | | 6 | | | | | ... | |Workers| v | | | | | Workern(rank n)+<------------+ new Worker(local_rank, | | +--------------------+ +----------------+ | 7 | global_rank, | | | | role_rank, | +----------------------------------------------+ | world_size, | | role_world_size) | | | +---------------------------------+
_rendezvous 操作之后,Worker 实例已经生成了,接下来就看看如何生成 Worker 进程。但是因为这些方法在 SimpleElasticAgent 之中并没有实现,所以我们需要在其派生类 LocalElasticAgent 分析小节才能继续拓展我们的逻辑图。
我们先看看 SimpleElasticAgent 剩余两个 worker 相关函数。
_restart_workers 是重启 workers。
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof def _restart_workers(self, worker_group: WorkerGroup) -> None: """ Restarts (stops, rendezvous, starts) all local workers in the group. """ role = worker_group.spec.role self._stop_workers(worker_group) worker_group.state = WorkerState.STOPPED self._initialize_workers(worker_group)
实际上,几乎不可能保证DDP的所有 worker 都能保证同时结束,所以因此TE提供了一个finalization barrier,这个barrier的作用是对worker finalization 实施等待超时(5分钟)。
def _exit_barrier(self): """ Wait for ``exit_barrier_timeout`` seconds for all agents to finish executing their local workers (either successfully or not). This acts as a safety guard against user scripts that terminate at different times. This barrier keeps the agent process alive until all workers finish. """ start = time.time() try: store_util.barrier( self._store, self._worker_group.group_rank, self._worker_group.group_world_size, key_prefix=_TERMINAL_STATE_SYNC_ID, barrier_timeout=self._exit_barrier_timeout, ) except Exception: log.exception( f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds" )
LocalElasticAgent 是弹性训练最终使用的代理,主要用于在本地进行操作,负责管理单机上所有的worker进程,其派生了 SimpleElasticAgent。
此代理在每个主机之上部署,并配置为生成n
个工作进程。当使用GPU时,n是主机上可用的GPU数量。本地代理不会与部署在其他主机上的其他本地代理通信,即使worker可以在主机间通信。Worker id被解释为本地进程。代理作为把本机所有工作进程作为一个整体启动和停止。
传递给worker的函数和参数必须与python multiprocessing兼容。要将multiprocessing数据结构传递给worker,用户可以在与指定的start_method相同的多处理multiprocessing中创建数据结构,并将其作为函数参数传递。
exit_barrier_timeout用来指定等待其他代理完成的时间量(以秒为单位)。这起到了一个安全网的作用,可以处理worker在不同时间完成的情况,以防止代理将提前完成的worker视为scale-down事件。强烈建议用户代码确保worker以同步方式终止,而不是依赖于exit_barrier_timeout。
SimpleElasticAgent 主要是提供给了初始化和总体运行方式,但是遗留了一些抽象函数没有被实现,比如_start_workers
,_stop_workers
,_monitor_workers
,_shutdown
。LocalElasticAgent 就补齐了这些函数。
class LocalElasticAgent(SimpleElasticAgent): """ An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. This agent is deployed per host and is configured to spawn ``n`` workers. When using GPUs, ``n`` maps to the number of GPUs available on the host. The local agent does not communicate to other local agents deployed on other hosts, even if the workers may communicate inter-host. The worker id is interpreted to be a local process. The agent starts and stops all worker processes as a single unit. The worker function and argument passed to the worker function must be python multiprocessing compatible. To pass multiprocessing data structures to the workers you may create the data structure in the same multiprocessing context as the specified ``start_method`` and pass it as a function argument. The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait for other agents to finish. This acts as a safety net to handle cases where workers finish at different times, to prevent agents from viewing workers that finished early as a scale-down event. It is strongly advised that the user code deal with ensuring that workers are terminated in a synchronous manner rather than relying on the exit_barrier_timeout. """ def __init__( self, spec: WorkerSpec, start_method="spawn", exit_barrier_timeout: float = 300, log_dir: Optional[str] = None, ): super().__init__(spec, exit_barrier_timeout) self._start_method = start_method self._pcontext: Optional[PContext] = None rdzv_run_id = spec.rdzv_handler.get_run_id() self._log_dir = self._make_log_dir(log_dir, rdzv_run_id) def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") os.makedirs(base_log_dir, exist_ok=True) dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) return dir
我们先从其注释中提取代码,看看如何使用。以下是如何把function作为入口来启动。
def trainer(args) -> str: return "do train" def main(): start_method="spawn" shared_queue= multiprocessing.get_context(start_method).Queue() spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint=trainer, args=("foobar",), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec, start_method) results = agent.run() if results.is_failed(): print("trainer failed") else: print(f"rank 0 return value: {results.return_values[0]}") # prints -> rank 0 return value: do train
以下是如何把binary作为入口来启动。
def main(): spec = WorkerSpec( role="trainer", local_world_size=nproc_per_process, entrypoint="/usr/local/bin/trainer", args=("--trainer_args", "foobar"), ...<OTHER_PARAMS...>) agent = LocalElasticAgent(spec) results = agent.run() if not results.is_failed(): print("binary launches do not have return values")
_rendezvous 操作之后,Worker 实例已经生成了,接下来就看看如何生成 Worker 进程。
以下函数会停止workers。
@prof def _stop_workers(self, worker_group: WorkerGroup) -> None: self._shutdown() def _shutdown(self) -> None: if self._pcontext: self._pcontext.close()
我们接着前文来说,_rendezvous 操作之后,Worker 实例已经生成了,接下来就看看如何生成 Worker 进程。之前因为这些方法在 SimpleElasticAgent 之中并没有实现,所以我们在本小结继续拓展我们的逻辑图。
我们先再看看初始化workers。在 _initialize_workers之中,首先使用 _rendezvous 建立 workers 实例,其次调用 _start_workers 启动 workers。
@prof def _initialize_workers(self, worker_group: WorkerGroup) -> None: r""" Starts a fresh set of workers for the worker_group. Essentially a rendezvous followed by a start_workers. The caller should first call ``_stop_workers()`` to stop running workers prior to calling this method. Optimistically sets the state of the worker group that just started as ``HEALTHY`` and delegates the actual monitoring of state to ``_monitor_workers()`` method """ role = worker_group.spec.role # TODO after stopping workers, wait at least monitor_interval*2 for # workers on different nodes to fail on a collective op before waiting # on the rdzv barrier, this way we ensure that nodes enter rdzv # at around the same time and reduce false positive rdzv timeout errors self._rendezvous(worker_group) # Worker实例已经生成了 worker_ids = self._start_workers(worker_group) # 启动Worker进程 for local_rank, w_id in worker_ids.items(): worker = worker_group.workers[local_rank] worker.id = w_id # 得到进程ID worker_group.state = WorkerState.HEALTHY
_start_workers
方法会调用 start_processes 来启动 worker 进程,默认_start_method 是 "spawn"。也就是启动了多个进程,并行执行用户程序。同时这些进程的运行结果会被监控。start_processes 参数之中,entrypoint
和args
是用户命令和参数,entrypoint可以是函数或者字符串。
_start_workers
把 start_processes 方法启动多线程的结果保存在 _pcontext 之中,后续就用 _pcontext 来继续控制,比如结束 worker 就是直接调用 _pcontext 的 close方法。
@prof def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: spec = worker_group.spec store = worker_group.store assert store is not None master_addr, master_port = super()._get_master_addr_port(store) restart_count = spec.max_restarts - self._remaining_restarts use_agent_store = spec.rdzv_handler.get_backend() == "static" args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { "LOCAL_RANK": str(local_rank), "RANK": str(worker.global_rank), "GROUP_RANK": str(worker_group.group_rank), "ROLE_RANK": str(worker.role_rank), "ROLE_NAME": spec.role, "LOCAL_WORLD_SIZE": str(spec.local_world_size), "WORLD_SIZE": str(worker.world_size), "GROUP_WORLD_SIZE": str(worker_group.group_world_size), "ROLE_WORLD_SIZE": str(worker.role_world_size), "MASTER_ADDR": master_addr, "MASTER_PORT": str(master_port), "TORCHELASTIC_RESTART_COUNT": str(restart_count), "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), "NCCL_ASYNC_ERROR_HANDLING": str(1), } if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] envs[local_rank] = worker_env worker_args = list(spec.args) worker_args = macros.substitute(worker_args, str(local_rank)) args[local_rank] = tuple(worker_args) # scaling events do not count towards restarts (gets same attempt #) # remove existing log dir if this restart is due to a scaling event attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") shutil.rmtree(attempt_log_dir, ignore_errors=True) os.makedirs(attempt_log_dir) self._pcontext = start_processes( # 把启动多线程的结果保存在 _pcontext 之中。 name=spec.role, entrypoint=spec.entrypoint, args=args, envs=envs, log_dir=attempt_log_dir, start_method=self._start_method, redirects=spec.redirects, tee=spec.tee, ) return self._pcontext.pids()
运行之后,TE 会调用 _monitor_workers
对workers进行监控。之前把启动多线程的结果保存在 _pcontext
之中,现在就用 _pcontext
对运行情况进行监控。
@prof def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: role = worker_group.spec.role worker_pids = {w.id for w in worker_group.workers} assert self._pcontext is not None pc_pids = set(self._pcontext.pids().values()) if worker_pids != pc_pids: return RunResult(state=WorkerState.UNKNOWN) result = self._pcontext.wait(0) # 对运行结构进行监控 if result: if result.is_failed(): # 如果进程失败 # map local rank failure to global rank worker_failures = {} # 返回的结果内部就包括每个进程的运行结果 for local_rank, failure in result.failures.items(): worker = worker_group.workers[local_rank] worker_failures[worker.global_rank] = failure return RunResult( state=WorkerState.FAILED, failures=worker_failures, # 返回运行结果 ) else: # copy ret_val_queue into a map with a global ranks workers_ret_vals = {} for local_rank, ret_val in result.return_values.items(): worker = worker_group.workers[local_rank] workers_ret_vals[worker.global_rank] = ret_val return RunResult( state=WorkerState.SUCCEEDED, return_values=workers_ret_vals, # 返回运行结果 ) else: return RunResult(state=WorkerState.HEALTHY)
因为启动和监控涉及到系统整体运行逻辑,需要和 rendezvous 一起才能更好理解,所以我们把这部分的分析推迟,等到 Rendezvous 之后再来做整体分析。
目前总体逻辑如下:
_initialize_workers + | | v _rendezvous(worker_group) + +----------------------------------------------+ | | LocalElasticAgent | | 1 | | 2 v | +--------------+ rdzv_handler.next_rendezvous() | +--------------------+ | | + | | WorkerGroup | | | | | | | | | 3 | ranks | | | | | v | | spec | | | +--------------+------------------+ | | | | | | _assign_worker_ranks | | | | | | | | | | store <----------------------------+ | | 4 | | | | | | | role_infos = _share_and_gather( | | | | | | | + store) | | | group_world_size<--------------------+ | | | 5 | | | | | | | | | | | | | | | v | | | group_rank <-------------------------+ | | _get_ranks(world...) | | | | | | _get_ranks(role...) | | | | +----------------+ | | + | | | workers +----------->+ Worker0(rank 0)| | | | | | | | | Worker1(rank 1)| | | | 6 | | | | | ... | |Workers| v | | | | | Workern(rank n)+<------------+ new Worker(local_rank, | | +--------------------+ +---------+------+ | 7 | global_rank, | | ^ | | role_rank, | | | | | world_size, | | | | | role_world_size) | +----------------------------------------------+ | | | +---------------+-----------------+ | | | | 8 | 9 v +-----------------------+ _start_workers + | 10 | v +---------------+--------------+ | state = _monitor_workers | +--> | +--> | +---------------+--------------+ | | | | <--------------------------------------+ LOOP Every 30S | | 11 v _exit_barrier
手机如下:
TorchElastic - 弹性、容错的分布式训练