Tensorflow2.0搭建网络八股扩展

目录

一、自制数据集

准备:txt和图片

制作函数

二、断点继训,存取模型

1.读取保存的模型

2.保存模型

3.正确使用

三、参数提取,把参数存入txt

参数提取

四、acc/loss可视化,查看效果

1.前提开启:获取history

2.history参数表

3.代码

五、应用程序,给图识物

代码实现

总结


一、自制数据集

准备:txt和图片

需要有图片的名称信息,并且有txt文件进行汇总,下面是txt信息

制作函数

//路径
train_path = './mnist_image_label/mnist_train_jpg_60000/'
train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'
x_train_savepath = './mnist_image_label/mnist_x_train.npy'
y_train_savepath = './mnist_image_label/mnist_y_train.npy'

//函数
def generateds(path, txt):
    f = open(txt, 'r')  # 以只读形式打开txt文件
    contents = f.readlines()  # 读取文件中所有行
    f.close()  # 关闭txt文件
    x, y_ = [], []  # 建立空列表
    for content in contents:  # 逐行取出
        value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表
        img_path = path + value[0]  # 拼出图片路径和文件名
        img = Image.open(img_path)  # 读入图片
        img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式
        img = img / 255.  # 数据归一化 (实现预处理)
        x.append(img)  # 归一化后的数据,贴到列表x
        y_.append(value[1])  # 标签贴到列表y_
        print('loading : ' + content)  # 打印状态提示

    x = np.array(x)  # 变为np.array格式
    y_ = np.array(y_)  # 变为np.array格式
    y_ = y_.astype(np.int64)  # 变为64位整型
    return x, y_  # 返回输入特征x,返回标签y_
//调用函数,并保存numpy格式的训练数据
    x_train, y_train = generateds(train_path, train_txt)
    x_train_save = np.reshape(x_train, (len(x_train), -1))
    np.save(x_train_savepath, x_train_save)
    np.save(y_train_savepath, y_train)

二、断点继训,存取模型

1.读取保存的模型

load_weights(路径文件名)

2.保存模型

tf.keras.callbacks.ModelCheckpoint(
filepath=路径文件名, 
save_weights_only=True/False, 
save_best_only=True/False)
history = model.fit( callbacks=[cp_callback] 

3.正确使用

# 保存训练的模型路径
checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    # 读取文件模型名字
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
# 回调函数,将训练好的模型返回history
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

三、参数提取,把参数存入txt

提取可训练参数

model.trainable_variables 返回模型中可训练的参数

参数提取

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

四、acc/loss可视化,查看效果

1.前提开启:获取history

history=model.fit(训练集数据, 训练集标签, batch_size= , epochs=, validation_split=用作测试数据的比例,validation_data=测试集, validation_freq=测试频率)

2.history参数表

训练集loss: loss

测试集loss: val_loss

训练集准确率: sparse_categorical_accuracy

测试集准确率: val_sparse_categorical_accuracy

3.代码

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

五、应用程序,给图识物

前向传播执行应用

predict(输入特征, batch_size=整数) 返回前向传播计算结果

代码实现

from PIL import Image
import numpy as np
import tensorflow as tf
# 训练好的模型
model_save_path = './checkpoint/mnist.ckpt'

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')])
# 读取模型
model.load_weights(model_save_path)
# 测试数据
preNum = int(input("input the number of test pictures:"))
# 将输入的数据修改成我们想要的数据
for i in range(preNum):
    image_path = input("the path of test picture:")
    img = Image.open(image_path)
    img = img.resize((28, 28), Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))

    for i in range(28):
        for j in range(28):
            if img_arr[i][j] < 200:
                img_arr[i][j] = 255
            else:
                img_arr[i][j] = 0

    img_arr = img_arr / 255.0
    x_predict = img_arr[tf.newaxis, ...]
    result = model.predict(x_predict)

    pred = tf.argmax(result, axis=1)

    print('\n')
    tf.print(pred)
















