参考:scipy.sparse.csr_matrix — SciPy v1.8.0 Manual
CSR Matrix的存储结构包含三列数据:
以上图为例:
代码示例:
from scipy.sparse import csr_matrix import numpy as np def construct_csr(data): indptr = [0] col_indeces = [] values = [] for row_index, line in enumerate(data): row_value_num = 0 for col_index, value in enumerate(line): value = int(value) if value > 0: row_value_num += 1 col_indeces.append(col_index) values.append(value) indptr.append(indptr[-1] + row_value_num) row_num = row_index + 1 col_num = len(line) return csr_matrix((values, col_indeces, indptr), shape=(row_num, col_num)) if __name__ == '__main__': d_A = [[1,0,3], [0,5,7], [0,0,9], [2,4,0]] s_A = csr_matrix(np.array(d_A)) s_B = construct_csr(d_A) print(f's_A:\n{s_A}\n', ) print(f's_B:\n{s_B}\n', ) print(s_A.toarray()==s_B.toarray())
执行结果:
s_A: (0, 0) 1 (0, 2) 3 (1, 1) 5 (1, 2) 7 (2, 2) 9 (3, 0) 2 (3, 1) 4 s_B: (0, 0) 1 (0, 2) 3 (1, 1) 5 (1, 2) 7 (2, 2) 9 (3, 0) 2 (3, 1) 4 [[ True True True] [ True True True] [ True True True] [ True True True]]