利用pytorch两层线性网络对titanic数据集进行分类(kaggle)

利用pytorch两层线性网络对titanic数据集进行分类

最近在看pytorch的入门课程,做了一下在kaggle网站上的作业,用的是titanic数据集,因为想搭一下神经网络,所以数据加载部分简单的把训练集和测试集中有缺失值的列还有含有字符串的列去除了,加入了DataLoader模块,其实这个数据集很小,用不到,本人还没入门,小白一枚。

import torch 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets
from torchvision import transforms
import pandas as pd

class titanicDataset(Dataset):
    def __init__(self,filepath):
        xy=np.loadtxt(filepath,delimiter=',',skiprows=1,usecols=[1,2,7,8],dtype=np.float32)
        self.len=xy.shape[0]
        # print(self.len)
        self.y_data=torch.from_numpy(xy[:,[0]])
        self.x_data=torch.from_numpy(xy[:,1:])
        
    def __getitem__(self,index):#获取索引元素 
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.len
dataset=titanicDataset('./pytorch/dataset/titanic/train.csv')
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)

# print(dataset.x_data,dataset.y_data)
test_loader=DataLoader(dataset=np.loadtxt('./pytorch/dataset/titanic/test.csv',delimiter=',',skiprows=1,usecols=[1,6,7],dtype=np.float32),batch_size=32,shuffle=False,num_workers=0)
print(next(iter(test_loader)))

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        # self.linear1=torch.nn.Linear(4,3)
        self.linear2=torch.nn.Linear(3,2)
        self.linear3=torch.nn.Linear(2,1)
        self.sigmoid=torch.nn.Sigmoid()
    def forward(self,x):
        # x=self.sigmoid(self.linear1(x))
        x=self.sigmoid(self.linear2(x))
        x=self.sigmoid(self.linear3(x))
        return x
model=Model()
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
for epoch in range(10000):
    acc_num=0
    for i,data in enumerate(train_loader,0):
        #1.Prepare data
        inputs,labels=data
        # print(inputs.shape[0])
        #2.Forward
        y_pred=model(inputs)
        loss=criterion(y_pred,labels)
        # print(epoch,i,loss.item())
        #3.Backward
        optimizer.zero_grad()
        loss.backward()
        #4.Update
        optimizer.step()
        y_pred_label=torch.where(y_pred>0.5,torch.tensor([1.0]),torch.tensor([0.0]))
        acc_num+=torch.eq(y_pred_label,labels).sum().item()
    # print(acc_num,len(dataset),len(train_loader.dataset))
    acc=acc_num/len(dataset)
print(acc)
# print(test_loader)
# print(test_loader.dataset.shape)
out = model(torch.tensor(test_loader.dataset))
y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))[:,0]
print(y_pred)
print(pd.Series(y_pred))
id=pd.read_csv('./pytorch/dataset/titanic/test.csv',usecols=['PassengerId']).iloc[:,0]
# print(type(id))

pd.DataFrame({'PassengerId':id,'Survived':pd.Series(y_pred,dtype=int)}).to_csv('pred.csv',index=None)
a=pd.DataFrame([id,pd.Series(y_pred)])
print(a)
# print(y_pred[-10:])


# for x in test_loader:
#     print(x.shape)
#     out = model(x)
#     y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))
# print(y_pred)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/591335.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

每日OJ题_贪心算法二⑤_力扣870. 优势洗牌(田忌赛马)

目录 力扣870. 优势洗牌(田忌赛马) 解析代码 力扣870. 优势洗牌(田忌赛马) 870. 优势洗牌 难度 中等 给定两个长度相等的数组 nums1 和 nums2,nums1 相对于 nums2 的优势可以用满足 nums1[i] > nums2[i] 的索引…

EDA(一)Verilog