总结

  • 本文主要借鉴:mooc曹健老师的《人工智能实践:Tensorflow笔记》
  • 可以利用这几种方式将八股神经网络进行优化应用到我们想要的场景中
  • 可差分的神经网络体系方便我们使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/91647.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DETR-《End-to-End Object Detection with Transformers》论文精读笔记

DETR&#xff08;基于Transformer架构的目标检测方法开山之作&#xff09; End-to-End Object Detection with Transformers 参考&#xff1a;跟着李沐学AI-DETR 论文精读【论文精读】 摘要 在摘要部分作者&#xff0c;主要说明了如下几点&#xff1a; DETR是一个端到端&am…

金蝶云星空对接打通管易云分布式调入单查询接口与其他入库单新增完结接口接口

金蝶云星空对接打通管易云分布式调入单查询接口与其他入库单新增完结接口接口 对接系统金蝶云星空 金蝶K/3Cloud&#xff08;金蝶云星空&#xff09;是移动互联网时代的新型ERP&#xff0c;是基于WEB2.0与云技术的新时代企业管理服务平台。金蝶K/3Cloud围绕着“生态、人人、体验…

深度学习12:胶囊神经网络

目录 研究动机 CNN的缺陷 逆图形法 胶囊网络优点 胶囊网络缺点 研究内容 胶囊是什么 囊间动态路由算法 整体框架 编码器 损失函数 解码器 传统CNN存在着缺陷&#xff08;下面会详细说明&#xff09;&#xff0c;如何解决CNN的不足&#xff0c;Hinton提出了一种对于图…

Blender给一个对象添加多个动画

最近在做一个类似元宇宙的项目&#xff0c;需要使用3D建模软件来给3D模型添加动画&#xff0c;3D建模软件选择Blender&#xff08;因为开源免费…&#xff09;&#xff0c;版本: V3.5 遇到的需求是同一个对象要添加多个动画&#xff0c;然后在代码里根据需要调取动画来执行。本…

Excel 打开文件提示内存或磁盘不足

Excel表格打开文件时&#xff0c;提示内存或磁盘空间不足&#xff0c;Microsoft Excel 无法再次打开或保存任何文档&#xff0c;这是很多人都会遇到的问题&#xff0c;该如何解决这个问题呢&#xff1f;如果你是用Excel表格打开某个文件时遇到提示内存或磁盘空间不足&#xff0…

java八股文面试[JVM]——垃圾回收器

jvm结构总结 常见的垃圾回收器有哪些&#xff1f; CMS&#xff08;Concurrent Mark Sweep&#xff09; 整堆收集器&#xff1a; G1 由于整个过程中耗时最长的并发标记和并发清除过程中&#xff0c;收集器线程都可以与用户线程一起工作&#xff0c;所以总体上来说&#xff0c;…

求生之路2社区服务器sourcemod安装配置搭建教程centos

求生之路2社区服务器sourcemod安装配置搭建教程centos 大家好我是艾西&#xff0c;通过上文我们已经成功搭建了求生之路2的服务端。但是这个服务端是纯净的服务端&#xff0c;就是那种最纯粹的原版。如果想要实现插件、sm开头的命令等功能&#xff0c;需要安装这个sourcemod。…

机器人制作开源方案 | 桌面级机械臂--本体说明+驱动及控制

一、本体说明 1. 机械臂整体描述 该桌面级机械臂为模块化设计&#xff0c;包含主机模块1个、转台模块1个、二级摆动模块1个、可编程示教盒1个、2种末端执行器、高清摄像头&#xff0c;以及适配器、组装工具、备用零件等。可将模块快速组合为一个带被动关节的串联3自由度机械臂…

vue 简单实验 v-model 变量和htm值双向绑定

1.代码 <script src"https://unpkg.com/vuenext" rel"external nofollow" ></script> <div id"two-way-binding"><p>{{ message }}</p><input v-model"message" /> </div> <script>…

GPT---1234

