机器学习——编程实现从零构造训练集的决策树

自己搭建一棵决策树【长文预警】

忙了一个周末就写到了“构建决策树”这一步,还没有考虑划分测试集、验证集、“缺失值、连续值”,预剪枝、后剪枝的部分,后面再补吧(挖坑)

第二节内容:验证集划分\k折交叉检验 机器学习——编程从零实现决策树【二】-CSDN博客

目录

1、信息

1)基本算法过程

2)信息熵和信息增益的计算方式

2、做点假设,简化运算

3、拆解算法过程

0)结点类

1)同类样本判断

2)数据集能否再拆解

3)选取最优属性

步骤1

步骤2:

步骤3:

步骤4: 

 步骤5:

4)构造新结点

4、完整的结点类代码

5、完整的构造树的过程

6、建树

1)准备数据集

2)建树

7、绘图查看树的结构

1)绘图代码

2)结果

3)预测


完整的代码指路

DrawPixel/decisionTree.ipynb at main · ndsoi/DrawPixel (github.com)

1、信息

1)基本算法过程

2)信息熵和信息增益的计算方式

2、做点假设,简化运算

 ① 为了选择最优的属性进行划分,我们需要计算信息增益,而计算信息增益需要用到信息:

        1、选取的属性attr有多少种取值?

                (用西瓜分类的例子,考虑属性”纹理“,就有3种取值——”清晰“、”稍糊“和”模糊“)

        2、每种取值有哪些数据?这些数据中有多少是A类别的,又有多少是B类别的..?

                比如对原始数据集考虑”纹理=清晰“的数据,那么有7个是好瓜,有2个是坏瓜

② 计算完信息增益之后,我们选信息增益最大的属性,按照这个属性划分数据集,生成子结点

注意这里的划分数据集,事实上我们在完成①.2问题的时候就已经“划分了”一次数据集,只是我们没有记录下来,类似这样的“冗余”计算有很多,为了尽量减少“重复”计算,我重规划算法的步骤如下:

1、设总共有class_num个类别,假设我们初始化结点node的时候就知道了这个数据集的如下信息:

数据集        self.data

属性集        self.attr

该数据集内样本数量最多的类别        self.max

该数据集内每个类别的样本数量        self.cal_class 是一个列表,每一个元素是|Dv|

2、基于假设1: 

计算Ent(D):

def Ent(D,cal_class):
  sum = len(D) # 样本总数
  # 求占比
  re = 0
  for k in cal_class:
    pk_class = k/sum
    if pk_class != 0:
      re -= pk_class* math.log(pk_class,2)
  return re

3、拆解算法过程

0)结点类

class Node():
  def __init__(self,D,A,max,cal_class,class_num):
    self.data = D
    self.attr = A
    self.class_num = class_num
    self.cal_class = cal_class

    self.max = max
    self.label = 0 # 0表示非叶结点 1表示叶结点
    self.Class = 0 # 默认一个
    self.flag = "init"

1)同类样本判断

若要判断D中的样本是否同属于一个类别:只需要判断self.max的数量是否等于class_num

  def isSameClass(self):
    if self.cal_class[self.max] == len(self.data):
      return True
    return False

2)数据集能否再拆解

若D中样本不属于同一类,那么接下来要看D中的样本是否还能再分解:

def isNoAttr(self):
    # 属性集为空
    if self.attr == None or self.attr==[]:
      return True,[]
    # 存储取值不同的属性
    self.Attr_Div = []
    for a in self.attr:
      a1 = self.data[0][a]
      for d in self.data:
        if d[a] != a1:
          self.Attr_Div.append(a)
          break
    # 无可分的属性
    if self.Attr_Div == []:
      return True,[]

    return False,self.Attr_Div

3)选取最优属性

从2)中获取了当前node数据集进一步可以分解的属性范围(self.Attr_Div),对于self.Attr_Div中的每一个attr,我们需要做的事情还有:

1. 找出属性attr的所有取值

2. 按照attr的不同取值将self.data划分成互斥的子集Dv_{attr=a1},Dv_{attr_{a_2}}...Dv_{a_n} 简称为Dv

3. 计算|Dv|和 Ent(Dv) 

4. 计算出attr的Gain

5. 重复步骤2-4 计算出所有attr的Gain, 选出Gain最大的attr