EDA(一)Verilog Verilog是一种用于电子系统设计自动化(EDA)的硬件描述语言(HDL),主要用于设计和模拟电子系统,特别是在集成电路(IC)和印刷电路板(…

使用OpenCV绘制两幅图验证DSC和IoU以及BCELoss的计算程序

1.创作灵感 很多小伙伴在玩深度学习模型的时候,需要计算Groudtruth和predict图的dsc、IOU以及BCELoss。这两个关键的指标的程序有很多种写法,今天使用OpenCV绘制两张已知分布的图像,计算其dsc、IOU以及BCELoss。 2、图像如图所示 在一个100100的区域内,红色框范围为预测…

访问jwt生成token404解决方法

背景: 1.在部署新的阿里云环境后发现调用jwt生成token的方法404,前端除了404,台不报任何错误 在本地好用,在老的阿里云环境好用, 2.缩短生成私钥的参数报错,以为私钥太长改了tomcat参数也无效&#xff0…

启动任何类型操作系统:不需要检索 ISO 文件 | 开源日报 No.243

netbootxyz/netboot.xyz Stars: 7.7k License: Apache-2.0 netboot.xyz 是一个方便的平台,可以不需要检索 ISO 文件就能启动任何类型操作系统或实用工具磁盘。它使用 iPXE 提供用户友好的 BIOS 菜单,让您轻松选择所需的操作系统以及特定版本或可引导标志…

华为云耀云服务器开放端口

博客主页:花果山~程序猿-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 一.华为云控制台开放端口 寻找到安全组信息 2. 添加开放的端口信息 3. 检查是否成…

【C++】对文章分词,并对词频用不同排序方法排序,比较各排序算法效率(功能全面,通俗易懂)

文章分词 1.问题描述2.需求分析3.概要设计3.1 主程序流程3.2 函数调用关系 4.主函数实现4.1 main.h4.2 main.cpp 5. 函数实现5.1 processDic函数5.2 forwardMax函数5.3 countWordFreq函数5.4 quickResult函数5.5 其它排序算法效率…

计算机视觉科普到实践

第一部分:计算机视觉基础 引言: 计算机视觉作为人工智能领域的一个重要分支,近年来取得了显著的进展。本文将带领读者深入了解计算机视觉的基础知识,并通过实践案例展示其应用。让我们一同探索这个令人着迷的领域吧!…

SpringSecurity6 学习

学习介绍 网上关于SpringSecurity的教程大部分都停留在6以前的版本 但是,SpringSecurity6.x版本后的内容进行大量的整改,网上的教程已经不能够满足 最新的版本使用。这里我查看了很多教程 发现一个宝藏课程,并且博主也出了一个关于SpringSec…

解锁AI新纪元:如何用好大语言模型?

在20世纪末和21世纪初,⼈类经历了两次信息⾰命的浪潮: 第⼀次是互联网时代的兴起,将世界各地连接在⼀起,改变了⼈们获取信息和交流的⽅式。 第⼆次则是移动互联网时代的到来,智能⼿机和移动应⽤程序的普及使⼈们可以…

Oracle 数据库全面升级为 23ai

从 11g 到 12c 再到 19c,今天,我们迎来了 23ai ! “ Oracle AI Vector Search allows documents, images, and relational data that are stored in mission-critical databases to be easily searched based on their conceptual content Ge…

平平科技工作室-Python-猜数字游戏

一.代码展示 import random print(__猜数字游戏__) print(由平平科技工作室制作) print(游戏规则:1至10随机数随便猜) print (三次没猜对游戏结束) numrandom.randint (1,10) for i in range(3):aint(input(输入你想要猜测的数字))if a>num:print (数字猜的有点大了)elif a…

MySQL-数据缓冲池(Buffer Pool)

InnoDB存储引擎以 页 为单位管理存储空间,增删改查的本质就是访问页面。为提高查询效率,DBMS会占用内存作为缓冲池,在执行SQL之前,会将磁盘上的页 缓存到内存中的 缓冲池(Buffer Pool)后执行相关SQL语句。 …

git学习指南

文章目录 一.版本控制1.认识版本控制2.版本控制功能3.集中式版本控制4.分布式版本控制 二.Git的环境安装搭建1.Git的安装2.Git配置分类3.Git配置选项 三.Git初始化本地仓库1. git init/git clone-获取Git仓库2. 本地仓库文件的划分3. git status-检测文件的状态4. git add-文件…

数据库基础--MySQL多表查询之外键约束

MySQL多表关系 一对一 顾名思义即一个对应一个的关系,例如身份证号对于每个人来说都是唯一的,即个人信息表与身份证号信息表是一对一的关系。车辆信息表与车牌信息表也是属于一对一的关系。 一对多 即一个表当中的一个字段信息,对应另一张…

黑马面试篇1

目录 一、面试准备 二、Redis篇 ​编辑1. 布隆过滤器: 2. 缓存击穿概念&解决方案 3. 双写一致 4. 持久化 1)RDB的执行原理? 2)AOF vs RDB 5. 数据过期策略 6. 数据淘汰策略 7. 分布式锁 8. Redis集群 1&#xff…

如何选择一个出色的APP内测分发平台 - 探讨小猪APP分发平台

在众多APP内测分发平台中如何选择一个出色的APP内测分发平台 - 探讨小猪APP分发平台,小猪APP分发平台(zixun.ppzhu.net)以其出色的服务和高效的推广机制成为行业佼佼者。 小猪APP分发平台的核心优势 小猪APP分发平台不仅以其用户友好的界面赢…

Coze扣子开发指南:搭建一个免费的微信公众号AI客服

运营微信公众号的自媒体,现在借助Coze扣子可以非常好用而且免费的7*24客服了,完全不需要任何编程基础,操作非常简单: 打开Coze扣子,新建一个bot,输入bot名称、功能介绍和图标: 选择大语言模型&…

论文笔记(四十五)Attention Is All You Need

Attention Is All You Need 文章概括摘要1. 介绍2. 背景3. 模型架构3.1 编码器和解码器堆栈3.2 Attention3.2.1 按比例点积Attention3.2.2 Multi-Head Attention3.2.3 注意力在模型中的应用 3.3 定位前馈网络3.4 嵌入与 Softmax3.5 位置编码 4 为什么 Self-Attention5. Trainin…

OpenWRT部署Zerotier虚拟局域网实现内网穿透

前言 细心的小伙伴肯定已经发现了:电脑上部署了Zerotier,如果路由器也部署了OpenWRT,那是否能远程访问呢? 答案是肯定的。 OpenWRT部署Zerotier有啥好处? 那好处必须多,其中的一个便是在外远程控制家里…
最新文章