GPT:《Improving Language Understanding by Generative Pre-Training》 下载地址:https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdfhttps://cdn.openai.com/research-covers/language-unsupervised/language_understa…

初识【类和对象】

目录 1.面向过程和面向对象初步认识 2.类的引入 3.类的定义 4.类的访问限定符及封装 5.类的作用域 6.类的实例化 7.类的对象大小的计算 8.类成员函数的this指针 1.面向过程和面向对象初步认识 C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的…

软件测试技术分享丨使用Postman搞定各种接口token实战

现在许多项目都使用jwt来实现用户登录和数据权限&#xff0c;校验过用户的用户名和密码后&#xff0c;会向用户响应一段经过加密的token&#xff0c;在这段token中可能储存了数据权限等&#xff0c;在后期的访问中&#xff0c;需要携带这段token&#xff0c;后台解析这段token才…

css background实现四角边框

2023.8.27今天我学习了如何使用css制作一个四角边框&#xff0c;效果如下&#xff1a; .style{background: linear-gradient(#33cdfa, #33cdfa) left top,linear-gradient(#33cdfa, #33cdfa) left top,linear-gradient(#33cdfa, #33cdfa) right top,linear-gradient(#33cdfa, #…

网络安全(自学黑客)一文全解

目录 特别声明&#xff1a;&#xff08;文末附资料笔记工具&#xff09; 一、前言 二、定义 三、分类 1.白帽黑客&#xff08;White Hat Hacker&#xff09; 2.黑帽黑客&#xff08;Black Hat Hacker&#xff09; 3.灰帽黑客&#xff08;Gray Hat Hacker&#xff09; 四…

RT-Thread 时钟管理

时钟节拍 任何操作系统都需要提供一个时钟节拍&#xff0c;以供系统处理所有和时间有关的事件&#xff0c;如线程的延时、时间片的轮转调度以及定时器超时等。 RTT中&#xff0c;时钟节拍的长度可以根据RT_TICK_PER_SECOND的定义来调整。rtconfig.h配置文件中定义&#xff1a…

Android 11/12 app-lint 系统Update-API时Lint检查问题

有以下两种解决方法 1. 加SupressLint注解 这种方式你可以其他博客也有 但是要每个类和方法都加上SuppressLint 太麻烦了 我才不要这样呢 2. 添加 --api-lint-ignore-prefix 参数直接跳过代码检查 1. 打开 frameworks/base/Android.bp 文件 2. 搜索找到这个字段 metalava…

windows可视化界面管理服务器上的env文件

需求&#xff1a;在 Windows 环境中通过可视化界面编辑位于 Linux 主机上的 env 文件的情况&#xff0c;我现在环境是windows环境&#xff0c;我的env文件在linux的192.168.20.124上&#xff0c;用户是op&#xff0c;密码是op&#xff0c;文件绝对路径是/home/op/compose/env …

nginx生成自定义证书

1、创建key文件夹 [rootlocalhost centos]# mkdir key 进入key文件夹 [rootlocalhost centos]# cd key/ 2、生成私钥文件 [rootlocalhost key]# openssl genrsa -des3 -out ssl.key 4096 输入这个key文件的密码。不推荐输入&#xff0c;因为以后要给nginx使用。每次reload ngin…

WPF中的数据转换-StringFormat

WPF中的数据转换-StringFormat 前言 字符串格式化。使用该功能可以通过设置Binding.StringFormat属性对文本形式的数据进行转换——例如包含日期和数字的字符串。对于至少一半的格式化任务&#xff0c;字符串格式化是一种便捷的技术。 使用 当设置Binding.StringFormat属性…

机器学习基础之《分类算法(4)—案例:预测facebook签到位置》

一、背景 1、说明 2、数据集 row_id&#xff1a;签到行为的编码 x y&#xff1a;坐标系&#xff0c;人所在的位置 accuracy&#xff1a;定位的准确率 time&#xff1a;时间戳 place_id&#xff1a;预测用户将要签到的位置 3、数据集下载 https://www.kaggle.com/navoshta/gr…
最新文章