步骤1

# 属性attr的取值大全
def attrAllvalue(D,attr):
  Allvalue = {}
  for d in D:
    Allvalue[d[attr]] = 0
  return Allvalue

步骤2:

def divDataByattr(D,attr):
  # 建立一个字典,key是attr的取值,已初始化数值为0
  re = attrAllvalue(D,attr)
  n = len(re) # 要划分出n个子数据集
  SubDataSets = {}
  for key,value in re.items():
    SubDataSets[key] = []
  for d in D:
    SubDataSets[d[attr]].append(d)
  return SubDataSets

divDataByattr获得形如: {'清晰':[数据1,数据2],'模糊':[数据3],'稍糊':[数据4]} 的字典

步骤3:

为了计算Ent(Dv)我们需要获得Dv的cal_class,下列函数计算了数据集子集Dv的max和cal_class

# 获取maxnumClass
def calMaxClass(D,class_num):
  # 统计数据集D中各类样本的数目
  cal_class = [0 for i in range(class_num)]
  max = 0
  for d in D:
    cal_class[d['Class']]+=1
    if cal_class[d['Class']] > cal_class[max]:
      max = d['Class']
  return max,cal_class

步骤4: 

确定一个attr,划分出子集的集合,遍历子集集合,然后调用Ent函数,组合计算(加粗部分就是Gain函数所做的事情)

# 信息增益
def Gain(D,attr,class_info):
  max,cal_class = calMaxClass(D,class_num)
  EntD = Ent(D,cal_class)
  SubDataSets = divDataByattr(D,attr)
  EntDv = 0
  for value,Dv in SubDataSets.items():
    # cal_class
    max,cal_class=calMaxClass(Dv,class_num)
    class_info.append([max,cal_class])
    EntDv +=len(Dv)/len(D)*Ent(Dv,cal_class)
  Gain_D_attr = EntD-EntDv
  return Gain_D_attr

补充:这里的class_info就是记录下每一个Dv的max和cal_class,用于后续传参给node 初始化

 步骤5:

def choseAttr(D,attrSet):
  compar = 0

  Gain_D = {}
  for attr in attrSet:
    SubDataSets = divDataByattr(D,attr)
    EntDv = 0
    # 补充上Dv额外的参数
    class_info = []
    Gain_D_attr = Gain(D,attr,class_info)
    # 记录数据集D用属性attr做划分时所有的已知信息,包括gain,数据子集,数据子集的class_num和max类
    Gain_D[attr] = {'gain':Gain_D_attr,'Dv':SubDataSets,'Dv_info':class_info}

  # 找gain最高的attr
  target = attrSet[0]
  for attr in attrSet:
    if Gain_D[attr]['gain'] > compar:
      compar = Gain_D[attr]['gain']
      target = attr

  return target,Gain_D

4)构造新结点

在完成3)的步骤5后,应该为选定的attr划分的子集生成新结点,新结点

# 选取最优属性
  attr,info = node.bestAttr()
  # 获取划分好的数据集
  SubDataSets = info[attr]['Dv']
  SubInfo = info[attr]['Dv_info']


  # 生成子node
  Attr = copy.deepcopy(Attr_Div)
  Attr.remove(attr)
  st = 0
  for value,subds in SubDataSets.items():
    # 因为假设是离散属性,所以新的self.attr必然要去掉已经选出的attr
    subnodeAttr = copy.deepcopy(Attr)
    # 获取已经算好的Dv的max和cal_class
    submax = SubInfo[st][0]
    subcal_class = SubInfo[st][1]
    st+=1

    # 生成新结点
    subnode = Node(subds,subnodeAttr,submax,subcal_class,class_num)
    subnode.setflag(attr)
    # 新结点还要继续加入tree进行讨论
    tree.put(subnode)

    # 父结点记录子结点的指引
    node.addsubDs(subnode,value)

4、完整的结点类代码

# 说明:
# 设数据集是[{},{},{},...,{}]的格式
# {}的格式是{'attr1':'value1,'attr2':'value2',..,'label':'class'}
# label是结点node:表明其为叶节点还是非叶节点
# Class 是当node为叶结点时,该集合的类别
#
# 类别的数量
class_num = 2
class Node():
  def __init__(self,D,A,max,cal_class,class_num):
    self.data = D
    self.attr = A
    self.class_num = class_num
    self.cal_class = cal_class

    self.max = max
    self.label = 0 # 0表示非叶结点 1表示叶结点
    self.Class = 0 # 默认一个
    self.flag = "init"

  def isSameClass(self):
    if self.cal_class[self.max] == len(self.data):
      return True
    return False

  def isNoAttr(self):
    # 属性集为空
    if self.attr == None or self.attr==[]:
      return True,[]
    # 存储取值不同的属性
    self.Attr_Div = []
    for a in self.attr:
      a1 = self.data[0][a]
      for d in self.data:
        if d[a] != a1:
          self.Attr_Div.append(a)
          break
    # 无可分的属性
    if self.Attr_Div == []:
      return True,[]

    return False,self.Attr_Div



  # 计算选取最优划分属性
  def bestAttr(self):
    # 指向划分的子结点
    self.subDs = {}
    self.bestattr,self.Gain_D = choseAttr(self.data,self.Attr_Div)
    return self.bestattr,self.Gain_D

  def setflag(self,attr):
    self.flag = attr


  # 设置subDs
  def addsubDs(self,node,value):
    self.subDs[value] = node




5、完整的构造树的过程

import copy
import queue

def do_tree(tree):
  node = tree.get()
  print(node.data)

  # 判断D中的类别是不是都是一类
  re = node.isSameClass()

  if re:
    print("当前node都属于同一类别")
    # 如果D中的数据都属于同一个类别
    node.Class = node.max
    node.label = 1 # 标记为叶子结点
    return

  # D中的数据并不属于同一个类别
  # 判断属性是否可分
  boolre,Attr_Div = node.isNoAttr()
  print(f"Attr_Div={Attr_Div}")
  # D中的属性不可再分
  if boolre == True:
    print("当前类别属性不可再分")
    node.label = 1
    node.Class = node.max
    return

  # 选取最优属性
  attr,info = node.bestAttr()
  # 获取划分好的数据集
  SubDataSets = info[attr]['Dv']
  SubInfo = info[attr]['Dv_info']


  # 生成子node
  Attr = copy.deepcopy(Attr_Div)
  Attr.remove(attr)
  st = 0
  for value,subds in SubDataSets.items():
    # 因为假设是离散属性,所以新的self.attr必然要去掉已经选出的attr
    subnodeAttr = copy.deepcopy(Attr)
    # 获取已经算好的Dv的max和cal_class
    submax = SubInfo[st][0]
    subcal_class = SubInfo[st][1]
    st+=1

    # 生成新结点
    subnode = Node(subds,subnodeAttr,submax,subcal_class,class_num)
    subnode.setflag(attr)
    # 新结点还要继续加入tree进行讨论
    tree.put(subnode)

    # 父结点记录子结点的指引
    node.addsubDs(subnode,value)

def TreeGenerate(D,A):
  # 计算初始数据集的max和cal_class
  max,cal_class = calMaxClass(D,class_num)
  # 生成根结点
  node = Node(D,A,max,cal_class,class_num)

  tree = queue.Queue()
  tree.put(node)
  while tree.empty() == False:
    do_tree(tree)
  return node



6、建树

1)准备数据集


dataSet = [
        # 1
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 2
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 3
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 4
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 5
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 6
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        # 7
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        # 8
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],

        # ----------------------------------------------------
        # 9
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        # 10
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        # 11
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        # 12
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        # 13
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 14
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 15
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        # 16
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        # 17
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]
Attr = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']

# 硬编码类别
class_dict = {'坏瓜':0,'好瓜':1}

# 将数据合并格式
D = []
for i in range(len(dataSet)):
  d = {}
  for j in range(len(Attr)):
    d[Attr[j]] = dataSet[i][j]
  d['Class'] = class_dict[dataSet[i][-1]]
  D.append(d)

print(D)

2)建树

root = TreeGenerate(D,Attr)

7、绘图查看树的结构

1)绘图代码

只是打印每层的结点,通过分支数目得知父子结点的关系

cur = root
# 表示区分的属性

