本文是参数服务器系列第二篇,介绍ps-lite的通信模块 Van。
本系列其他文章是:
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice
邮局里有了地址簿,就需要有货车来负责拉送物件,Van 就是整个Parameter Server的通信模块,其特点如下。
VAN 目前有两个实现:
首先给出 UML 图。
下面我们只给出Van对象关键变量和成员函数说明。
其主要变量如下:
Node scheduler_ :Scheduler 节点参数,每一个node都会记录Scheduler 节点的信息;
Node my_node_ : 本节点参数。如果本节点是Scheduler,则 my_node_ 会指向上面的 scheduler_ ;
bool is_scheduler_ : 本节点是否是 scheduler;
std::unique_ptr< std::thread> receiver_thread_ :接收消息线程指针;
std::unique_ptr< std::thread> heartbeat_thread_ :发送心跳线程指针;
std::vector
Resender *resender_ = nullptr :重新发送消息指针;
std::atomic
std::unordered_map<std::string, int> connected_nodes_ : 记录了目前连接到哪些 nodes;
其主要函数功能如下:
start :建立通信初始化;
Receiving :接收消息线程的处理函数;
Heartbeat :发送心跳线程的处理函数;
ProcessAddNodeCommandAtScheduler :scheduler 的 AddNode 消息处理函数;
ProcessHearbeat:心跳包处理函数;
ProcessDataMsg :数据消息(push & pull)处理函数;
ProcessAddNodeCommand :worker 和 server 的 AddNode 消息处理函数;
ProcessBarrierCommand :Barrier 消息处理函数;
PS Lite 定义的三种角色采用多线程机制工作,每个线程承担特定的职责,在所属的 Van 实例启动时被创建。
具体描述如下:
详细代码(摘要)如下:
class Van { public: static Van *Create(const std::string &type); virtual void Start(int customer_id); int Send(const Message &msg); virtual void Stop(); inline int GetTimestamp() { return timestamp_++; } inline bool IsReady() { return ready_; } protected: //连结节点 virtual void Connect(const Node &node) = 0; //绑定到自己节点之上 virtual int Bind(const Node &node, int max_retry) = 0; //接收消息,用阻塞方式 virtual int RecvMsg(Message *msg) = 0; //发送消息 virtual int SendMsg(const Message &msg) = 0; /** * \brief pack meta into a string */ void PackMeta(const Meta &meta, char **meta_buf, int *buf_size); /** * \brief pack meta into protobuf */ void PackMetaPB(const Meta &meta, PBMeta *pb); /** * \brief unpack meta from a string */ void UnpackMeta(const char *meta_buf, int buf_size, Meta *meta); Node scheduler_; Node my_node_; bool is_scheduler_; std::mutex start_mu_; private: /** thread function for receving */ void Receiving(); /** thread function for heartbeat */ void Heartbeat(); // node's address string (i.e. ip:port) -> node id // this map is updated when ip:port is received for the first time std::unordered_map<std::string, int> connected_nodes_; // maps the id of node which is added later to the id of node // which is with the same ip:port and added first std::unordered_map<int, int> shared_node_mapping_; /** whether it is ready for sending */ std::atomic<bool> ready_{false}; std::atomic<size_t> send_bytes_{0}; size_t recv_bytes_ = 0; int num_servers_ = 0; int num_workers_ = 0; /** the thread for receiving messages */ std::unique_ptr<std::thread> receiver_thread_; /** the thread for sending heartbeat */ std::unique_ptr<std::thread> heartbeat_thread_; std::vector<int> barrier_count_; /** msg resender */ Resender *resender_ = nullptr; int drop_rate_ = 0; std::atomic<int> timestamp_{0}; int init_stage = 0; //以下是处理各种类型消息 void ProcessAddNodeCommandAtScheduler(Message *msg, Meta *nodes, Meta *recovery_nodes); void ProcessTerminateCommand(); void ProcessAddNodeCommand(Message *msg, Meta *nodes, Meta *recovery_nodes); void ProcessBarrierCommand(Message *msg); void ProcessHearbeat(Message *msg); void ProcessDataMsg(Message *msg); //更新本地NodeID void UpdateLocalID(Message *msg, std::unordered_set<int> *deadnodes_set, Meta *nodes, Meta *recovery_nodes); const char *heartbeat_timeout_val = Environment::Get()->find("PS_HEARTBEAT_TIMEOUT"); int heartbeat_timeout_ = heartbeat_timeout_val ? atoi(heartbeat_timeout_val) : 0; DISALLOW_COPY_AND_ASSIGN(Van); };
Van对象的初始化函数作用就是依据本地节点类型的不同,做不同设置,从而启动端口,建立到scheduler的连结,启动接收消息线程,心跳线程等,这样就可以进行通信了。具体如下:
receiver_thread_
,执行Van::Receiving
;关于7,8两点的进一步说明就是:
具体代码如下:
void Van::Start(int customer_id) { // get scheduler info start_mu_.lock(); if (init_stage == 0) { // 初始化scheduler_这个成员变量 scheduler_.hostname = std::string( CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI"))); scheduler_.port = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT"))); scheduler_.role = Node::SCHEDULER; scheduler_.id = kScheduler; // 确认本节点是scheduler节点 is_scheduler_ = Postoffice::Get()->is_scheduler(); // get my node info if (is_scheduler_) { // 初始化本节点,因为是scheduler,所以直接赋值 my_node_ = scheduler_; } else { auto role = Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER; const char* nhost = Environment::Get()->find("DMLC_NODE_HOST"); std::string ip; if (nhost) ip = std::string(nhost); if (ip.empty()) { const char* itf = Environment::Get()->find("DMLC_INTERFACE"); std::string interface; if (itf) interface = std::string(itf); if (interface.size()) { GetIP(interface, &ip); } else { GetAvailableInterfaceAndIP(&interface, &ip); } } int port = GetAvailablePort(); const char* pstr = Environment::Get()->find("PORT"); if (pstr) port = atoi(pstr); my_node_.hostname = ip; my_node_.role = role; my_node_.port = port; // cannot determine my id now, the scheduler will assign it later // set it explicitly to make re-register within a same process possible my_node_.id = Node::kEmpty; my_node_.customer_id = customer_id; } // bind. //绑定接口,把本节点绑定到ip:port这个socket上,理论来说这个函数就是初始化了receiver_ my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40); // connect to the scheduler // 连接上scheduler_,由于本节点就是scheduler_,其实就是初始化senders_,由于发送的节点很多,所以这里是一个map<int,void*> // 在这里就是senders_[1] = socket_1, socket_1中的body设置一点字符“ps1***”, 注意链接不是sendMsg Connect(scheduler_); // for debug use if (Environment::Get()->find("PS_DROP_MSG")) { drop_rate_ = atoi(Environment::Get()->find("PS_DROP_MSG")); } // start receiver // 开启一个接收消息的线程,这里就是处理消息 receiver_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this)); init_stage++; } start_mu_.unlock(); if (!is_scheduler_) { // let the scheduler know myself // worker和server节点会通过 ADD_NODE 消息把本地节点的信息告诉scheduler,比如角色,ip,port... Message msg; Node customer_specific_node = my_node_; customer_specific_node.customer_id = customer_id; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::ADD_NODE; msg.meta.control.node.push_back(customer_specific_node); msg.meta.timestamp = timestamp_++; Send(msg); } // wait until ready // 等待 ready_ 从false变成true,当是scheduler的时候,必须要有等worker和server节点过来,不然一直都是阻塞在这,如果是 worker/server,则是等待 scheduler 发送系统allready消息。 while (!ready_.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } start_mu_.lock(); if (init_stage == 1) { // resender if (Environment::Get()->find("PS_RESEND") && atoi(Environment::Get()->find("PS_RESEND")) != 0) { int timeout = 1000; if (Environment::Get()->find("PS_RESEND_TIMEOUT")) { timeout = atoi(Environment::Get()->find("PS_RESEND_TIMEOUT")); } // 如果设置了超时重传,就初始化resender_这个变量 resender_ = new Resender(timeout, 10, this); } if (!is_scheduler_) { // start heartbeat thread // 初始化心跳线程 heartbeat_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this)); } init_stage++; } start_mu_.unlock(); }
我们首先介绍后台线程是如何运行,然后会具体分析如何处理各种消息。
ps-lite 启动了一个后台线程 receiver_thread_ 进行接受/处理消息。
// start receiver receiver_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this));
receiver_thread_ 使用 Receiving 函数进行消息处理。
除了传递参数的数据消息外,各个节点之间控制信息有:
因此在 Receiving 之中会调用 不同处理函数处理不同类型的消息:
线程内有两个变量,因为其是在 while (true) 循环之外,所以属于线程内的全局变量,这点在阅读代码时候需要注意。
Receiving 逻辑如下:
具体代码如下
void Van::Receiving() { Meta nodes; // 以下两个可以认为是全局变量 Meta recovery_nodes; // store recovery nodes 储存康复重启的节点 recovery_nodes.control.cmd = Control::ADD_NODE; // 康复重启节点的control.cmd 都设置为 ADD_NODE while (true) { Message msg; int recv_bytes = RecvMsg(&msg); //利用receiver_ 变量拿到消息 // For debug, drop received message if (ready_.load() && drop_rate_ > 0) { unsigned seed = time(NULL) + my_node_.id; if (rand_r(&seed) % 100 < drop_rate_) { LOG(WARNING) << "Drop message " << msg.DebugString(); continue; } } CHECK_NE(recv_bytes, -1); recv_bytes_ += recv_bytes; //收到的字节数累加 if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } // duplicated message if (resender_ && resender_->AddIncomming(msg)) continue; //重传确认机制 if (!msg.meta.control.empty()) { //如果是控制类型的消息 // control msg auto& ctrl = msg.meta.control; if (ctrl.cmd == Control::TERMINATE) { ProcessTerminateCommand(); break; } else if (ctrl.cmd == Control::ADD_NODE) { ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes); //当执行到这个位置的时候继续跳转 } else if (ctrl.cmd == Control::BARRIER) { ProcessBarrierCommand(&msg); } else if (ctrl.cmd == Control::HEARTBEAT) { ProcessHearbeat(&msg); // 发回Heartbeat的ACK } else { LOG(WARNING) << "Drop unknown typed message " << msg.DebugString(); } } else { //非控制类型的消息处理方式 ProcessDataMsg(&msg); } } }
ADD_NODE 是 worker / server 用来向 scheduler 注册自身的控制消息。
先回忆下注册基本思路。
ProcessAddNodeCommand 逻辑如下。
具体代码如下:
void Van::ProcessAddNodeCommand(Message* msg, Meta* nodes, Meta* recovery_nodes) { auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set之中 auto& ctrl = msg->meta.control; //拿到收到消息里面的control信息 UpdateLocalID(msg, &dead_set, nodes, recovery_nodes); if (is_scheduler_) { // Scheduler 节点 ProcessAddNodeCommandAtScheduler(msg, nodes, recovery_nodes); } else { // Worker & Server 节点 for (const auto& node : ctrl.node) { std::string addr_str = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(addr_str) == connected_nodes_.end()) { // 现有节点会在自己连接之中查找这个新节点,发现现有连接中没有这个新节点 // 如果是新节点,则会连接现有节点(非同类型) Connect(node); // 与新节点进行连接 connected_nodes_[addr_str] = node.id; // 加入已经连接的节点 } if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_; if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_; } ready_ = true; } }
此函数作用是更新节点内部的node id 信息,也是分为两种情况,函数逻辑如下:
具体代码如下:
void Van::UpdateLocalID(Message* msg, std::unordered_set<int>* deadnodes_set, Meta* nodes, Meta* recovery_nodes) { auto& ctrl = msg->meta.control; size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers(); // assign an id if (msg->meta.sender == Meta::kEmpty) { //如果sender未设定,则处理此message的一定是Scheduler CHECK(is_scheduler_); CHECK_EQ(ctrl.node.size(), 1); //msg中的control命令中的节点集合就是worker自己,所以就是1个节点 if (nodes->control.node.size() < num_nodes) { //没有到齐 nodes->control.node.push_back(ctrl.node[0]); } else { //如果所有work和server到齐了,就进入else // some node dies and restarts CHECK(ready_.load()); for (size_t i = 0; i < nodes->control.node.size() - 1; ++i) { const auto& node = nodes->control.node[i]; if (deadnodes_set->find(node.id) != deadnodes_set->end() && node.role == ctrl.node[0].role) { auto& recovery_node = ctrl.node[0]; // assign previous node id recovery_node.id = node.id; recovery_node.is_recovery = true; nodes->control.node[i] = recovery_node; recovery_nodes->control.node.push_back(recovery_node); break; } } } } // update my id / 对普通的node,更新其rank,scheduler 节点不会起作用(因为找不到)。 // schedule发给此work节点的消息,如果发现本地的ip和port和消息中的某个一点重合,那么就把本地节点的ID(初始化时候没有ID,只是等于Empty)改为schedule发过来的 node id。 for (size_t i = 0; i < ctrl.node.size(); ++i) { const auto& node = ctrl.node[i]; if (my_node_.hostname == node.hostname && my_node_.port == node.port) { if (getenv("DMLC_RANK") == nullptr || my_node_.id == Meta::kEmpty) { my_node_ = node; std::string rank = std::to_string(Postoffice::IDtoRank(node.id)); #ifdef _MSC_VER _putenv_s("DMLC_RANK", rank.c_str()); #else setenv("DMLC_RANK", rank.c_str(), true); #endif } } } }
ProcessAddNodeCommandAtScheduler 是在 Scheduler 之内运行,是对控制类型消息的处理。
对于Scheduler节点来说,scheduler收到所有worker和server的ADD_NODE的消息后进行节点id分配并应答,即,需要设定 最新的所有node的 全局rank 并发送给所有Worker和Server。
nodes->control.node.size() == num_nodes
):
ready_ = true
; 即 scheduler 是一个 ready 状态了,不管 worker 和 server 是否确认收到ADD_NODE消息。!recovery_nodes->control.node.empty()
,这就表明是处理某些重启节点的注册行为:
CHECK_EQ(recovery_nodes->control.node.size(), 1)
来确认重启节点为 1 个)。具体代码如下:
void Van::ProcessAddNodeCommandAtScheduler(Message* msg, Meta* nodes, Meta* recovery_nodes) { recovery_nodes->control.cmd = Control::ADD_NODE; time_t t = time(NULL); size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers(); // scheduler收到所有worker和server的ADD_NODE的消息后进行节点id分配并应答 if (nodes->control.node.size() == num_nodes) { // 节点收集完全 // sort the nodes according their ip and port, 根据IP和port给worker,server排个序 std::sort(nodes->control.node.begin(), nodes->control.node.end(), [](const Node& a, const Node& b) { return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0; }); // assign node rank for (auto& node : nodes->control.node) { // 建立连接、更新心跳时间戳,给 scheduler所有连接的节点分配全局 rank。 std::string node_host_ip = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) { //如果ip:port不存在van_中的话 CHECK_EQ(node.id, Node::kEmpty); //判断是不是初始化节点 int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); //如果是sever的话,就id产生一个id号,num_servers_初始化为0 node.id = id; //将这个新节点的id赋值为id Connect(node); //连接这个新节点, 即建立一个socket, 然后senders_[id] = sender; 就是将目标id的socket存放起来后面使用 Postoffice::Get()->UpdateHeartbeat(node.id, t);//更新心跳包 connected_nodes_[node_host_ip] = id; //既然 worker, server 已经发message来了,scheduler要把这个节点作为已经链接的节点 } else { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); shared_node_mapping_[id] = connected_nodes_[node_host_ip]; node.id = connected_nodes_[node_host_ip]; } if (node.role == Node::SERVER) num_servers_++;//更新rank if (node.role == Node::WORKER) num_workers_++; } nodes->control.node.push_back(my_node_); //把本节点放到里面 nodes->control.cmd = Control::ADD_NODE; Message back; back.meta = *nodes; // 向所有的worker和server发送ADD_NODE消息 for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { int recver_id = r; if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) { back.meta.recver = recver_id; back.meta.timestamp = timestamp_++; Send(back); } } ready_ = true; //scheduler已经准备好了 } else if (!recovery_nodes->control.node.empty()) { // 节点没有收集完全 auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set // send back the recovery node CHECK_EQ(recovery_nodes->control.node.size(), 1); Connect(recovery_nodes->control.node[0]); Postoffice::Get()->UpdateHeartbeat(recovery_nodes->control.node[0].id, t); Message back; for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { if (r != recovery_nodes->control.node[0].id && dead_set.find(r) != dead_set.end()) { // do not try to send anything to dead node continue; } // only send recovery_node to nodes already exist // but send all nodes to the recovery_node back.meta = (r == recovery_nodes->control.node[0].id) ? *nodes : *recovery_nodes; back.meta.recver = r; back.meta.timestamp = timestamp_++; Send(back); } } }
此部分流程逻辑如下:
+ Scheduler | Worker | + | + | | | | | | v | | Postoffice::Start +----> Van::Start | | + | | | | | | | | v | | Connect--do nothing | | + | v | | | | Postoffice::Start +-----> Van::Start | | + v | | receiver_thread_ +---+ | | + | | v | | | Connect--to scheduler | | | + | | | | | | | | | | | | | | | v | | | receiver_thread_ +----->+ | | | + | | | | | | | | | | | | | | v | | | <---------------------------------------+ Send | | | | ADD_NODE + | | v | | | | | | | | ProcessAddNodeCommand | | | | + | | | | | | | | | | All nodes OK | | | | | | | | v | | | | | set rank | | | wait until ready | | | | + | | | | | +----------------------------------------------------------------> | | | | ADD_NODE response(nodes info) | | | | | | ProcessAddNodeCommand | | | v | | | | | | <--------------+ | wait until ready | | ready_ = true | + | | | | <---------------+ +-------------------+ v | | | | +--------------------+ v | | | v | | | v Postoffice::Barrier | | Postoffice::Barrier +
手机如下,左侧是 Scheduler,右侧是 worker:
其互联过程可以分为3步:
第一步:worker/server节点初始化的时候,向schedular节点发送一个连接信息,假定自身是节点 2;
if (!is_scheduler_) { // let the scheduler know myself Message msg; Node customer_specific_node = my_node_; customer_specific_node.customer_id = customer_id; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::ADD_NODE; msg.meta.control.node.push_back(customer_specific_node); msg.meta.timestamp = timestamp_++; Send(msg); //发送给schedular, 建立链接信息。 }
第二步:Scheduler 节点收到信息后,在 ProcessAddNodeCommandAtScheduler 之中,首先会和 节点 2 建立一个连接。会向所有已经和schedular建立连接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求连接的信息放入meta信息中。
// assign node rank for (auto& node : nodes->control.node) { std::string node_host_ip = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); node.id = id; Connect(node); // 连接这个新节点, 即建立一个socket, 然后senders_[id] = sender; 就是将目标id的socket存放起来后面使用 Postoffice::Get()->UpdateHeartbeat(node.id, t); connected_nodes_[node_host_ip] = id; } else { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); shared_node_mapping_[id] = connected_nodes_[node_host_ip]; node.id = connected_nodes_[node_host_ip]; } if (node.role == Node::SERVER) num_servers_++; if (node.role == Node::WORKER) num_workers_++; } nodes->control.node.push_back(my_node_); nodes->control.cmd = Control::ADD_NODE; Message back; back.meta = *nodes; // 向所有已经和schedular建立连接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求连接的信息放入meta信息中。 for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { int recver_id = r; if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) { back.meta.recver = recver_id; back.meta.timestamp = timestamp_++; Send(back); } }
第三步:现有worker/server节点收到这个命令后,在 ProcessAddNodeCommand 之中 会和 节点 2 形成连接。
for (const auto& node : ctrl.node) { std::string addr_str = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(addr_str) == connected_nodes_.end()) { // 现有连接中没有这个新节点 Connect(node); // 与新节点进行连接 connected_nodes_[addr_str] = node.id; } if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_; if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
至此,整个过程就描述完了。每个新节点加入后,已经加入的节点都会通过schedular节点和这个新节点建立连接。
我们接下来分析心跳机制。
为了记录网络的可达性,PS Lite 设计了心跳机制。具体而言:
具体如下:
std::unordered_map<int, time_t> heartbeats_ 就是存储了心跳关联的节点的活跃信息。键为节点编号,值为上次收到其 HEARTBEAT 消息的时间戳。
UpdateHeartbeat 会定期更新心跳。
void UpdateHeartbeat(int node_id, time_t t) { std::lock_guard<std::mutex> lk(heartbeat_mu_); heartbeats_[node_id] = t; } std::unordered_map<int, time_t> heartbeats_;
在这两种节点中,启动了一个线程,每一个 Worker/Server 节点,每隔 PS_HEARTBEAT_INTERVAL 秒向 Scheduler 发送一条 HEARTBEAT 消息:
if (!is_scheduler_) { // start heartbeat thread heartbeat_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this)); }
具体心跳函数是:
void Van::Heartbeat() { const char* val = Environment::Get()->find("PS_HEARTBEAT_INTERVAL"); const int interval = val ? atoi(val) : kDefaultHeartbeatInterval; while (interval > 0 && ready_.load()) { std::this_thread::sleep_for(std::chrono::seconds(interval)); Message msg; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::HEARTBEAT; msg.meta.control.node.push_back(my_node_); msg.meta.timestamp = timestamp_++; Send(msg); } }
Scheduler 节点收到后 HEARTBEAT 消息后,响应一个 HEARTBEAT 消息。UpdateHeartbeat 会定期更新心跳。
void Van::ProcessHearbeat(Message* msg) { auto& ctrl = msg->meta.control; time_t t = time(NULL); for (auto& node : ctrl.node) { Postoffice::Get()->UpdateHeartbeat(node.id, t); if (is_scheduler_) { Message heartbeat_ack; heartbeat_ack.meta.recver = node.id; heartbeat_ack.meta.control.cmd = Control::HEARTBEAT; heartbeat_ack.meta.control.node.push_back(my_node_); heartbeat_ack.meta.timestamp = timestamp_++; // send back heartbeat Send(heartbeat_ack); } } }
Scheduler 在处理 ADD_NODE 消息时候,会看看是否已经有死亡节点,具体判通过当前时间戳与心跳包接收时间戳之差判断是否alive。
std::vector<int> Postoffice::GetDeadNodes(int t) { std::vector<int> dead_nodes; if (!van_->IsReady() || t == 0) return dead_nodes; time_t curr_time = time(NULL); const auto& nodes = is_scheduler_ ? GetNodeIDs(kWorkerGroup + kServerGroup) : GetNodeIDs(kScheduler); { std::lock_guard<std::mutex> lk(heartbeat_mu_); for (int r : nodes) { auto it = heartbeats_.find(r); if ((it == heartbeats_.end() || it->second + t < curr_time) && start_time_ + t < curr_time) { dead_nodes.push_back(r); } } } return dead_nodes; }
逻辑如下:
+----------------------------------------------------+ | Scheduler | | | | | | | | heartbeats_ | | | | receiver_thread_+--------> ProcessHearbeat | | ^ + ^ + | | | | | | | | | | | | | | | | | | | +----------------------------------------------------+ | | | | | | | | RESPONSE | | | +-------------------------------------+ | | | | | | +-------------------------------+ | | | | | HEARTBEAT | | RESPONSE HEARTBEAT | | | | | | +-----------------------------------------+ +-----------------------------------------+ | Worker | | | | Server | | | | | | | | | | | | | | | | | | | | | | | | | | | | heartbeats_ | | | | heartbeats_ | | | | + | | | + | | | heartbeat_thread_+----> Heartbeat | | | heartbeat_thread_+--> Heartbeat | | | | | | | | | v | | v | | receiver_thread_ +---> ProcessHearbeat | | receiver_thread_ +--> ProcessHearbeat | | | | | | | | | | | | | +-----------------------------------------+ +-----------------------------------------+
ProcessTerminateCommand 会处理结束消息,具体就是设定 ready_ 为 false。
这样就预示着 Van 状态不对,不可以继续处理。
void Van::ProcessTerminateCommand() { PS_VLOG(1) << my_node().ShortDebugString() << " is stopped"; ready_ = false; } inline bool IsReady() { return ready_; }
在分布式系统中,通信也是不可靠的,丢包、延时都是必须考虑的场景。PS Lite 设计了 Resender类来提高通信的可靠性,它引入了 ACK 机制。即:
定义如下,其中 send_buff_ 就是发送缓存,用来存储发送了的消息列表。acked_ 就是已经确认的消息。
class Resender { std::thread* monitor_; std::unordered_set<uint64_t> acked_; std::atomic<bool> exit_{false}; std::mutex mu_; int timeout_; int max_num_retry_; Van* van_; using Time = std::chrono::milliseconds; // the buffer entry struct Entry { Message msg; Time send; int num_retry = 0; }; std::unordered_map<uint64_t, Entry> send_buff_; };
监控线程以及函数如下如下,就是被唤醒时候,从send_buff_(本地缓存)找到每个消息的发送时间戳和当前时间,找出超时的消息进行重发,并累加其重试次数。 :
monitor_ = new std::thread(&Resender::Monitoring, this); void Monitoring() { while (!exit_) { std::this_thread::sleep_for(Time(timeout_)); std::vector<Message> resend; Time now = Now(); mu_.lock(); for (auto& it : send_buff_) { if (it.second.send + Time(timeout_) * (1+it.second.num_retry) < now) { resend.push_back(it.second.msg); ++it.second.num_retry; CHECK_LT(it.second.num_retry, max_num_retry_); } } mu_.unlock(); for (const auto& msg : resend) van_->Send(msg); } }
当 Van 发送消息时候,如果配置了重传,就调用AddOutgoing函数把消息加入到发送缓存。
int Van::Send(const Message& msg) { int send_bytes = SendMsg(msg); CHECK_NE(send_bytes, -1); send_bytes_ += send_bytes; if (resender_) resender_->AddOutgoing(msg); if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } return send_bytes; }
下面函数就是加入到发送缓存。
/** * \brief add an outgoining message * */ void AddOutgoing(const Message& msg) { if (msg.meta.control.cmd == Control::ACK) return; CHECK_NE(msg.meta.timestamp, Meta::kEmpty) << msg.DebugString(); auto key = GetKey(msg); std::lock_guard<std::mutex> lk(mu_); // already buffered, which often due to call Send by the monitor thread if (send_buff_.find(key) != send_buff_.end()) return; auto& ent = send_buff_[key]; ent.msg = msg; ent.send = Now(); ent.num_retry = 0; }
下面函数有两个作用:
/** * \brief add an incomming message * \brief return true if msg has been added before or a ACK message */ bool AddIncomming(const Message& msg) { // a message can be received by multiple times if (msg.meta.control.cmd == Control::TERMINATE) { return false; } else if (msg.meta.control.cmd == Control::ACK) { mu_.lock(); auto key = msg.meta.control.msg_sig; auto it = send_buff_.find(key); if (it != send_buff_.end()) send_buff_.erase(it); mu_.unlock(); return true; } else { mu_.lock(); auto key = GetKey(msg); auto it = acked_.find(key); bool duplicated = it != acked_.end(); if (!duplicated) acked_.insert(key); mu_.unlock(); // send back ack message (even if it is duplicated) Message ack; ack.meta.recver = msg.meta.sender; ack.meta.sender = msg.meta.recver; ack.meta.control.cmd = Control::ACK; ack.meta.control.msg_sig = key; van_->Send(ack); // warning if (duplicated) LOG(WARNING) << "Duplicated message: " << msg.DebugString(); return duplicated; } }
ProcessDataMsg 用来处理 worker 发过来的数据消息(就是worker向server更新梯度),具体是取得对应的Customer后,调用 Customer 的方法进行处理,直接将msg
放入处理队列中。
我们会放在 Customer 之中进行介绍。
void Van::ProcessDataMsg(Message* msg) { // data msg int app_id = msg->meta.app_id; int customer_id = Postoffice::Get()->is_worker() ? msg->meta.customer_id : app_id; auto* obj = Postoffice::Get()->GetCustomer(app_id, customer_id, 5); obj->Accept(*msg); // 这里给 Customer 添加消息 }
ZMQVan是基于zeromq的Van的实现,即为用zmq库实现了连接的底层细节(zmq库是一个开源库,对socket进行了优良的封装,他使得Socket编程更加简单、简洁和性能更高)。
ZMQVan定义如下:
ZMQVan 继承于Van ,在这个类的基础上加了两个成员变量,分别是:
具体如下:
class ZMQVan : public Van { void *context_ = nullptr; /** * \brief node_id to the socket for sending data to this node */ std::unordered_map<int, void*> senders_; std::mutex mu_; void *receiver_ = nullptr; };
Van类 有如下函数会调用到 ZMQVan 或者被 ZMQVan 调用。
Send 函数就是调用 ZMQVan 的 SendMsg 函数进行发送消息,发送之后如果设定了ACK机制,则会调用 resender_->AddOutgoing。
int Van::Send(const Message& msg) { int send_bytes = SendMsg(msg); CHECK_NE(send_bytes, -1); send_bytes_ += send_bytes; if (resender_) resender_->AddOutgoing(msg); if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } return send_bytes; }
Meta封装了元数据,发送者,接受者,时间戳,请求还是响应等。
/** * \brief meta info of a message */ struct Meta { /** \brief the empty value */ static const int kEmpty; /** \brief an int head */ int head; /** \brief the unique id of the application of messsage is for*/ int app_id; /** \brief customer id*/ int customer_id; /** \brief the timestamp of this message */ int timestamp; /** \brief the node id of the sender of this message */ int sender; /** \brief the node id of the receiver of this message */ int recver; /** \brief whether or not this is a request message*/ bool request; /** \brief whether or not a push message */ bool push; /** \brief whether or not a pull message */ bool pull; /** \brief whether or not it's for SimpleApp */ bool simple_app; /** \brief an string body */ std::string body; /** \brief data type of message.data[i] */ std::vector<DataType> data_type; /** \brief system control message */ Control control; /** \brief the byte size */ int data_size = 0; /** \brief message priority */ int priority = 0; };
为了缓解通信压力,ps-lite 使用了Protobuf对 Meta 进行数据压缩。
就是按照 protobuf 来进行数据压缩。
void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) { // convert into protobuf PBMeta pb; pb.set_head(meta.head); if (meta.app_id != Meta::kEmpty) pb.set_app_id(meta.app_id); if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp); if (meta.body.size()) pb.set_body(meta.body); pb.set_push(meta.push); pb.set_pull(meta.pull); pb.set_request(meta.request); pb.set_simple_app(meta.simple_app); pb.set_priority(meta.priority); pb.set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb.add_data_type(d); if (!meta.control.empty()) { auto ctrl = pb.mutable_control(); ctrl->set_cmd(meta.control.cmd); if (meta.control.cmd == Control::BARRIER) { ctrl->set_barrier_group(meta.control.barrier_group); } else if (meta.control.cmd == Control::ACK) { ctrl->set_msg_sig(meta.control.msg_sig); } for (const auto& n : meta.control.node) { auto p = ctrl->add_node(); p->set_id(n.id); p->set_role(n.role); p->set_port(n.port); p->set_hostname(n.hostname); p->set_is_recovery(n.is_recovery); p->set_customer_id(n.customer_id); } } // to string *buf_size = pb.ByteSize(); *meta_buf = new char[*buf_size + 1]; CHECK(pb.SerializeToArray(*meta_buf, *buf_size)) << "failed to serialize protobuf"; }
按照protobuf 预先生成的 PBMeta 格式进行解压。
void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) { // to protobuf PBMeta pb; CHECK(pb.ParseFromArray(meta_buf, buf_size)) << "failed to parse string into protobuf"; // to meta meta->head = pb.head(); meta->app_id = pb.has_app_id() ? pb.app_id() : Meta::kEmpty; meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty; meta->request = pb.request(); meta->push = pb.push(); meta->pull = pb.pull(); meta->simple_app = pb.simple_app(); meta->priority = pb.priority(); meta->body = pb.body(); meta->customer_id = pb.customer_id(); meta->data_type.resize(pb.data_type_size()); for (int i = 0; i < pb.data_type_size(); ++i) { meta->data_type[i] = static_cast<DataType>(pb.data_type(i)); } if (pb.has_control()) { const auto& ctrl = pb.control(); meta->control.cmd = static_cast<Control::Command>(ctrl.cmd()); meta->control.barrier_group = ctrl.barrier_group(); meta->control.msg_sig = ctrl.msg_sig(); for (int i = 0; i < ctrl.node_size(); ++i) { const auto& p = ctrl.node(i); Node n; n.role = static_cast<Node::Role>(p.role()); n.port = p.port(); n.hostname = p.hostname(); n.id = p.has_id() ? p.id() : Node::kEmpty; n.is_recovery = p.is_recovery(); n.customer_id = p.customer_id(); meta->control.node.push_back(n); } } else { meta->control.cmd = Control::EMPTY; } }
PackMetaPB 从注释看,是字节跳动提交的,主要用于 ibverbs_van.h,所以我们不做深入研究。
void Van::PackMetaPB(const Meta& meta, PBMeta* pb) { pb->set_head(meta.head); if (meta.app_id != Meta::kEmpty) pb->set_app_id(meta.app_id); if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); if (meta.body.size()) pb->set_body(meta.body); pb->set_push(meta.push); pb->set_request(meta.request); pb->set_simple_app(meta.simple_app); pb->set_priority(meta.priority); pb->set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb->add_data_type(d); if (!meta.control.empty()) { auto ctrl = pb->mutable_control(); ctrl->set_cmd(meta.control.cmd); if (meta.control.cmd == Control::BARRIER) { ctrl->set_barrier_group(meta.control.barrier_group); } else if (meta.control.cmd == Control::ACK) { ctrl->set_msg_sig(meta.control.msg_sig); } for (const auto& n : meta.control.node) { auto p = ctrl->add_node(); p->set_id(n.id); p->set_role(n.role); p->set_port(n.port); p->set_hostname(n.hostname); p->set_is_recovery(n.is_recovery); p->set_customer_id(n.customer_id); } } pb->set_data_size(meta.data_size); }
ZMQVan 有如下重要的派生函数。
Bind 逻辑如下: