书中代码1:
def binSplitDataSet(dataSet, feature, value): mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0] return mat0,mat1
改成:
def binSplitDataSet(dataSet, feature, value): featList = [] mat0 = [] mat1 = [] for featVec in dataSet: featList.append(featVec[feature]) for feat in featList: if feat > value: mat0.append(dataSet[featList.index(feat)]) else: mat1.append(dataSet[featList.index(feat)]) return mat0, mat1
书中代码2:
def regLeaf(dataSet): return mean(dataSet[:,-1])
改成:
def regLeaf(dataSet): valueList = [] for featVec in dataSet: valueList.append(featVec[-1]) return mean(valueList)
书中代码3:
def regErr(dataSet): return var(dataSet[:,-1]) * shape(dataSet)[0]
改成:
def regErr(dataSet): valueList = [] for featVec in dataSet: valueList.append(featVec[-1]) var = 0 mean = sum(valueList)/len(valueList) for value in valueList: var += (mean-value)**2 return var/len(valueList) * shape(dataSet)[0]
书中代码4:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): tolS = ops[0]; tolN = ops[1] if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = shape(dataSet) S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 for featIndex in range(n-1): for splitVal in set(dataSet[:,featIndex]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: return None, leafType(dataSet) #exit cond 2 mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3 return None, leafType(dataSet) return bestIndex,bestValue
改成:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): tolS = ops[0] tolN = ops[1] valueList = [] for featVec in dataSet: valueList.append(featVec[-1]) if len(list(set(valueList))) == 1: return None, leafType(dataSet) m, n = shape(dataSet) S = errType(dataSet) bestS = inf bestIndex = 0 bestValue = 0 for featIndex in range(n - 1): valueList = [] for featVec in dataSet: valueList.append(featVec[featIndex]) for splitVal in list(set(valueList)): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): return None, leafType(dataSet) return bestIndex, bestValue
运行结果: