Keras多分类鸢尾花DEMO

完整的一个小demo:

pandas==1.2.4

numpy==1.19.2

python==3.9.2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pandas import DataFrame
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
from sklearn import preprocessing
from sklearn.datasets import load_iris
# 映射函数iris_type: 将string的label映射至数字label
import os

# data.to_csv('data.csv',index=False)  #cvs保存文件不会保存index列
# data = pd.read_csv('data.csv',index_col=0)  #读取csv文件的时候选择不读取第一列信息
def downLoad():
    path="../httdemo/"
    iris = load_iris()
    data = iris.data #获取特征数据
    target = iris.target#获取目标数据
    data_information = DataFrame(data, columns=['bcalyx', 'scalyx', 'length', 'width']) #重新定义特征数据的列名
    data_target = DataFrame(target, columns=['target'])#目标数据列名target
    data_csv = pd.concat([data_information, data_target], axis=1) #合并特征数据和目标数据到一个DataFrame
    if not os.path.exists(path):#把DataFrame数据保存到本地,以.CVS的格式保存
        os.makedirs(path)
    filename = path + 'iris.csv'  #定义保存路径
    data_csv.to_csv(filename,index=False) #index==False表示,序号下表列不做保存

    # 本地数据保存为excel文件
    # outputfile = "iris.xls"  # 保存文件路径名
    # column = list(data['feature_names'])
    # dd = pd.DataFrame(data.data, index=range(150), columns=column)
    # dt = pd.DataFrame(data.target, index=range(150), columns=['outcome'])
    # jj = dd.join(dt, how='outer')  # 用到DataFrame的合并方法,将data.data数据与data.target数据合并
    # jj.to_excel(outputfile)  # 将数据保存到outputfile文件中


def readData(path):
    Data = pd.read_csv(path,names=['bcalyx', 'scalyx', 'length', 'width','target']) #读取本地保存的CVS数据
    Data.head(10)#展示前10
    # 变量初始化
    # 最后一列为y,其余为x
    cols = Data.shape[1]  # 获取列数 shape[0]行数 [1]列数
    X = Data.iloc[1:, 0:cols - 1].astype(float)  # 获取得到特征数据,转换为Float的格式,如果输入str,会报错的,取前cols-1列,即输入向量
    y = Data.iloc[1:, cols - 1:cols]  # 取最后一列,即目标变量
    X = np.array(X)
    y = np.array(y)
    print(y)
    return X,y

def startM():
    path = "../httdemo/iris.csv"
    X,y=readData(path)  #加载数据

    from sklearn.preprocessing import OneHotEncoder
    # 创建独热编码器对象
    encoder = OneHotEncoder() #sklearn创建热编码器对象
    # 训练独热编码器 (将目标数据进行训练)
    encoder.fit(y)

    # 转换特征向量 (将目标数据y转换为特征向量[[0,0,1][0,1,0][0,0,1]])格式
    encoded_data = encoder.transform(y).toarray()

    # shuffle = True 随机打乱后再进行分割数据
    X_train, X_test, y_train, y_test = train_test_split(X, encoded_data, test_size=0.3,shuffle=True)


    #构建网络模型
    model = Sequential()
    model.add(Dense(units=1024, activation='relu', input_dim=4))  # 输入层,1024个激活单元,激活函数为relu,输入数据维度为(4,)
    model.add(Dense(units=512, activation='relu'))  # 隐藏层,512个激活单元,激活函数为relu
    model.add(Dense(units=256, activation='relu'))  # 隐藏层,256个激活单元,激活函数为relu
    model.add(Dropout(0.1)) #丢到10%的数据
    model.add(Dense(units=3, activation='softmax'))  # 输出层,3个输出单元,激活函数为softmax)


    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    #开始训练
    model.fit(X_train, y_train, batch_size=30, epochs=32)

    #预测测试集的结果
    result = model.predict(X_test)
    yTest=np.round(result, 2)#保留俩位小数
    print(yTest)

    #测试机准确率评估
    score = model.evaluate(X_test, y_test)
    print('loss值为:', score[0])
    print('准确率为:', score[1])



if __name__=='__main__':
    startM()
    # downLoad()



 

特征数据是str,需要转换成float 

X = Data.iloc[1:, 0:cols - 1].astype(float)

 

target的数据打印:

热编码转换之后的数据:

测试集预测结果:表示的位概率值,那个数值比较大,就是哪一个类别,每一个数组表示A,B,C

