本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

wzatv:【j2开奖】教程 | 从头开始:用Python实现决策树算法(6)

时间:2017-02-20 20:26来源:本港台直播 作者:118开奖 点击:
在其中还包括了一个 print_tree() 函数,它能够递归地一行一个地打印出决策树的节点。经过它打印的不是一个明显的树结构,但它能给我们关于树结构的大

在其中还包括了一个 print_tree() 函数,它能够递归地一行一个地打印出决策树的节点。经过它打印的不是一个明显的树结构,但它能给我们关于树结构的大致印象,并能帮助决策。

# Split a dataset based on an attribute and an attribute value

  def test_split(index, value, dataset):

  left, right = list(), list()

  for row in dataset:

  if row[index] < value:

  left.append(row)

  else:

  right.append(row)

  return left, right

  # Calculate the Gini index for a split dataset

  def gini_index(groups, class_values):

  gini = 0.0

  for class_value in class_values:

  for group in groups:

  size = len(group)

  if size == 0:

  continue

  proportion = [row[-1] for row in group].count(class_value) / float(size)

  gini += (proportion * (1.0 - proportion))

  return gini

  # Select the best split point for a dataset

  def get_split(dataset):

  class_values = list(set(row[-1] for row in dataset))

  b_index, b_value, b_score, b_groups = 999, 999, 999, None

  for index in range(len(dataset[0])-1):

  for row in dataset:

  groups = test_split(index, row[index], dataset)

  gini = gini_index(groups, class_values)

  if gini < b_score:

  b_index, b_value, b_score, b_groups = index, row[index], gini, groups

  return {'index':b_index, 'value':b_value, 'groups':b_groups}

  # Create a terminal node value

  def to_terminal(group):

  outcomes = [row[-1] for row in group]

  return max(set(outcomes), key=outcomes.count)

  # Create child splits for a node or make terminal

  def split(node, max_depth, min_size, depth):

  left, right = node['groups']

  del(node['groups'])

  # check for a no split

  if not left or not right:

  node['left'] = node['right'] = to_terminal(left + right)

  return

  # check for max depth

  if depth >= max_depth:

  node['left'], node['right'] = to_terminal(left), to_terminal(right)

  return

  # process left child

  if len(left) <= min_size:

  node['left'] = to_terminal(left)

  else:

  node['left'] = get_split(left)

  split(node['left'], max_depth, min_size, depth+1)

  # process right child

  if len(right) <= min_size:

  node['right'] = to_terminal(right)

  else:

  node['right'] = get_split(right)

  split(node['right'], max_depth, min_size, depth+1)

  # Build a decision tree

  def build_tree(train, max_depth, min_size):

  root = get_split(dataset)

  split(root, max_depth, min_size, 1)

  return root

  # Print a decision tree

  def print_tree(node, depth=0):

  if isinstance(node, dict):

  print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))

  print_tree(node['left'], depth+1)

  print_tree(node['right'], depth+1)

  else:

  print('%s[%s]' % ((depth*' ', node)))

  dataset = [[2.771244718,1.784783929,0],

  [1.728571309,1.169761413,0],

  [3.678319846,2.81281357,0],

  [3.961043357,2.61995032,0],

  [2.999208922,2.209014212,0],

  [7.497545867,3.162953546,1],

  [9.00220326,3.339047188,1],

  [7.444542326,0.476683375,1],

  [10.12493903,3.234550982,1],

  [6.642287351,3.319983761,1]]

  tree = build_tree(dataset, 1, 1)

  print_tree(tree)

在运行过程中,我们能修改树的最大深度,并在打印的树上观察其影响。

当最大深度为 1 时(即调用 build_tree() 函数时第二个参数),我们可以发现该树使用了我们之前发现的完美分割点(作为树的唯一分割点)。该树只有一个节点,也被称为决策树桩。

[X1 < 6.642]

  [0]

  [1]

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容