q = queue.Queue()
q.put(cur)
while q.empty()==False:
  # 这层的宽度
  width = q.qsize()
  for i in range(width):
    # 用/**/包住一个node
    print(" /*",end="")
    cur = q.get()
    if cur.label == 1:
      # 叶子结点
      print(f"叶子:{cur.Class,cur.flag,cur.data[0][cur.flag]}",end="")
    else:
      l = len(cur.subDs)
      print(f"被分类依据:{cur.flag}",end="")
      if cur.flag != "init":
        print(f"值:{cur.data[0][cur.flag]}",end="  ")
      print(f",分支:{l}个",end="")

      for key,nod in cur.subDs.items():
        q.put(nod)

    print("*/ ",end="")
  print("")


2)结果

手绘还原:

3)预测

投入一个样本,返回好瓜/坏瓜判断

def predict(data,root):
  cur = root
  while cur.label != 1:
    attr = cur.bestattr
    cur = cur.subDs[data[attr]]
  return cur.Class

for d in D:
  pd_label = predict(d,root)
  if pd_label == 0:
    print("坏瓜")
  else:
    print("好瓜")

结果打印8行好瓜,9行坏瓜

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/470561.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

线性代数基础概念和在AI中的应用

基本概念 线性代数是数学的一个分支,专注于向量、向量空间(也称为线性空间)、线性变换和矩阵的研究。这些概念在数据科学、人工智能、工程学和物理学等多个领域都有广泛应用。以下是这些基本概念的详细解释和它们在数据处理和AI中的应用。 …

Jackson 2.x 系列【2】生成器 JsonGenerator

有道无术,术尚可求,有术无道,止于术。 本系列Jackson 版本 2.17.0 源码地址:https://gitee.com/pearl-organization/study-seata-demo 文章目录 1. 前言2. 案例演示2.1 创建 JsonFactory2.2 创建 JsonGenerator2.3 写入操作2.4 查…

期权波动率是什么?怎么计算?

期权波动率 历史波动率:基于历史行情计算出来的历史波动率 我们现在站在现实时点B回顾过去,从A到B这段时间的历史行情我们是知道的,但是基于过去一段时间,标的价格的历史数据计算出来的波动率,就是历史波动率&#x…

分布式搜索引擎elasticsearch专栏二

上一篇的传送门: 分布式搜索引擎elasticsearch专栏一-CSDN博客 这一篇博文主要讲解elasticsearch的数据搜索功能。下面会分别使用DSL和RestClient实现搜索。 1.DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1.DSL查询分类 Elasticsea…

Mac nvm install failed python: not found

报错 $>./configure --prefix/Users/xxx/.nvm/versions/node/v12.22.12 < ./configure: line 3: exec: python: not found nvm: install v12.22.12 failed!解决方法 到 App 文件夹&#xff0c;并且打开 cd /System/Applications/Utilities/ open .记得改完 Rosetta 之…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之三 简单卡通漫画风格效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之三 简单卡通漫画风格效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之三 简单卡通漫画风格效果 一、简单介绍 二、简单卡通漫画风格效果实现原理 A、边缘蒙版的处理说明&#xff1…

FPGA控制AD7606_AD7606解读

目录 一、AD7606解读二、引脚说明三、时序图 一、AD7606解读 AD7606特点&#xff1a; 8通道同步采样模拟通道数为8分辨率&#xff1a;16bit&#xff0c;即最小采样的电压为5V/(2^16) 0,00007V&#xff0c;即数字量的1就代表模拟量的0,00007V&#xff0c;2代表0,00014V有效位数…

鸿蒙开发实战:【系统服务管理部件】

简介 samgr组件是OpenHarmony的核心组件&#xff0c;提供OpenHarmony系统服务启动、注册、查询等功能。 系统架构 图 1 系统服务管理系统架构图 目录 /foundation/systemabilitymgr ├── samgr │ ├── bundle.json # 部件描述及编译文件 │ ├── frameworks …

linux之centos7vmware虚拟机的安装

目录 一、下载合适的vmware和操作系统镜像安装文件 来自引用文章的软件下载本片文章使用的软件下载 二、根据教程进行安装 三、网络配置解说 四、配置网络 编辑虚拟机网络 对VMWARE虚拟机网络进行配置 设置虚拟机网络为NAT模式 设置自定义网络为 VMnet8(NAT模式) 编辑li…

[ C++ ] STL---list的使用指南