[[0.01 0.28 0.71]
 [0.91 0.06 0.03]
 [0.01 0.28 0.71]
 [0.02 0.33 0.66]
 [0.01 0.28 0.71]
 [0.06 0.51 0.44]
 [0.92 0.05 0.02]
 [0.04 0.43 0.53]
 [0.02 0.38 0.6 ]
 [0.01 0.31 0.67]
 [0.03 0.42 0.55]
 [0.01 0.24 0.76]
 [0.01 0.32 0.67]
 [0.06 0.49 0.45]
 [0.01 0.25 0.74]
 [0.01 0.31 0.68]
 [0.08 0.51 0.42]
 [0.92 0.06 0.02]
 [0.88 0.09 0.04]
 [0.02 0.33 0.65]
 [0.01 0.29 0.7 ]
 [0.01 0.28 0.71]
 [0.04 0.47 0.49]
 [0.9  0.07 0.03]
 [0.91 0.06 0.03]
 [0.86 0.1  0.04]
 [0.18 0.5  0.31]
 [0.89 0.08 0.03]
 [0.91 0.07 0.03]
 [0.06 0.47 0.48]
 [0.02 0.37 0.61]
 [0.04 0.39 0.57]
 [0.87 0.09 0.04]
 [0.05 0.46 0.49]
 [0.01 0.27 0.72]
 [0.02 0.34 0.64]
 [0.05 0.45 0.5 ]
 [0.92 0.06 0.02]
 [0.09 0.53 0.38]
 [0.04 0.48 0.48]
 [0.95 0.04 0.02]
 [0.01 0.26 0.73]
 [0.   0.24 0.76]
 [0.78 0.15 0.07]
 [0.   0.21 0.79]]

运行结果:

训练完成之后保存模型,然后测试模型:

 

读取模型,开始预测:

from tensorflow.keras.models import load_model
import numpy as np
# 模型的导入
model = load_model('../httdemo/httmodel.h5')
# 对数据的预测输入分别为[花萼长,花萼宽,花瓣长,花瓣宽]
y_pred = model.predict([[2,1,5.5,2],[2.3,4.5,5.2,9]])
print(y_pred)
for i in y_pred:
    a = np.argmax(i)
    if a == 0 : print('该花为A')
    elif a == 1 : print('该花为B')
    elif a == 2 : print('该花为C')

测试结果:准确预测出来为C种类 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/276702.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Selenium库和ChromeDriver谷歌驱动最新版安装

1.安装selenium库 使用pip安装第三方库selenium,速度较慢。 pip install selenium 使用国内清华源安装第三方库selenium,速度较快。 pip install selenium -i https://pypi.tuna.tsinghua.edu.cn/simple 2.安装谷歌浏览器驱动 驱动下载链接&#x…

GoogleNetv1:Going deeper with convolutions更深的卷积神经网络

文章目录 GoogleNetv1全文翻译论文结构摘要1 引言2 相关工作3 动机和高层考虑稀疏矩阵 4 结构细节引入1x1卷积核可以减少通道数 5 GoogleNet6 训练方法7 ILSVRC 2014 分类挑战赛设置和结果8 ILSVRC 2014检测挑战赛设置和结果9 总结 论文研究背景、成果及意义论文图表 GoogleNet…

iPhone 13 Pro 更换『移植电芯』和『超容电池』体验

文章目录 考虑换电池Ⅰ 方案一Ⅱ 方案二 总结危险 Note系列地址 简 述: 首发买的iPhone 13P &#xff08;2021.09&#xff09;&#xff0c;随性使用一年出头&#xff0c;容量就暴跌 85%&#xff0c;对比朋友一起买的同款&#xff0c;还是95%。这已经基本得一天两充 >_<&a…

代码随想录刷题笔记(DAY2)

今日总结&#xff1a;今天在学 vue 做项目&#xff0c;学校还有很多作业要完成&#xff0c;熬到现在写完了三道题&#xff0c;有点太晚了&#xff0c;最后一道题的题解明天早起补上。 Day 2 01. 有序数组的平方&#xff08;No. 977&#xff09; 给你一个按 非递减顺序 排序的…

搭建简单的GPT聊天机器人

目录 第一步 进行语料库读取、文本预处理&#xff0c;完成data_utls.py 第二步 进行Seq2Seq模型的构建&#xff0c;完成Seq2Seq.py 第三步 进行模型参数设置、加载词典和数据、数据准备、GPU设置、构建优化器和损失函数&#xff0c;进行模型的训练和测试&#xff0c;完成…

使用vue3实现echarts漏斗图表以及实现echarts全屏放大效果

1.首先安装echarts 安装命令&#xff1a;npm install echarts --save 2.页面引入 echarts import * as echarts from echarts; 3.代码 <template> <div id"main" :style"{ width: 400px, height: 500px }"></div> </template> …

OSPF被动接口配置-新版(14)

目录 整体拓扑 操作步骤 1.基本配置 1.1 配置R1的IP 1.2 配置R2的IP 1.4 配置R4的IP 1.5 配置R5的IP 1.6 配置PC-1的IP地址 1.7 配置PC-2的IP地址 1.8 配置PC-3的IP地址 1.9 配置PC-4的IP地址 1.10 检测R1与PC3连通性 1.11 检测R2与PC4连通性 1.12 检测R4与PC1连…

SuperMap iServer发布的ArcGIS REST 地图服务如何通过ArcGIS API进行要素查询

作者&#xff1a;yx 前言 前面我们介绍了SuperMap iServer发布的ArcGIS REST 地图服务如何通过ArcGIS API加载&#xff0c;这里呢我们再来看看如何进行要素查询呢&#xff1f; 一、服务发布 SuperMap iServer发布的ArcGIS REST 地图服务如何通过ArcGIS API加载已经介绍如何发…

【Vue3】创建项目的方式

1. 基于 vue-cli 创建 ## 查看vue/cli版本&#xff0c;确保vue/cli版本在4.5.0以上 vue --version## 安装或者升级你的vue/cli npm install -g vue/cli## 执行创建命令 vue create vue_test本质上使用webpack&#xff0c;默认安装以下依赖&#xff1a; 2. 基于 vite 创建 官…

mac 生成 本地.ssh

输入下面命令行 ssh-keygen 默认回车得到下面的 Generating public/private rsa key pair. Enter file in which to save the key (/Users/{用户名}/.ssh/id_rsa): Enter passphrase (empty for no passphrase): Enter same passphrase again: Your identification has be…

深入浅出理解转置卷积Conv2DTranspose

一、参考资料 【keras/Tensorflow/pytorch】Conv2D和Conv2DTranspose详解 怎样通俗易懂地解释反卷积&#xff1f; 转置卷积&#xff08;Transposed Convolution&#xff09; 抽丝剥茧&#xff0c;带你理解转置卷积&#xff08;反卷积&#xff09; 二、标准卷积(Conv2D) 1. Co…

Linux中的gcc\g++使用

文章目录 gcc\g的使用预处理编译汇编链接函数库gcc选项 gcc\g的使用 这里我们需要知道gcc和g实际上是对应的c语言和c编译器&#xff0c;而其他的Java&#xff08;半解释型&#xff09;&#xff0c;PHP&#xff0c;Python等语言实际上是解释型语言&#xff0c;因此我们经常能听…

【VB测绘程序设计】案例6——Mid\Right等字符串函数的应用(附源代码)

【VB测绘程序设计】案例6——Mid\Right等字符串函数的应用(附源代码) 文章目录 前言一、程序界面二、程序说明三、程序代码四、数据演示总结前言 VB编程中内部函数主要供用户调用,主要有数学运算符函数、字符串函数、转换函数、日期与时间函数、判断函数和格式输出函数等。…

YOLOv5改进 | ODConv卷积助力极限涨点(附修改后的C2f、Bottleneck模块代码)

一、本文介绍 这篇文章给大家带来的是发表于2022年的ODConv(Omni-Dimensional Dynamic Convolution)中文名字全维度动态卷积&#xff0c;该卷积可以即插即用&#xff0c;可以直接替换网络结构中的任何一个卷积模块&#xff0c;在本文的末尾提供可以直接替换卷积模块的ODConv&a…

HLS 2017.4 导出 RTL 报错:ERROR: [IMPL 213-28] Failed to generate IP.

软件版本&#xff1a;HLS 2017.4 在使用 HLS 导出 RTL 的过程中产生如下错误&#xff1a; 参考 Xilinx 解决方案&#xff1a;https://support.xilinx.com/s/article/76960?languageen_US 问题描述 DESCRIPTION As of January 1st 2022, the export_ip command used by Vivad…

SpingBoot的项目实战--模拟电商【2.登录】

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于SpringBoot电商项目的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.功能需求 二.代码编写 …

【STM32】程序在SRAM中运行

程序在RAM中运行 1、配置内存分配。 2、修改跳转文件 FUNC void Setup(void) { SP _RDWORD(0x20000000); PC _RDWORD(0x20000004); } LOAD RAM\Obj\Project.axf INCREMENTAL Setup(); 3、修改下载ROM地址和RAM地址&#xff1b; 中断向量表映射 中断向量表映射到SRA…

AI大模型引领未来智慧科研暨丨ChatGPT在地学、GIS、气象、农业、生态、环境等领域中的高级应用

以ChatGPT、LLaMA、Gemini、DALLE、Midjourney、Stable Diffusion、星火大模型、文心一言、千问为代表AI大语言模型带来了新一波人工智能浪潮&#xff0c;可以面向科研选题、思维导图、数据清洗、统计分析、高级编程、代码调试、算法学习、论文检索、写作、翻译、润色、文献辅助…

某后台管理系统加密参数逆向分析

前言 在我们日常的渗透中经常会遇到开局一个登录框的情况&#xff0c;弱口令爆破当然是我们的首选。但是有的网站会对账号密码等登录信息进行加密处理&#xff0c;这一步不由得阻碍了很多人的脚步。前端的加解密是比较常见的&#xff0c;无论是 web 后台还是小程序&#xff0c…

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现

ThinkPHP6.0任意文件上传 PHPSESSION 已亲自复现 漏洞名称漏洞描述影响版本 漏洞复现环境搭建安装thinkphp6漏洞信息配置 漏洞利用 修复建议 漏洞名称 漏洞描述 2020年1月10日&#xff0c;ThinkPHP团队发布一个补丁更新&#xff0c;修复了一处由不安全的SessionId导致的任意文…
最新文章