在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第五篇,看看Rendezvous 的内部引擎,比如如何处理节点加入,节点离开,等待,心跳等等。
弹性训练系列文章如下:
[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
[源码解析] PyTorch 分布式之弹性训练(3)---代理
[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑
弹性训练可以理解为在 Rendezvous 基础之上的一个运行系统。
Agent 偏重具体节点上的逻辑
Rendezvous 负责
集群逻辑
,保证节点之间对于""有哪些节点参与训练"达成强一致共识。
torch.distributed.Store
API。此存储仅由已完成Rendezvous的成员共享,它旨在让Torch Distributed Elastic在初始化作业过程之中交换控制和数据信息。目前为止,Rendezvous 信息如下,DynamicRendezvousHandler 属于动态逻辑,其中,_RendezvousStateHolder
是状态等元信息存储(静态结构),大家会发现图中还有一个 _RendezvousOpExecutor 没有介绍,这就是运行时引擎,所以我们本文看看 _RendezvousOpExecutor 如何处理。
+-----------------------------+ +------------------------------------------------+ | LocalElasticAgent | | WorkerSpec | | | | | | +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+ | |WorkerGroup | | | | | | | spec +--------------> | entry = worker_fn | | | | workers | | | | | | | store | | | role = {str} 'trainer' | | | | group_rank | | | | | | | group_world_size | | +------------------------------------------------+ | | | | | | | +------------------------+ | | | | | | rdzv_run_id | | | store | +-----------------------------------------+ | | | |DynamicRendezvousHandler | | +-----------------------------+ | | | | | | | _settings: RendezvousSettings | <--+ | | | _store: Store | | | | _state_holder: _RendezvousStateHolder | | | | _op_executor: _RendezvousOpExecutor | | | +-----------------------------------------+
_RendezvousOpExecutor 把功能分割解耦:
_RendevzousJoinOp
。_RendezvousOpExecutor
引擎来执行各种算子,依据算子结果,得到一个 Action,再利用 Action 调用业务函数进行操作。本文主要介绍C10d 后端对应的 Rendezvous 引擎。
_RendezvousOpExecutor 是引擎的基类,只是定义了run这个虚函数。
class _RendezvousOpExecutor(ABC): """Executes rendezvous operations.""" @abstractmethod def run( self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float ) -> None: """Executes a rendezvous operation. An operation is run inside a state machine and is expected to transition the rendezvous from one state to another. Args: state_handler: A callable that is expected to return the next state transition action based on the current state of the rendezvous. deadline: The time, in seconds, at which the operation will be considered timed-out. """
这里用到了 _RendezvousContext,其作用是把 Rendezvous 的各种信息封装了起来,提供给操作引擎。这里就有了 _RendezvousState 和 RendezvousSettings 的使用。
class _RendezvousContext: """Holds the context of the rendezvous. Attributes: node: The node descriptor associated with the current rendezvous handler instance. state: The current state of the rendezvous. settings: The rendezvous settings. """ node: _NodeDesc state: _RendezvousState settings: RendezvousSettings def __init__( self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings ) -> None: self.node = node self.state = state self.settings = settings
_DistributedRendezvousOpExecutor 拓展了 _RendezvousOpExecutor,是 ElasticTorch 的实际执行者。类似于 Looper,负责消息分发,调用业务,状态维护。
与其基类相比,_DistributedRendezvousOpExecutor 加入了比如节点信息,状态,配置这样的成员变量。
class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): """Executes rendezvous operations using a shared state. Args: node: The node descriptor associated with the current rendezvous handler instance. state_holder: The ``RendezvousStateHolder`` to use to sync the rendezvous state with other nodes. settings: The rendezvous settings. """ _node: _NodeDesc _state: _RendezvousState _state_holder: _RendezvousStateHolder _settings: RendezvousSettings def __init__( self, node: _NodeDesc, state_holder: _RendezvousStateHolder, settings: RendezvousSettings, ) -> None: self._node = node self._state_holder = state_holder self._settings = settings
逻辑如下:
+---------------------------------------------------------------+ | _DistributedRendezvousOpExecutor | | | | +------------------------+ | | _state +---> | _RendezvousState | | | | | | | | participants | | | | wait_list | | | | last_heartbeats | | | | deadline | | | +------------------------+ | | | | +-------------------------+ | | _settings +--> | RendezvousSettings | | | | | | | +-------------------------+ | | | | +--------------------------------------+ | | _state_holder +---> | _BackendRendezvousStateHolder | | | | | | | | _backend: RendezvousBackend | | | | _state: _RendezvousState | | | | _settings: RendezvousSettings | | | | | | | +--------------------------------------+ | | +--------------------------------------+ | | | _NodeDesc | | | _node +-------> | fqdn: str | | | | pid: int | | | | local_id: int | | | | | | | +--------------------------------------+ | +---------------------------------------------------------------+
我们举出几个例子来看看如何调用引擎,可以看到都是先设置算子,然后调用引擎的run函数。
def _keep_alive(self) -> None: self._heartbeat_lock.acquire() op = _RendezvousKeepAliveOp() # 设置算子 deadline = self._get_deadline(self._settings.timeout.heartbeat) self._op_executor.run(op, deadline) # 调用
def _close(self) -> None: op = _RendezvousCloseOp() # 设置算子 deadline = self._get_deadline(self._settings.timeout.close) self._op_executor.run(op, deadline) # 调用
def next_rendezvous(self) -> Tuple[Store, int, int]: """See base class.""" self._stop_heartbeats() # Delay the execution for a small random amount of time if this is our # first run. This will slightly skew the rendezvous attempts across the # nodes and reduce the load on the backend. if self._state_holder.state.round == 0: _delay(seconds=(0, 0.3)) exit_op = _RendezvousExitOp() # 设置算子 join_op = _RendezvousJoinOp() # 设置算子 deadline = self._get_deadline(self._settings.timeout.join) self._op_executor.run(exit_op, deadline) # 这里会进行调用 self._op_executor.run(join_op, deadline) # 调用 self._start_heartbeats() rank, world_size = self._get_world() store = self._get_store() return store, rank, world_size
_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。
run 具体代码如下:
def run( self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float ) -> None: """See base class.""" action = None while action != _Action.FINISH: # 循环,一直到获得一个FINISH action 为止 # Reads or writes the latest rendezvous state shared by all nodes in # the rendezvous. Note that our local changes might get overridden # by another node if that node synced its changes before us. # 这里很重要,在所有node之间做信息同步 has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。 self._state = self._state_holder.state ctx = _RendezvousContext(self._node, self._state, self._settings) # Determine the next action to take based on the current state of # the rendezvous. action = state_handler(ctx, deadline) # 决定下一个操作,state_handler 就是算子 if action == _Action.FINISH: continue if action == _Action.ERROR_CLOSED: raise RendezvousClosedError() if action == _Action.ERROR_TIMEOUT: raise RendezvousTimeoutError() if action == _Action.SYNC: # Delay the execution by one second to avoid overloading the # backend if we are asked to poll for state changes. _delay(seconds=1) else: if action == _Action.KEEP_ALIVE: self._keep_alive() elif action == _Action.ADD_TO_PARTICIPANTS: self._add_to_participants() elif action == _Action.ADD_TO_WAIT_LIST: self._add_to_wait_list() elif action == _Action.REMOVE_FROM_PARTICIPANTS: self._remove_from_participants() elif action == _Action.REMOVE_FROM_WAIT_LIST: self._remove_from_wait_list() elif action == _Action.MARK_RENDEZVOUS_COMPLETE: self._mark_rendezvous_complete() elif action == _Action.MARK_RENDEZVOUS_CLOSED: self._mark_rendezvous_closed() # Attempt to sync our changes back to other nodes. self._state_holder.mark_dirty()
具体如下图。
+-----------------------------------------+ +---------------------------------------------------------------+ |DynamicRendezvousHandler | | _DistributedRendezvousOpExecutor | | | | | | | | +------------------------+ | | _settings: RendezvousSettings | | _state +---> | _RendezvousState | | | | | | | | | | | | participants | | | _store: Store | | | wait_list | | | | | | last_heartbeats | | | | | | deadline | | | _state_holder: _RendezvousStateHolder | | +------------------------+ | | | run(_RendezvousJoinOp()) | +-------------------------+ | | | | _settings +--> | RendezvousSettings | | | _op_executor +------------------------------------------------> | | | | | | | +-------------------------+ | | | | +--------------------------------------+ | +-----------------------------------------+ | _state_holder +---> | _BackendRendezvousStateHolder | | | | | | | | _backend: RendezvousBackend | | | | _state: _RendezvousState | | | | _settings: RendezvousSettings | | | | | | | +--------------------------------------+ | | +--------------------------------------+ | | | _NodeDesc | | | _node +-------> | fqdn: str | | | | pid: int | | | | local_id: int | | | | | | | +--------------------------------------+ | +---------------------------------------------------------------+
手机如下:
在 run 函数之中,需要注意的是:在执行各种算子操作之前,会调用 self._state_holder.sync() 在各个 worker 之间进行一个状态同步,达成共识 (consensus)。
def sync(self) -> Optional[bool]: """See base class.""" state_bits: Optional[bytes] = None token = None has_set: Optional[bool] if self._dirty: # 如果本node状态变化了 has_set = False state_bits = pickle.dumps(self._state) # 把自己的状态设置到backend之中 set_response = self._backend.set_state(state_bits, self._token) if set_response is not None: state_bits, token, has_set = set_response else: # 自己没变化,只能从后端获取 has_set = None if self._cache_duration > 0: # Avoid overloading the backend if we are asked to retrieve the # state repeatedly. Try to serve the cached state. if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0): return None get_response = self._backend.get_state() # 从backend获取其他节点最新状态 if get_response is not None: state_bits, token = get_response if state_bits is not None: try: self._state = pickle.loads(state_bits) # 用后端状态更新本身的状态 except pickle.PickleError as exc: raise RendezvousStateError( "The rendezvous state is corrupt. See inner exception for details." ) from exc else: self._state = _RendezvousState() if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG): node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) msg = ( f"As part of the sync operation the node(s) {node_list} have been removed from the " f"rendezvous '{self._settings.run_id}' since they had no heartbeat." ) self._record(message=msg) self._token = token self._dirty = False self._last_sync_time = time.monotonic() self._sanitize() return has_set
torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py 之中是对应后端代码。
后端这里使用 store 作为一个集中式存储,是master。每个 node 是 client,会去master更新自己状态,并且获取其他node状态。这样所有node就会互通有无,达成共识。这里也会定期删除不更新元数据的clients。
get_state 就是简单的从 store 提取。
def get_state(self) -> Optional[Tuple[bytes, Token]]: """See base class.""" base64_state: bytes = self._call_store("get", self._key) return self._decode_state(base64_state)
set_state 会做一个compare set,其返回new state和是否更新了state。
def set_state( self, state: bytes, token: Optional[Token] = None ) -> Optional[Tuple[bytes, Token, bool]]: """See base class.""" base64_state_str: str = b64encode(state).decode() if token: # Shortcut if we know for sure that the token is not valid. if not isinstance(token, bytes): result = self.get_state() if result is not None: tmp = *result, False # Python 3.6 does not support tuple unpacking in return # statements. return tmp return None token = token.decode() else: token = self._NULL_SENTINEL base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) state_token_pair = self._decode_state(base64_state) if state_token_pair is None: return None new_state, new_token = state_token_pair # C10d Store's compare_set method does not offer an easy way to find out # whether our write attempt was successful. As a brute-force solution we # perform a bitwise comparison of our local state and the remote state. return new_state, new_token, new_state == state
_sanitize 方法用来依据其他节点消息做处理,比如清理故障节点。即,如果上一次的心跳时间超过了一定阈值范围,则会把这些节点标记为dead_node,并且从 participant或者wait list中清除这些节点。
def _sanitize(self) -> None: state = self._state expire_time = datetime.utcnow() - ( self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt ) # Filter out the dead nodes. self._dead_nodes = [ node for node, last_heartbeat in state.last_heartbeats.items() if last_heartbeat < expire_time ] participant_removed = False for dead_node in self._dead_nodes: del state.last_heartbeats[dead_node] # 移除故障节点 try: del state.participants[dead_node] # 移除故障节点 participant_removed = True except KeyError: pass try: state.wait_list.remove(dead_node) # 移除故障节点 except KeyError: pass if participant_removed: # Common epilogue shared with the _remove_from_participants() # function of _DistributedRendezvousOpExecutor. _remove_participant_epilogue(state, self._settings)
介绍完毕如何运行引擎,我们接下来看看具体算子。
_RendezvousOpExecutor
引擎的业务逻辑被分成两层:用户操作 和 内部业务逻辑。用户操作和内部业务机制之间被解耦。
用户操作被分成各种算子,包括:心跳,Join,关闭,结束。比如Join 算子就是 _RendevzousJoinOp
。
内部业务逻辑被分成各种业务函数,比如 _add_to_participants 方法从等待列表中移除节点,往 participants 加入这个节点。
算子和内部业务逻辑并不是一一对应,需要一个类似状态机的机制来控制。
引擎内部就是根据 Action 来执行具体业务逻辑,或者可以说,是通过 Action 进行解耦。
具体如下,引擎从逻辑上可以分成三层:最上面是算子层,中间是 Action 层,下面是业务函数层。
+-----------------------------------------------------------------------------------------+ | | | _RendezvousKeepAliveOp _RendezvousCloseOp _RendezvousExitOp _RendezvousJoinOp | | | +-------------+---------------------+--------------------+------------------+-------------+ | | | | | | | | | | | | | | | | v v v v +-----------------------------------------------------------------------------------------+ | | | KEEP_ALIVE ADD_TO_PARTICIPANTS ADD_TO_WAIT_LIST REMOVE_FROM_WAIT_LIST ...... | | | +-------------+----------+----------+----------+---------+---------+---------+------------+ | | | | | | | | | | | | | | | | | | | | | | | | | | | | v v v v v v v +-----------------------------------------------------------------------------------------+ | | | _add_to_participants _remove_from_participants _add_to_wait_list ...... | | | | | +-----------------------------------------------------------------------------------------+
我们逐一解析。
先来解析中间层 Action,看看有多少 Action。基于 rendezvous 的状态,引擎的actions具体如下。代码位于 torch/distributed/elastic/rendezvous/dynamic_rendezvous.py
class _Action(Enum): """Specifies the possible actions based on the state of the rendezvous.""" KEEP_ALIVE = 1 ADD_TO_PARTICIPANTS = 2 ADD_TO_WAIT_LIST = 3 REMOVE_FROM_PARTICIPANTS = 4 REMOVE_FROM_WAIT_LIST = 5 MARK_RENDEZVOUS_COMPLETE = 6 MARK_RENDEZVOUS_CLOSED = 7 SYNC = 8 ERROR_CLOSED = 9 ERROR_TIMEOUT = 10 FINISH = 11
引擎之中实现了一些算子,基本上,一个操作对应一个算子,我们给出几个操作算子的例子,算子就是依据rendezvous的状态来设置操作类型。
_RendezvousKeepAliveOp 的作用是:依据当前状态和时间来确定下一步Action。主要是定期检查本Node是否故障。
class _RendezvousKeepAliveOp: """Represents a rendezvous keep-alive update operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if _should_keep_alive(ctx): if time.monotonic() > deadline: return _Action.ERROR_TIMEOUT return _Action.KEEP_ALIVE return _Action.FINISH
_should_keep_alive 方法为:
def _should_keep_alive(ctx: _RendezvousContext) -> bool: """Determines whether a keep-alive heartbeat should be sent.""" try: last_heartbeat = ctx.state.last_heartbeats[ctx.node] except KeyError: return False return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
这里要注意的是,因为做任何算子之前,都要调用 sync 操作,而 sync 会在 node 之间同步状态,因为心跳是定期的,所以同步状态也是定期的。
DynamicRendezvousHandler 之中会启动一个timer,定期调用_keep_alive_weak方法。
def _start_heartbeats(self) -> None: self._keep_alive_timer = _PeriodicTimer( self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) ) self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}") self._keep_alive_timer.start()
其次,_keep_alive_weak
会调用 self._keep_alive()
。
@staticmethod def _keep_alive_weak(weak_self) -> None: self = weak_self() if self is not None: self._keep_alive()
_keep_alive 会调用 _RendezvousKeepAliveOp。
def _keep_alive(self) -> None: self._heartbeat_lock.acquire() op = _RendezvousKeepAliveOp() deadline = self._get_deadline(self._settings.timeout.heartbeat) try: self._op_executor.run(op, deadline) msg = ( f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " f"'{self._settings.run_id}'." ) self._record(message=msg) log.debug(msg) except RendezvousError as ex: msg = ( f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." ) self._record(message=msg, node_state=NodeState.FAILED) finally: self._heartbeat_lock.release()
另外,_DistributedRendezvousOpExecutor 有一个 _keep_alive 同名函数,是用来实现内部逻辑,我们后续会讲到。
_RendezvousCloseOp 会依据当前状态和时间来确定下一步Action。
class _RendezvousCloseOp: """Represents a rendezvous close operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if ctx.state.closed: return _Action.FINISH if time.monotonic() > deadline: return _Action.ERROR_TIMEOUT return _Action.MARK_RENDEZVOUS_CLOSED
_RendezvousExitOp 依据当前状态和时间来确定下一步Action。如果本Node不在participants之中,不处理。否则返回一个从 participants 列表删除的下一步Action。如果超时则返回对应Action。
class _RendezvousExitOp: """Represents a rendezvous exit operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if ctx.node in ctx.state.participants: if time.monotonic() > deadline: return _Action.ERROR_TIMEOUT return _Action.REMOVE_FROM_PARTICIPANTS return _Action.FINISH
_RendezvousJoinOp 这里依据系统状态不同,做不同处理,比如试图把本Node加入到participant,或者 waiting list,或者继续等待,具体可以参见代码注释。
class _RendezvousJoinOp: """Represents a rendezvous join operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: state = ctx.state # 从上下文之中提取 _RendezvousState 状态 # A closed rendezvous means that it no longer accepts new nodes. if state.closed: return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED is_participant = ctx.node in state.participants # 看看是不是参与者 # If we are part of the rendezvous and it is already complete there is # no further action to take. if state.complete and is_participant: # 如果是参与者且状态是结束,就返回 _Action.FINISH return _Action.FINISH now = time.monotonic() if now > deadline: # 如果已经超时 rollback_period = 5 # 5 seconds # If we still have time to rollback (a short period on top of the # operation deadline), try to remove ourself from the rendezvous. # It is okay if we can't though as our keep-alive will eventually # expire. if now <= deadline + rollback_period: # 如果还有时间来 rollback # If we are part of the rendezvous, it means we couldn't find # enough participants to complete it on time. if is_participant: # 此时尚未达到min,虽然已经是参与者,但是需要移除 return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除 # If we are in the wait list, it means we couldn't wait till the # next round of the rendezvous. if ctx.node in state.wait_list: # 此时已经达到 max,虽然已经在等待列表之中,需要移除 return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除 return _Action.ERROR_TIMEOUT # 返回超时 if state.complete: # 如果 rendezvous 已经结束 # If we are here, it means we are not part of the rendezvous. In # case the rendezvous has capacity for additional participants add # ourself to the wait list for the next round. if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数 if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中 return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action elif is_participant: # 如果已经在参与者列表 # If the rendezvous has enough number of participants including us, # check whether we have passed the rendezvous deadline. If yes, # complete it. if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数 if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时 return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束 else: # 否则就直接加入到参与者 # The rendezvous is not complete yet and we are not part of it. Try # to join. return _Action.ADD_TO_PARTICIPANTS if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE return _Action.KEEP_ALIVE # At this point either the rendezvous is not complete, but we are part # of it, which means we have to wait for other participants to join; or # the rendezvous is complete, but we are not part of it, which means we # have to wait for the next round. return _Action.SYNC # 否则返回同步状态 _Action.SYNC
具体逻辑如下:
state.closed +--------------------------> _Action.ERROR_CLOSED | | | complete & participant +--------------------------> _Action.FINISH | | | timeout & participant +--------------------------> _Action.REMOVE_FROM_PARTICIPANTS | | | timeout & wait +--------------------------> _Action.REMOVE_FROM_WAIT_LIST | +-------------------+ | | | | timeout | _RendezvousJoinOp +------------------------------> _Action.ERROR_TIMEOUT | | | +-------------------+ | complete & < max & not wait | +--------------------------> _Action.ADD_TO_WAIT_LIST | | complete & participant & > min & deadline | +--------------------------> _Action.MARK_RENDEZVOUS_COMPLETE | | not complete & not participant | +--------------------------> _Action.ADD_TO_PARTICIPANTS | | _should_keep_alive | +--------------------------> _Action.KEEP_ALIVE | | else | +--------------------------> _Action.SYNC
以下是源码之中 ETCD 后端 Rendezvous 状态描述图,我们可以大致参考比对 c10d的状态。
可见,etcd 后端的Join可以分为4个阶段:
rendezvous
过程在进行中。RANK 0
的实例成为 master。仿照上图,我们把 c10d 拓展如下。
+ | | v +-----+------+ | | | closed +---------------> ERROR_CLOSED | | +-----+------+ | | v +-----+------+ is_participant | | | complete +---------------> FINISH | | +-----+------+ | is_participant | v +----> REMOVE_FROM_PARTICIPANTS +-----+-------+ now > deadline +-----------+ now < rollback +-----------+ | | | | | | | | | join +----------------> | timeout +---------------------->+ rollback +-----+ | | | | | | | +-----+-------+ +----+------+ +-----------+ | | | | in state.wait_list | | now > rollback | | now < deadline | +----> REMOVE_FROM_WAIT_LIST | +----------> ERROR_TIMEOUT | | complete && not is_participant && < max && not in state.wait_list | +------------------------------------------------------------------> ADD_TO_WAIT_LIST | | not complete && is_participant && > min && > deadline | +------------------------------------------------------------------> MARK_RENDEZVOUS_COMPLETE | | not complete && not is_participant | +-----------------------------------------> ADD_TO_PARTICIPANTS | | _should_keep_alive | +---------------------------> KEEP_ALIVE | | v SYNC
手机如下:
_DistributedRendezvousOpExecutor.run 的内部就是依据 action 选择不同的业务函数来执行。
if action == _Action.KEEP_ALIVE: self._keep_alive() elif action == _Action.ADD_TO_PARTICIPANTS: self._add_to_participants() elif action == _Action.ADD_TO_WAIT_LIST: self._add_to_wait_list() elif action == _Action.REMOVE_FROM_PARTICIPANTS: self._remove_from_participants() elif action == _Action.REMOVE_FROM_WAIT_LIST: self._remove_from_wait_list() elif action == _Action.MARK_RENDEZVOUS_COMPLETE: self._mark_rendezvous_complete() elif action == _Action.MARK_RENDEZVOUS_CLOSED: self._mark_rendezvous_closed()
我们接下来就看看具体这些内部函数逻辑。
接受到 ADD_TO_PARTICIPANTS 之后,调用 _add_to_participants 从等待列表中移除节点,往 participants 加入这个节点。
def _add_to_participants(self) -> None: state = self._state try: state.wait_list.remove(self._node) except KeyError: pass # The ranks of the participants will be set once the rendezvous is # complete. state.participants[self._node] = 0 self._keep_alive() if len(state.participants) == self._settings.min_nodes: state.deadline = datetime.utcnow() + self._settings.timeout.last_call if len(state.participants) == self._settings.max_nodes: self._mark_rendezvous_complete()
接受到 REMOVE_FROM_PARTICIPANTS 之后,调用 _remove_from_participants 从 participants 和 last_heartbeats 中删除参与者。
def _remove_from_participants(self) -> None: state = self._state del state.participants[self._node] del state.last_heartbeats[self._node] if state.complete: # If we do not have any participants left, move to the next round. if not state.participants: state.complete = False state.round += 1 else: if len(state.participants) < self._settings.min_nodes: state.deadline = None
接受到 ADD_TO_WAIT_LIST 之后,调用 _add_to_wait_list 网 wait_list 中加入节点。
def _add_to_wait_list(self) -> None: self._state.wait_list.add(self._node) self._keep_alive()
接受到 REMOVE_FROM_WAIT_LIST 之后,调用 _remove_from_wait_list 从 wait_list 移除节点。
def _remove_from_wait_list(self) -> None: self._state.wait_list.remove(self._node) del self._state.last_heartbeats[self._node]
接受到 MARK_RENDEZVOUS_COMPLETE 之后,当 rendezvous 聚合操作结束之后,给每一个参与者设置 rank。
每个节点上都是按照同样算法排序,所以rank在每个节点上都是一样的。
def _mark_rendezvous_complete(self) -> None: state = self._state state.complete = True state.deadline = None # Assign the ranks. for rank, node in enumerate(sorted(state.participants)): state.participants[node] = rank def _mark_rendezvous_closed(self) -> None: self._state.closed = True
接收到 KEEP_ALIVE action之后,会调用到 _keep_alive 来维持心跳。另外,keep_alive 也会在 _add_to_participants等方法内被调用,会更新本地state之中的last heartbeats,下一次 sync 时候,会把 last_heartbeats 写入键值存储,这样其他Node就可以知道这个节点的状态了。而本地则会在 _sanitize 之中依据 last_heartbeats 做处理,我们之前提到过。
def _keep_alive(self) -> None: msg = ( f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " f"'{self._settings.run_id}'. Pending sync." ) self._record(message=msg) self._state.last_heartbeats[self._node] = datetime.utcnow()
_record 方法如下:
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: construct_and_record_rdzv_event( name=f"{self.__class__.__name__}.{get_method_name()}", run_id=self._settings.run_id, message=message, node_state=node_state, hostname=self._node.fqdn, pid=self._node.pid, local_id=self._node.local_id, )
其就是调用如下代码记录log。
def record_rdzv_event(event: RdzvEvent) -> None: _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) def construct_and_record_rdzv_event( run_id: str, message: str, node_state: NodeState, name: str = "", hostname: str = "", pid: Optional[int] = None, master_endpoint: str = "", local_id: Optional[int] = None, rank: Optional[int] = None, ) -> None: # We don't want to perform an extra computation if not needed. if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): return # Set up parameters. if not hostname: hostname = socket.getfqdn() if not pid: pid = os.getpid() # Determines which file called this function. callstack = inspect.stack() filename = "no_file" if len(callstack) > 1: stack_depth_1 = callstack[1] filename = os.path.basename(stack_depth_1.filename) if not name: name = stack_depth_1.function # Delete the callstack variable. If kept, this can mess with python's # garbage collector as we are holding on to stack frame information in # the inspect module. del callstack # Set up error trace if this is an exception if node_state == NodeState.FAILED: error_trace = traceback.format_exc() else: error_trace = "" # Initialize event object event = RdzvEvent( name=f"{filename}:{name}", run_id=run_id, message=message, hostname=hostname, pid=pid, node_state=node_state, master_endpoint=master_endpoint, rank=rank, local_id=local_id, error_trace=error_trace, ) # Finally, record the event. record_rdzv_event(event)
至此,引擎部分也已经分析完毕,下一篇我们看看是否可以从整体角度再做一下全面梳理。
[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
[源码解析] PyTorch 分布式之弹性训练(3)---代理
[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