目录 list简介 list的常用接口 构造函数 赋值运算符重载 迭代器 容量相关接口 元素访问接口 修改相关接口 头插push_front() 头删pop_front() 尾插push_back() 尾删pop_back() insert() erase() list的迭代器失效 list简介 1. list是可以以O(1)的时间复杂度在任意…

机器人现有力控技术检索

文章目录 力控技术1 基本柔顺力控1.1 直接力控1.2 间接力控1.2.1 被动和主动柔顺控制1.2.2 混合力位控制1.2.3 阻抗&导纳力控1.2.3.1 原理1.2.3.2 区别和联系1.2.3.3 工程应用 1.2.4 阻抗力控1.2.5 导纳力控 2 先进力控2.1 自适应力控2.2 鲁棒力控2.3 学习力控 3 智能力控3…

【LabVIEW FPGA入门】使用FPGA实现串行同步接口(SSI)

SSI&#xff08;串行同步接口&#xff09;是连接绝对位置传感器和控制器的广泛应用的串行接口。SSI利用控制器发出一个时钟脉冲序列&#xff0c;初始化传感器的门限输出。 传感器不断更新位置数据&#xff0c;并传送到移位寄存器中。在每一个时钟脉冲序列之间&#xff…

HTML_CSS学习:超链接、列表、表格、表格常用属性

一、超链接_唤起指定应用 1.相关代码 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><title>超链接_唤起指定应用</title> </head> <body><a href"tel:10010">电话联…

JavaScript初学心得

JavaScript JavaScript原名是livescript&#xff0c;是由美国网景开发的一种用于对网页操作的脚本语言 网页操作&#xff08;图片切换&#xff09; 脚本语言&#xff08;不需要编译 sql,html,css,javascript,由某种解释器直接可以运行&#xff09; livescript也是面向对象的…

基于Go的开源全栈测试工具RunnerGo使用详解

前言 测试人员目前广泛使用的测试工具如Postman、JMeter等&#xff0c;如Postman缺少对API的性能测试部分、JMeter则缺少测试报告和测试脚本的统一管理系统及UI测试功能&#xff0c;而RunnerGo的出现填补了传统测试工具的空缺&#xff0c;为测试人员提供了一个可以满足测试全流…

计算机毕业设计-基于Python的“哔哩哔哩视频网”视频热度分析

概要 在21世纪的今天&#xff0c;网络发展越来越快&#xff0c;网上的娱乐方式也越来越多样化&#xff0c;而如今在网上观看视频消遣时间越来越受到大众的青睐。Bilibili视频网站是现当下年轻人最受欢迎的一个视频网站。有调查显示&#xff0c;直到2019年的10月份&#xff0c;B…

基于python车辆故障管理系统的设计与实现flask-django-nodejs-php

网络发展前景无限&#xff0c;及早与网络结合&#xff0c;与信息时代同步&#xff0c;与高科技汇合&#xff0c;定会给社会各行各业的发展注入新鲜的活力。站在网络时代的前夜&#xff0c;我们清晰地听到了网络时代的宣言&#xff1a;谁掌握了网络&#xff0c;谁就掌握了未来。…

视觉信息处理和FPGA实现第5次作业-Matlab实现图像逆时针旋转90度

一、Matlab2022a安装 链接&#xff1a;https://pan.quark.cn/s/6e177bc7c11d 提取码&#xff1a;dKNN 二、Matlab使用 2.1 新建一个脚本文件&#xff08;.m文件&#xff09; 2.2 另存为到便于归档的地方 考虑到.m文件如果不是全英文路径&#xff0c;也有可能会出问题&#…

51单片机学习笔记7 串转并操作方法

51单片机学习笔记7 串转并操作方法 一、串转并操作简介二、74HC595介绍1. **功能**&#xff1a;2. **引脚**&#xff1a;3. **工作原理**&#xff1a;4. 开发板原理图&#xff08;1&#xff09;8*8 LED点阵&#xff1a;&#xff08;2&#xff09;74HC595 串转并&#xff1a; 三…

qt-pdf-viewer-library 编译过程记录

1.qtpdfviewerinitializer.h 中 类模板问题需要修改为下面代码: https://github.com/develtar/qt-pdf-viewer-library 下载代码&#xff1a; 编译出现错误 修改代码&#xff0c;如下: 2.无法触发onViewerLoaded 事件&#xff0c;就是界面无法显示PDF文件 修改下面代码&#…