Paracel是豆瓣开发的一个分布式计算框架,它基于参数服务器范式来解决机器学习的问题:逻辑回归、SVD、矩阵分解(BFGS,sgd,als,cg),LDA,Lasso...。
Paracel支持数据和模型的并行,为用户提供简单易用的通信接口,比mapreduce式的系统要更加灵活。Paracel同时支持异步的训练模式,使迭代问题收敛地更快。此外,Paracel程序的结构与串行程序十分相似,用户可以更加专注于算法本身,不需将精力过多放在分布式逻辑上。
因为我们之前已经用ps-lite对参数服务器的基本功能做了介绍,所以在本文中,我们主要与ps-lite比对大的方面和一些关键技术点(paracel没有开源容错机制,是个不小的遗憾),而不会像对 ps-lite 那样做较详细的分析。
对于本文来说,ps-lite的主要逻辑如下:
本系列其他文章是:
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice
[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van
[源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer
[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现
本文在解析时候会删除部分非主体代码。
我们首先通过源码提供的LR算法看看如何使用。
我们从源码中找到 LR 相关部分来看,以下就是一些必要配置,在其中我做了部分翻译,需要留意的是:用一条命令可以启动若干不同类型的实例,实例运行的都是可执行程序 lr。
- Enter Paracel's home directory 进入Paracel工作目录
```cd paracel;```
- Generate training dataset for classification 产生训练数据集
```python ./tool/datagen.py -m classification -o training.dat -n 2500 -k 100```
- Set up link library path: 设置链接库路径
```export LD_LIBRARY_PATH=your_paracel_install_path/lib```
Create a json file named
cfg.json
, see example in Parameters section below. 创建配置文件Run (4 workers, local mode in the following example) 运行(4个worker,2个参数服务器)
```./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr```
Default parameters are set in a JSON format file. For example, we create a cfg.json as below(modify
your_paracel_install_path
):{
"training_input" : "training.dat", 训练集
"test_input" : "training.dat", 验证集
"predict_input" : "training.dat", label数据
"output" : "./lr_result/",
"update_file" : "your_paracel_install_path/lib/liblr_update.so",
"update_func" : "lr_theta_update", 更新函数
"method" : "ipm",
"rounds" : 100,
"alpha" : 0.001,
"beta" : 0.01,
"debug" : false
}
通过makefile我们可以看到,是把 lr_driver.cpp, lr.cpp一起编译成为 lr 可执行文件。把 update.cpp 编译成库,被服务器加载调用。
add_library(lr_update SHARED update.cpp) # 参数服务器如何更新 target_link_libraries(lr_update ${CMAKE_DL_LIBS}) install(TARGETS lr_update LIBRARY DESTINATION lib) add_library(lr_method SHARED lr.cpp) # 算法代码 target_link_libraries(lr_method ${Boost_LIBRARIES} comm scheduler) install(TARGETS lr_method LIBRARY DESTINATION lib) add_executable(lr lr_driver.cpp) # 驱动代码 target_link_libraries(lr ${Boost_LIBRARIES} comm scheduler lr_method) install(TARGETS lr RUNTIME DESTINATION bin)
对于 LR,有四种 大规模深度神经网络的随机梯度下降法 可以选择
dgd: distributed gradient descent learning
ipm: iterative parameter mixtures learning
downpour: asynchrounous gradient descent learning
agd: slow asynchronous gradient descent learning
我们选择 agd 算法来学习分析:http://www.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf
首先,我们看看驱动代码 lr_driver.cpp,逻辑就是:
DEFINE_string(server_info, "host1:7777PARACELhost2:8888", "hosts name string of paracel-servers.\n"); DEFINE_string(cfg_file, "", "config json file with absolute path.\n"); int main(int argc, char *argv[]) { // 配置运行环境和通信 paracel::main_env comm_main_env(argc, argv); paracel::Comm comm(MPI_COMM_WORLD); google::SetUsageMessage("[options]\n\t--server_info\n\t--cfg_file\n"); google::ParseCommandLineFlags(&argc, &argv, true); // 读取分析参数 paracel::json_parser pt(FLAGS_cfg_file); std::string training_input, test_input, predict_input, output, update_file, update_func, method; try { training_input = pt.check_parse<std::string>("training_input"); test_input = pt.check_parse<std::string>("test_input"); predict_input = pt.check_parse<std::string>("predict_input"); output = pt.parse<std::string>("output"); update_file = pt.check_parse<std::string>("update_file"); update_func = pt.parse<std::string>("update_func"); method = pt.parse<std::string>("method"); } catch (const std::invalid_argument & e) { std::cerr << e.what(); return 1; } int rounds = pt.parse<int>("rounds"); double alpha = pt.parse<double>("alpha"); double beta = pt.parse<double>("beta"); bool debug = pt.parse<bool>("debug"); // 生成 logistic_regression,进行训练,验证,预测 paracel::alg::logistic_regression lr_solver(comm, FLAGS_server_info, training_input, output, update_file, update_func, method, rounds, alpha, beta, debug); lr_solver.solve(); std::cout << "final loss: " << lr_solver.calc_loss() << std::endl; lr_solver.test(test_input); lr_solver.predict(predict_input); lr_solver.dump_result(); return 0; }
从之前的配置中我们知道更新部分是:
"update_file" : "your_paracel_install_path/lib/liblr_update.so", "update_func" : "lr_theta_update",
所以我们从 alg/classification/logistic_regression/update.cpp 中得到更新函数如下:
具体就是合并两个参数然后返回。这部分代码被编译成库,在server之中被加载运行。
#include <vector> #include "proxy.hpp" #include "paracel_types.hpp" using std::vector; extern "C" { extern paracel::update_result lr_theta_update; } vector<double> local_update(vector<double> a, vector<double> b) { vector<double> r; for(int i = 0; i < (int)a.size(); ++i) { r.push_back(a[i] + b[i]); } return r; } paracel::update_result lr_theta_update = paracel::update_proxy(local_update);
logistic_regression 是类定义,位于lr.hpp。logistic_regression 需要继承 paracel::paralg 才能使用。
namespace paracel { namespace alg { class logistic_regression: public paracel::paralg { public: logistic_regression(paracel::Comm, string, string _input, string output, string update_file_name, string update_func_name, string = "ipm", int _rounds = 1, double _alpha = 0.002, double _beta = 0.1, bool _debug = false); virtual ~logistic_regression(); double lr_hypothesis(const vector<double> &); void dgd_learning(); // distributed gradient descent learning void ipm_learning(); // by default: iterative parameter mixtures learning void downpour_learning(); // asynchronous gradient descent learning void agd_learning(); // slow asynchronous gradient descent learning virtual void solve(); double calc_loss(); void dump_result(); void print(const vector<double> &); void test(const std::string &); void predict(const std::string &); private: void local_parser(const vector<string> &, const char); void local_parser_pred(const vector<string> &, const char); private: string input; string update_file, update_func; std::string learning_method; int worker_id; int rounds; double alpha, beta; bool debug = false; vector<vector<double> > samples, pred_samples; vector<double> labels; vector<double> theta; vector<double> loss_error; vector<std::pair<vector<double>, double> > predv; int kdim; // not contain 1 }; } // namespace alg } // namespace paracel
solve 是主体代码,依据不同配置选择不同的随机梯度下降法来训练。
void logistic_regression::solve() { auto lines = paracel_load(input); local_parser(lines); paracel_sync(); if(learning_method == "dgd") { dgd_learning(); } else if(learning_method == "ipm") { ipm_learning(); } else if(learning_method == "downpour") { downpour_learning(); } else if(learning_method == "agd") { agd_learning(); } else { ERROR_ABORT("method do not support"); } paracel_sync(); }
我们找出论文中的算法比对:
下面代码和论文算法基本一一对应,逻辑如下。
void logistic_regression::agd_learning() { int data_sz = samples.size(); int data_dim = samples[0].size(); theta = paracel::random_double_list(data_dim); paracel_write("theta", theta); // first push // 首先把 theta 推送到参数服务器 vector<int> idx; for(int i = 0; i < data_sz; ++i) { idx.push_back(i); } paracel_register_bupdate(update_file, update_func); double coff2 = 2. * beta * alpha; vector<double> delta(data_dim); unsigned time_seed = std::chrono::system_clock::now().time_since_epoch().count(); // train loop for(int rd = 0; rd < rounds; ++rd) { std::shuffle(idx.begin(), idx.end(), std::default_random_engine(time_seed)); theta = paracel_read<vector<double> >("theta"); // 从参数服务器读取最新的 theta vector<double> theta_old(theta); // traverse data for(auto sample_id : idx) { theta = paracel_read<vector<double> >("theta"); theta_old = theta; double coff1 = alpha * (labels[sample_id] - lr_hypothesis(samples[sample_id])); for(int i = 0; i < data_dim; ++i) { double t = coff1 * samples[sample_id][i] - coff2 * theta[i]; theta[i] += t; } if(debug) { loss_error.push_back(calc_loss()); } for(int i = 0; i < data_dim; ++i) { delta[i] = theta[i] - theta_old[i]; } // 把计算结果推送到参数服务器 paracel_bupdate("theta", delta); // you could push a batch of delta into a queue to optimize } // traverse } // rounds theta = paracel_read<vector<double> >("theta"); // last pull // 得到最终结果 }
lr的逻辑图如下:
+------------+ +-------------------------------------------------+ | lr_driver | |logistic_regression | | | | | | +---------------------------------------> solve | +------------+ lr_solver.solve() | + | | | | | | | | | | | +---------------------+-----------------------+ | | | agd_learning | | | | +-----------------------+ | | | | | | | | | | | v | | | | | theta = paracel_read("theta") | | | | | | | | | | | | | | | | | v | | | | | | | | | | delta[i] = theta[i] - theta_old[i] | | | | | + | | | | | | | | | | | | | | | | | v | | | | | paracel_bupdate("theta", delta) | | | | | + + | | | | | | | | | | | +-----------------------+ | | | | +---------------------------------------------+ | | | | +-------------------------------------------------+ | Worker | +------------------------------------------------------------------------------------+ Server | +---------------------+ | Server | | | | | | v | | local_update | | | +---------------------+
至此,我们知道了Paracel如何使用,实现是以driver为核心进行展开,用户需要编写 update函数和算法函数。但是距离深入了解还差得很远。
我们目前有几个问题需要解决:
我们需要通过启动部分来继续研究。
如前所述./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
是启动命令,paracel 通过 prun.py 进入系统,所以我们分析这个脚本。
下面我们省略一些非主体代码,比如处理参数,逻辑如下:
if __name__ == '__main__': optpar = OptionParser() # 省略处理参数 (options, args) = optpar.parse_args() nsrv = 1 nworker = 1 if options.parasrv_num: nsrv = options.parasrv_num if options.worker_num: nworker = options.worker_num if not options.method_server: options.method_server = options.method if not options.ppn_server: options.ppn_server = options.ppn if not options.mem_limit_server: options.mem_limit_server = options.mem_limit if not options.hostfile_server: options.hostfile_server = options.hostfile # 利用 init_starter 得到如何启动server,worker,构建出相应字符串 server_starter = init_starter(options.method_server, str(options.mem_limit_server), str(options.ppn_server), options.hostfile_server, options.server_group) worker_starter = init_starter(options.method, str(options.mem_limit), str(options.ppn), options.hostfile, options.worker_group) #initport = random.randint(30000, 65000) #initport = get_free_port() initport = 11777 start_parasrv_cmd_lst = [server_starter, str(nsrv), os.path.join(PARACEL_INSTALL_PREFIX, 'bin/start_server --start_host'), socket.gethostname(), ' --init_port', str(initport)] start_parasrv_cmd = ' '.join(start_parasrv_cmd_lst) # 利用 subprocess.Popen 启动server,其中server的执行程序是 bin/start_server procs = subprocess.Popen(start_parasrv_cmd, shell=True, preexec_fn=os.setpgrp) try: serverinfo = paracelrun_cpp_proxy(nsrv, initport) entry_cmd = '' if args: entry_cmd = ' '.join(args) alg_cmd_lst = [worker_starter, str(nworker), entry_cmd, '--server_info', serverinfo, '--cfg_file', options.config] alg_cmd = ' '.join(alg_cmd_lst) # 利用 os.system 启动 worker os.system(alg_cmd) os.killpg(procs.pid, 9) except Exception as e: logger.exception(e) os.killpg(procs.pid, 9)
init_starter 函数会依据配置构建一个字符串。其中 paracel 有三种启动方式:
The –m_server and -m options above refer to what type of cluster you use. Paracel support mesos clusters, mpi clusters and multiprocessers in a single machine.
我们利用前面horovod文章的知识可以知道,mpirun 是可以启动多个进程。
结合之前的命令行,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
,可以知道 local 就是 mpirun,所以paracel 通过 mpirun 来启动了 4 个 lr 进程。
具体代码如下:
def init_starter(method, mem_limit, ppn, hostfile, group): '''Assemble commands for running paracel programs''' starter = '' if not hostfile: hostfile = '~/.mpi/large.18' if method == 'mesos': if group: starter = '%s/mrun -m %s -p %s -g %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn, group) else: starter = '%s/mrun -m %s -p %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn) elif method == 'mpi': starter = 'mpirun --hostfile %s -n ' % hostfile elif method == 'local': starter = 'mpirun -n ' else: print 'method %s not supported.' % method sys.exit(1) return starter
前面提到,server 执行程序对应的是 bin/start_server。
我们看看其构建 src/CMakeLists.txt,于是我们可以去查找 start_server.cpp。
add_library(comm SHARED comm.cpp) # 通信相关库 install(TARGETS comm LIBRARY DESTINATION lib) add_library(scheduler SHARED scheduler.cpp # 调度 install(TARGETS scheduler LIBRARY DESTINATION lib) add_library(default SHARED default.cpp) # 缺省库 install(TARGETS default LIBRARY DESTINATION lib) # 这里可以看到start_server.cpp add_executable(start_server start_server.cpp) target_link_libraries(start_server ${Boost_LIBRARIES} ${CMAKE_DL_LIBS}) install(TARGETS start_server RUNTIME DESTINATION bin) add_executable(paracelrun_cpp_proxy paracelrun_cpp_proxy.cpp) target_link_libraries(paracelrun_cpp_proxy ${Boost_LIBRARIES} ${CMAKE_DL_LIBS}) install(TARGETS paracelrun_cpp_proxy RUNTIME DESTINATION bin)
src/start_server.cpp 是服务器主体代码。
结合之前的命令行,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
,可以知道 local 就是 mpirun,所以paracel 通过 mpirun 来启动了 2 个 start_server 进程,即两个参数服务器。
#include <gflags/gflags.h> #include "server.hpp" DEFINE_string(start_host, "beater7", "host name of start node\n"); DEFINE_string(init_port, "7773", "init port"); int main(int argc, char *argv[]) { google::SetUsageMessage("[options]\n\ --start_host\tdefault: balin\n\ --init_port\n"); google::ParseCommandLineFlags(&argc, &argv, true); paracel::init_thrds(FLAGS_start_host, FLAGS_init_port); // join inside return 0; }
在 include/server.hpp 文件之中,init_thrds 函数启动了一系列线程,具体逻辑如下。
// init_host is the hostname of starter void init_thrds(const paracel::str_type & init_host, const paracel::str_type & init_port) { // 构建 zmq 环境 zmq::context_t context(2); zmq::socket_t sock(context, ZMQ_REQ); paracel::str_type info = "tcp://" + init_host + ":" + init_port; sock.connect(info.c_str()); char hostname[1024], freeport[1024]; size_t size = sizeof(freeport); // hostname of servers gethostname(hostname, sizeof(hostname)); paracel::str_type ports = hostname; ports += ":"; // create sock in every thrd 为每个线程建立了socket std::vector<zmq::socket_t *> sock_pt_lst; for(int i = 0; i < paracel::threads_num; ++i) { zmq::socket_t *tmp; tmp = new zmq::socket_t(context, ZMQ_REP); sock_pt_lst.push_back(tmp); sock_pt_lst.back()->bind("tcp://*:*"); sock_pt_lst.back()->getsockopt(ZMQ_LAST_ENDPOINT, &freeport, &size); if(i == paracel::threads_num - 1) { ports += local_parse_port(paracel::str_type(freeport)); } else { ports += local_parse_port(std::move(paracel::str_type(freeport))) + ","; } } zmq::message_t request(ports.size()); std::memcpy((void *)request.data(), &ports[0], ports.size()); sock.send(request); zmq::message_t reply; sock.recv(&reply); // 建立服务器处理线程 thrd_exec paracel::list_type<std::thread> threads; for(int i = 0; i < paracel::threads_num - 1; ++i) { threads.push_back(std::thread(thrd_exec, std::ref(*sock_pt_lst[i]))); } // 建立ssp线程 thrd_exec_ssp threads.push_back(std::thread(thrd_exec_ssp, std::ref(*sock_pt_lst.back()))); // 等待线程结束 for(auto & thrd : threads) { thrd.join(); } for(int i = 0; i < paracel::threads_num; ++i) { delete sock_pt_lst[i]; } zmq_ctx_destroy(context); } // init_thrds
./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
的对应启动逻辑图具体如下:
prun.py + | | | +----------------+ | +--> | start_server | v | +----------------+ server_starter = init_starter +--> mpirun -n 2 +----+ + | +----------------+ | | | start_server | | | | + | | +--> | | | v | | | worker_starter = init_starter +--> mpirun -n 4 | | | + | v | | | init_thrds | | | + | | | | | +-------+----+--+-------+ | | | | | | | | | | | | | | | v | v v v v | thrd_exec | bin/lr bin/lr bin/lr bin/lr | + | | | | | | | | | | | v | | thrd_exec_ssp | +----------------+
目前我们知道了,worker和server都有多种启动方式,比如用 mpi 的方式来启动多个进程。
worker 端就是通过 driver.cpp 为主体,启动多个进程。
server端就是通过 start_server 为主体,启动多个进程,就是多个进程(参数服务器)组成了一个集群。
以上这些和ps-lite非常类似。
下面我们要分别深入这两个角色的内部。
通过之前ps-lite我们知道,参数服务器大多使用 KV 存储来保存参数,所以我们先介绍KV存储。
在 include/kv_def.hpp 给出了server 端使用的KV存储。
#include "paracel_types.hpp" #include "kv.hpp" namespace paracel { paracel::kvs<paracel::str_type, int> ssp_tbl; // 用来协助实现 SSP paracel::kvs<paracel::str_type, paracel::str_type> tbl_store; // 主要的kv存储 }
KV 存储的定义在 include/kv.hpp,下面省略了部分代码。
可以看出来,基本功能就是维护了内存table,提供了set系列函数和get系列函数,其中当需要返回 value, unique 的时候,就采用hash函数处理。
template <class K, class V> struct kvs { public: bool contains(const K & k) { return kvdct.count(k); } void set(const K & k, const V & v) { kvdct[k] = v; } void set_multi(const paracel::dict_type<K, V> & kvdict) { for(auto & kv : kvdict) { set(kv.first, kv.second); } } boost::optional<V> get(const K & k) { auto fi = kvdct.find(k); if(fi != kvdct.end()) { return boost::optional<V>(fi->second); } else return boost::none; } bool get(const K & k, V & v) { auto fi = kvdct.find(k); if(fi != kvdct.end()) { v = fi->second; return true; } else { return false; } } paracel::list_type<V> get_multi(const paracel::list_type<K> & keylst) { paracel::list_type<V> valst; for(auto & key : keylst) { valst.push_back(kvdct.at(key)); } return valst; } void get_multi(const paracel::list_type<K> & keylst, paracel::list_type<V> & valst) { for(auto & key : keylst) { valst.push_back(kvdct.at(key)); } } void get_multi(const paracel::list_type<K> & keylst, paracel::dict_type<K, V> & valdct) { valdct.clear(); for(auto & key : keylst) { auto it = kvdct.find(key); if(it != kvdct.end()) { valdct[key] = it->second; } } } // 这里使用了 hash 函数 // gets(key) -> value, unique boost::optional<std::pair<V, paracel::hash_return_type> > gets(const K & k) { if(auto v = get(k)) { std::pair<V, paracel::hash_return_type> ret(*v, hfunc(*v)); return boost::optional< std::pair<V, paracel::hash_return_type> >(ret); } else { return boost::none; } } // compare-and-set, cas(key, value, unique) -> True/False bool cas(const K & k, const V & v, const paracel::hash_return_type & uniq) { if(auto r = gets(k)) { if(uniq == (*r).second) { set(k, v); return true; } else { return false; } } else { kvdct[k] = v; } return true; } paracel::dict_type<K, V> getall() { return kvdct; } private: //std::tr1::unordered_map<K, V> kvdct; paracel::dict_type<K, V> kvdct; paracel::hash_type<V> hfunc; };
thrd_exec 线程实现了参数服务器的基本处理逻辑:就是针对worker传来的不同的命令进行相关处理(大部分就是针对KV存储进行处理),比如:
需要注意的是,这里使用了用户定义的update函数,即:
下面删除了部分非主体代码。
// thread entry void thrd_exec(zmq::socket_t & sock) { paracel::packer<> pk; update_result update_f; filter_result pullall_special_f; filter_result remove_special_f; // 这里使用了dlopen_update_lambda来对用户设置的update函数进行生成,赋值为 update_f auto dlopen_update_lambda = [&](const paracel::str_type & fn, const paracel::str_type & fcn) { void *handler = dlopen(fn.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE); auto local = dlsym(handler, fcn.c_str()); update_f = *(std::function<paracel::str_type(paracel::str_type, paracel::str_type)>*) local; dlclose(handler); }; // 主体逻辑 while(1) { zmq::message_t s; sock.recv(&s); auto scrip = paracel::str_type(static_cast<const char *>(s.data()), s.size()); auto msg = paracel::str_split_by_word(scrip, paracel::seperator); auto indicator = pk.unpack(msg[0]); if(indicator == "pull") { // 如果是从参数服务器读取参数,则直接返回 auto key = pk.unpack(msg[1]); paracel::str_type result; auto exist = paracel::tbl_store.get(key, result); // 读取kv if(!exist) { paracel::str_type tmp = "nokey"; rep_send(sock, tmp); } else { rep_send(sock, result); // 返回 } } if(indicator == "pull_multi") { // 读取多个参数 paracel::packer<paracel::list_type<paracel::str_type> > pk_l; auto key_lst = pk_l.unpack(msg[1]); auto result = paracel::tbl_store.get_multi(key_lst); rep_pack_send(sock, result); } if(indicator == "pullall") { // 读取所有参数 auto dct = paracel::tbl_store.getall(); rep_pack_send(sock, dct); } mutex.lock(); if(indicator == "push") { // 插入参数 auto key = pk.unpack(msg[1]); paracel::tbl_store.set(key, msg[2]); bool result = true; rep_pack_send(sock, result); } if(indicator == "push_multi") { // 插入多个参数 paracel::packer<paracel::list_type<paracel::str_type> > pk_l; paracel::dict_type<paracel::str_type, paracel::str_type> kv_pairs; auto key_lst = pk_l.unpack(msg[1]); auto val_lst = pk_l.unpack(msg[2]); assert(key_lst.size() == val_lst.size()); for(int i = 0; i < (int)key_lst.size(); ++i) { kv_pairs[key_lst[i]] = val_lst[i]; } paracel::tbl_store.set_multi(kv_pairs); //插入kv bool result = true; rep_pack_send(sock, result); } if(indicator == "update" || indicator == "bupdate") { // 更新参数 if(msg.size() > 3) { if(msg.size() != 5) { ERROR_ABORT("invalid invoke in server end"); } // open request func auto file_name = pk.unpack(msg[3]); auto func_name = pk.unpack(msg[4]); dlopen_update_lambda(file_name, func_name); } else { if(!update_f) { dlopen_update_lambda("../local/build/lib/default.so", "default_incr_i"); } } auto key = pk.unpack(msg[1]); // 这里使用用户的update函数来对kv进行处理 std::string result = kv_update(key, msg[2], update_f); rep_send(sock, result); } if(indicator == "remove") { // 删除参数 auto key = pk.unpack(msg[1]); auto result = paracel::tbl_store.del(key); rep_pack_send(sock, result); } mutex.unlock(); } // while } // thrd_exec
简化如图:
+--------------------------------------------------------------------------------------+ | thrd_exec | | | | +---------------------------------> while(1) | | | + | | | | | | | | | | | +----------+----------+--------+--+------+----------+---------+---------+ | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | v v v v v v v v | | | | | | pull pull_multi pullall push push_multi update bupdate remove | | | + + + + + + + + | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | v v v v v v v v | | | +----------+----------+--------+----+----+----------+---------+---------+ | | | | | | | | | | | | | | | | | | +-----------------------------------------+ | | | +--------------------------------------------------------------------------------------+
目前为止,我们可以看到,Paracel和ps-lite也很类似,服务器维护了一个存储,服务器也可以处理客户端的请求。
Worker 就是用来训练算法的进程。从前面我们了解,算法需要继承paracel::paralg才能使用参数服务器功能。
namespace paracel { namespace alg { class logistic_regression: public paracel::paralg { .....
paracel::paralg 就可以认为是参数服务器的API,或者代理,我们下面就看看。
Paralg是提供Paracel主要功能的基本类,可以理解为一个算法API类,或者对外功能API类。
我们只给出其成员变量,暂时省略其函数实现。最主要几个为:
class paralg { private: class parasrv { // 可以理解为是参数服务器类 using l_type = paracel::list_type<paracel::kvclt>; using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public: parasrv(paracel::str_type hosts_dct_str) { // init dct_lst dct_lst = paracel::get_hostnames_dict(hosts_dct_str); // init srv_sz srv_sz = dct_lst.size(); // init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); } // init servers for(auto i = 0; i < srv_sz; ++i) { servers.push_back(i); } // init hashring p_ring = new paracel::ring<int>(servers); } virtual ~parasrv() { delete p_ring; } public: dl_type dct_lst; int srv_sz = 1; l_type kvm; paracel::list_type<int> servers; // 具体服务器列表 paracel::ring<int> *p_ring; // hash ring }; // nested class parasrv private: int stale_cache, clock, total_iters; // 同步需要 int clock_server = 0; paracel::Comm worker_comm; //通信类,比如 MPI 通信 paracel::str_type output; int nworker = 1; int rounds = 1; int limit_s = 0; bool ssp_switch = false; parasrv *ps_obj; // 可以理解为是正式的参数服务器类。 paracel::dict_type<paracel::default_id_type, paracel::default_id_type> rm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> cm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> dm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> col_dm; paracel::dict_type<paracel::str_type, paracel::str_type> keymap; paracel::dict_type<paracel::str_type, boost::any> cached_para; paracel::update_result update_f; int npx = 1, npy = 1; }
编写一个Paracel程序需要对paralg基类进行子类化,并且必须重写virtual solve方法。其中一些是SPMD iterfaces 并行接口。
我们从之前 LR 的实现可以看到需要继承 paracel::paralg 。
class logistic_regression: public paracel::paralg
就是说,用户的solve函数可以直接调用 Paralg 的函数来完成基本功能。
我们以 paracel::paracel_read 为例,可以看到是使用 parasrv.kvm 的功能,我们后续会继续介绍 parasrv。
template <class V> V paracel_read(const paracel::str_type & key, int replica_id = -1) { if(ssp_switch) { // 如果应用ssp,应该如何处理。我们下文就将具体介绍ssp如何处理 V val; if(clock == 0 || clock == total_iters) { cached_para[key] = boost::any_cast<V>(ps_obj-> kvm[ps_obj->p_ring->get_server(key)]. pull<V>(key)); val = boost::any_cast<V>(cached_para[key]); } else if(stale_cache + limit_s > clock) { val = boost::any_cast<V>(cached_para[key]); } else { while(stale_cache + limit_s < clock) { stale_cache = ps_obj-> kvm[clock_server].pull_int(paracel::str_type("server_clock")); } cached_para[key] = boost::any_cast<V>(ps_obj-> kvm[ps_obj->p_ring->get_server(key)]. pull<V>(key)); val = boost::any_cast<V>(cached_para[key]); } return val; } // 否则直接返回 return ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key); }
worker逻辑如下:
+---------------------------------------------------------------------------+ | Algorithm | | ^ +------------------------------v | | | | | | | | | | | v | | | +----------------------------+------------------------------+ | | | | paracel_read | | | | | | | | | | ps_obj+>kvm[ps_obj+>p_ring+>get_server(key)].pull<V>(key) | | | | | | | | | +----------------------------+------------------------------+ | | | | | | | | | | | | | | | v | | | Compute | | | + | | | | | | | | | | | v | | | +---------------------------+-------------------------------+ | | | | paracel_bupdate | | | | | ps_obj->kvm[indx].bupdate | | | | | | | | | +---------------------------+-------------------------------+ | | | | | | | | | | | | | | | | | | +-----<--------------------------+ | | | +---------------------------------------------------------------------------+
Worker端的机理也类似ps-lite,通过read,pull等操作,向服务器提出请求。
在沐神论文中,Ring hash 是与数据一致性,容错,可扩展等机制联系在一起,比如:
parameter server 在数据一致性上,使用的是传统的一致性哈希算法,参数key与server node id被插入到一个hash ring中。
但可惜的是,ps-lite 没有提供这部分代码,paracel 虽然有 ring hash,但也不齐全,豆瓣没有开源容错和一致性等部分。我们只能基于已有代码进行学习分析。
这里只是大致讲解下,有需求的同学可以去网上搜索详细文章。
从拗口的技术术语来解释,一致性哈希的技术关键点是:按照常用的hash算法来将对应的key哈希到一个具有2^32次方个桶的空间中,即0 ~ (2^32)-1的数字空间。我们可以将这些数字头尾相连,想象成一个闭合的环形。
用通俗白话来理解,这个关键点就是:在部署服务器的时候,服务器的序号空间已经配置成了一个固定的非常大的数字 1~2^32(不需要再改变)。服务器可以分配为 1~2^32 中任一序号。这样服务器集群可以固定大多数算法规则 (因为序号空间是算法的重要参数),这样面对扩容等变化只有"分配规则" 需要根据实际系统容量做相应微调。从而对整体系统影响较小。
ring 就是hash 环的实现类,这里主要功能就是把 服务器 加入到 hash ring 之中,以及从ring之中取出服务器。
// T rep type of server name template <class T> class ring { public: ring(paracel::list_type<T> names) { for(auto & name : names) { add_server(name); } } ring(paracel::list_type<T> names, int cp) : replicas(cp) { for(auto & name : names) { add_server(name); } } void add_server(const T & name) { //std::hash<paracel::str_type> hfunc; paracel::hash_type<paracel::str_type> hfunc; std::ostringstream tmp; tmp << name; auto name_str = tmp.str(); for(int i = 0; i < replicas; ++i) { //对每一个副本进行处理 std::ostringstream cvt; cvt << i; auto n = name_str + ":" + cvt.str(); auto key = hfunc(n); // 依据name生成一个key srv_hashring_dct[key] = name; //添加value srv_hashring.push_back(key); //往list添加内容 } // sort srv_hashring std::sort(srv_hashring.begin(), srv_hashring.end()); } void remove_server(const T & name) { //std::hash<paracel::str_type> hfunc; paracel::hash_type<paracel::str_type> hfunc; std::ostringstream tmp; tmp << name; auto name_str = tmp.str(); for(int i = 0; i < replicas; ++i) { // 对每个副本进行处理 std::ostringstream cvt; cvt << i; auto n = name_str + ":" + cvt.str(); auto key = hfunc(n);// 依据name生成一个key srv_hashring_dct.erase(key);// 删除value auto iter = std::find(srv_hashring.begin(), srv_hashring.end(), key); if(iter != srv_hashring.end()) { srv_hashring.erase(iter); // 删除list中的内容 } } } // TODO: relief load of srv_hashring_dct[srv_hashring[0]] template <class P> T get_server(const P & skey) { //std::hash<P> hfunc; paracel::hash_type<P> hfunc; auto key = hfunc(skey);// 依据name生成一个key auto server = srv_hashring[paracel::ring_bsearch(srv_hashring, key)];//获取server return srv_hashring_dct[server]; } private: int replicas = 32; // 分别用list和dict存储 paracel::list_type<paracel::hash_return_type> srv_hashring; paracel::dict_type<paracel::hash_return_type, T> srv_hashring_dct; };
我们使用 paracel_read 来看,可以发现调用顺序是
V paracel_read(const paracel::str_type & key, int replica_id = -1) { ...... ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key); }
这里是和ps-lite的不同之处,就是用ring-hash来维护数据一致性,容错等,比如把 服务器 加入到 hash ring 之中,以及从ring之中取出服务器。
我们把目前逻辑梳理一下,综合看看。
如何使用ring hash,需要从 parasrv 说起。
我们知道,paralg 是基础API类,其中在 paralg 中有如下定义 以及 构建了 ps_obj , ps_obj是一个 parasrv 类型的实例。
注:以下都是在worker端使用的类型。
// paralg 内代码 parasrv *ps_obj; // 成员变量定义,参数服务器接口 paralg(paracel::str_type hosts_dct_str, paracel::Comm comm, paracel::str_type _output = "", int _rounds = 1, int _limit_s = 0, bool _ssp_switch = false) : worker_comm(comm), output(_output), nworker(comm.get_size()), rounds(_rounds), limit_s(_limit_s), ssp_switch(_ssp_switch) { ps_obj = new parasrv(hosts_dct_str); // 构建参数服务器,一个parasrv的实例 init_output(_output); clock = 0; stale_cache = 0; clock_server = 0; total_iters = rounds; if(worker_comm.get_rank() == 0) { paracel::str_type key = "worker_sz"; (ps_obj->kvm[clock_server]). push_int(key, worker_comm.get_size()); // 初始化时钟服务器 } paracel_sync(); // mpi barrier同步一下 }
parasrv 的定义如下,其中 p_ring 就是 ring 实例,使用 p_ring = new paracel::ring<int>(servers)
来完成了构建。
其中p_ring 是 ring hash,kvm是具体的kv存储列表。
class parasrv { using l_type = paracel::list_type<paracel::kvclt>; using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public: parasrv(paracel::str_type hosts_dct_str) { // 初始化host信息,srv大小,kvm,servers,ring hash // init dct_lst dct_lst = paracel::get_hostnames_dict(hosts_dct_str); // init srv_sz srv_sz = dct_lst.size(); // init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); } // init servers for(auto i = 0; i < srv_sz; ++i) { servers.push_back(i); } // init hashring p_ring = new paracel::ring<int>(servers); // 构建 } virtual ~parasrv() { delete p_ring; } public: dl_type dct_lst; int srv_sz = 1; l_type kvm; // 具体KV存储接口 paracel::list_type<int> servers; paracel::ring<int> *p_ring; // ring hash }; // nested class parasrv
kvm 初始化如下:
// init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); }
kvclt 是 kv control 的抽象。
只摘取部分代码,就是找到对应的服务器进行交互。
namespace paracel { struct kvclt { public: kvclt(paracel::str_type hostname, paracel::str_type ports) : host(hostname), context(1) { ports_lst = paracel::str_split(ports, ','); conn_prefix = "tcp://" + host + ":"; } template <class V, class K> bool pull(const K & key, V & val) { // 从参数服务器拉取 if(p_pull_sock == nullptr) { p_pull_sock.reset(create_req_sock(ports_lst[0])); } auto scrip = paste(paracel::str_type("pull"), key); // paracel::str_type return req_send_recv(*p_pull_sock, scrip, val); } template <class K, class V> bool push(const K & key, const V & val) { // 往参数服务器推送 if(p_push_sock == nullptr) { p_push_sock.reset(create_req_sock(ports_lst[1])); } auto scrip = paste(paracel::str_type("push"), key, val); bool stat; auto r = req_send_recv(*p_push_sock, scrip, stat); return r && stat; } template <class V> bool req_send_recv(zmq::socket_t & sock, const paracel::str_type & scrip, V & val) { zmq::message_t req_msg(scrip.size()); std::memcpy((void *)req_msg.data(), &scrip[0], scrip.size()); sock.send(req_msg); zmq::message_t rep_msg; sock.recv(&rep_msg); paracel::packer<V> pk; if(!rep_msg.size()) { ERROR_ABORT("paracel internal error!"); } else { std::string data = paracel::str_type( static_cast<char*>(rep_msg.data()), rep_msg.size()); if(data == "nokey") return false; val = pk.unpack(data); } return true; } private: paracel::str_type host; paracel::list_type<paracel::str_type> ports_lst; paracel::str_type conn_prefix; zmq::context_t context; std::unique_ptr<zmq::socket_t> p_contains_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pull_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pull_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pullall_sock = nullptr; std::unique_ptr<zmq::socket_t> p_push_sock = nullptr; std::unique_ptr<zmq::socket_t> p_push_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_update_sock = nullptr; std::unique_ptr<zmq::socket_t> p_bupdate_sock = nullptr; std::unique_ptr<zmq::socket_t> p_bupdate_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_remove_sock = nullptr; std::unique_ptr<zmq::socket_t> p_clear_sock = nullptr; std::unique_ptr<zmq::socket_t> p_ssp_sock = nullptr; }; // struct kvclt } // namespace paracel
所以目前总体逻辑如下:
+------------------+ worker + server | paralg | | | | | | | | | parasrv *ps_obj | | | + | | +------------------+ | | | | | start_server | +------------------+ | | | | | | | | | | | v | | | +------------+-----+ +------------------+ +---------+ | | thrd_exec | | parasrv | |kvclt | | kvclt | | | | | | | | | | | | | | | | host | | | | | thrd_exec_ssp | | servers | | | | | | | | | | | ports_lst | | | | | | | kvm +-----------> | |.....| | | | ssp_tbl | | | | context | | | | | | | p_ring | | | | | | | | | + | | conn_prefix | | | | | tbl_store | | | | | | | | | | | +------------------+ | p_pull_sock+---+ | | | | | | | | | | | | | | | | p_push_sock | | | | | | | | | + | | | | | | | v | | | | | | | | | +------------+------+ +------------------+ | +---------+ | | | | ring | | | | +---+---+----------+ | | | | | ^ ^ | | | | | | | | srv_hashring | | +-----------------------+ | | | +------------------------------------+ | srv_hashring_dct | | | | | +-------------------+ +
手机如下:
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。
PARACEL:让分布式机器学习变得简单
参数服务器——分布式机器学习的新杀器