[PyTorch][chapter 6][李宏毅深度学习][Logistic Regression]

前言:

         logistic回归又称logistic回归分析,是一种广义的线性回归分析模型,常用于数据挖掘,疾病自动诊断,经济预测等领域。 逻辑回归根据给定的自变量数据集来估计事件的发生概率,由于结果是一个概率,因此因变量的范围在 0 和 1 之间。 [3]例如,探讨引发疾病的危险因素,并根据危险因素预测疾病发生的概率等。

         训练样本特别小的时候用 Generative  Model会有较好的效果,大的样本使用Discriminative Model,Discriminative Model里面常用的二分类模型sigmoid ,多分类模型softmax


sigmoid 简介(Discriminative Model

    二分类模型

    1.1   模型定义

           使用了sigmoid 函数作为激活函数

           f(x)=\sigma(z)=\frac{1}{1+e^{-z}}

           z=wx+b=\sum_i w_ix_i+b

           输出 (0,1)

    1.2  损失函数

           假设有N个二分类样本

          

           \left\{\begin{matrix} \hat{y}=1\, \, \, \, ,if \, c_1 \\ \hat{y}=0\, \, \, \, \, , if \, c_2 \end{matrix}\right.

            损失函数定义为

            L(w,b)=f(x^1)f(x^2)(1-f(x^3))..

           我们要找到参数w,b使得上面概率最大

            w^{*},b^{*}=argmax_{w,b}L(w,b)

            根据交叉熵原理:我们对式子取对数。因为是求式子的最大值,可以转换成式子乘以负1,之后求最小值

            w^*,b^*=argmin_{w,b} -lnL(w,b)

            L(w,b)=-\sum_{i}^{N}-\begin{Bmatrix} \hat{y^i}ln f(x^i)+(1-\hat{y^i})ln (1-f(x^i)) \end{Bmatrix}

    1.3 梯度

           对w的求导分为两部分

          \frac{\partial lnf}{\partial f}\frac{\partial f}{\partial z}\frac{\partial z}{\partial w}=\frac{1}{f}f(1-f)x

                               =(1-f)x

           \frac{\partial ln1-f}{\partial f}\frac{\partial f}{\partial z}\frac{\partial z}{\partial w}=\frac{-1}{1-f}f(1-f)x

                                   =-fx

          合并起来

               \frac{\partial L}{\partial w}=\sum_i -\begin{Bmatrix} \hat{y^{i}}(1-f)x^{i}-(1-\hat{y^{i}})fx^{i} \end{Bmatrix}

                     =\sum_{i}-(\hat{y^{i}}-f)x^{i}

                    =\sum_{i}(f-\hat{y^{i}})x^{i}

     1.4 跟Linear 区别



二  Multi-class Classification(softmax)

      多分类模型

    2.1  模型定义

          使用了 softmax 作为激活函数 

           y=\sigma(z_i)=\frac{e^{z_i}}{\sum_{j=1}^{K}e^{z_j}}  

 

  2.2 损失函数

            使cross Entropy

             标签是一个one-hot 向量,非零项代表其类别

            L_{w,b}(\hat{y},y)=\sum_{i=1}^{K}\hat{y_i}logy_i

 

2.3 梯度

       

          y_i=softmax(z_i)=\frac{e^{z_i}}{\sum_j e^{z_j}}

         \left\{\begin{matrix} \frac{\partial y_i}{\partial z_j}=y_i*(1-y_j),j=i\\ \frac{\partial y_i}{\partial z_j}=-y_iy_j ,j\neq i \end{matrix}\right.

     损失函数为

       L=-\sum_{i}^{K}\hat{y_k}logy_k

      只跟其中的非零项有关系,假设非零项为y_i

          \frac{\partial L}{\partial z_j}=\left\{\begin{matrix} \frac{-1}{y_i}*y_i*(1-y_i)=y_i-1,j=i\\ \frac{-1}{y_i}*(-y_i y_j)=y_j-0,j\neq i \end{matrix}\right.

           因为标签值是one-hot

             \left\{\begin{matrix} \hat{y_j}=1,i=j\\ \hat{y_j}=0,i \neq j \end{matrix}\right.

             所以

              \frac{\partial L}{\partial z_j}=y_j-\hat{y_j}


三 代码

  

任务:

         给定的个人资料,预测此人的年收入是否大于50k

数据集说明:
                共有32561训练集数据,16281 测试集数据
(8140 in private test set and 8141 in public test set)

数据集情况:共14个feature 

 ?  代表不确定性
1 age 年龄: continuous.
2 workclass 工作性质: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
3 fnlwgt: continuous. *The number of people the census takers believe that observation represents.人口普查员认为这一观察结果所代表的人数。

4 education 教育水平: 
   Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.

5 education-num: continuous.

6 marital-status 婚姻状况: 
    Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.

7 occupation 工作: 
   Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.

8 relationship 关系: 
    Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.

9 race 种族: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
10 sex 性别: Female, Male.
11 capital-gain 资本收益: continuous.
12 capital-loss资本损失: continuous.
13 hours-per-week 每周工作时长: continuous.
14  native-country原国际: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

 针对非数值型的属性,采用了one-hot 编码

分为两个文件:

dataLoader.py: csv文件读取,特征工程

lr.py:  模型训练  y=xw

          其中

                   x=[x,1]增广矩阵,

                   w =[b,w]增广矩阵

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 12 14:51:45 2023

@author: chengxf2
"""

import numpy as np
import pandas as pd
from random import shuffle
from math import floor, log


def sample(X, Y):                                 #X and Y are np.array
    randomize = np.arange(X.shape[0])
    np.random.shuffle(randomize)
    return (X[randomize], Y[randomize])


def split_valid_set(X, Y, percentage):
    m = X.shape[0]
    valid_size = int(floor(m * percentage))

    X, Y = sample(X, Y)
    X_valid, Y_valid = X[ : valid_size], Y[ : valid_size]
    X_train, Y_train = X[valid_size:], Y[valid_size:]

    return X_train, Y_train, X_valid, Y_valid

def dataProcess_Y(rawData):
    
    df_y = rawData['income']
    y = pd.DataFrame((df_y==' >50K').astype("int64"), columns=["income"])
    print('\n y',y.shape)
    return y

def dataProcess_X(rawData):

    #axis=1, 删除列 axis=0 删除 index
    if "income" in rawData.columns:
        Data = rawData.drop(["sex", 'income'], axis=1)
        #(32561, 13) 
    else:
        Data = rawData.drop(["sex"], axis=1)
    
    #读取非数字的column
    listObjectColumn = [col for col in Data.columns if Data[col].dtypes == "object"] 
    #数字的column
    listNonObjedtColumn = [x for x in list(Data) if x not in listObjectColumn] 
   

    ObjectData = Data[listObjectColumn]
    NonObjectData = Data[listNonObjedtColumn]

    #insert set into nonobject data with male = 0 and female = 1
    NonObjectData.insert(0 ,"sex", (rawData["sex"] == " Female").astype(int))
    #set every element in object rows as an attribute,相当于one-hot 编码
    ObjectData = pd.get_dummies(ObjectData)

    Data = pd.concat([NonObjectData, ObjectData], axis=1)
    Data_x = Data.astype("int64")
    # Data_y = (rawData["income"] == " <=50K").astype(np.int)
    print("\n data_x: ",Data_x.shape)
    #normalize
    Data_x = (Data_x - Data_x.mean()) / Data_x.std()

    return Data_x


def data_loader():
    
    trainData =  pd.read_csv("data/train.csv")
    testData =  pd.read_csv("data/test.csv")
    test_label = pd.read_csv("data/correct_answer.csv")
 
    # here is one more attribute in trainData
    x_train = dataProcess_X(trainData).drop(['native_country_ Holand-Netherlands'], axis=1).values
    x_test = dataProcess_X(testData).values
    
    
    y_train = dataProcess_Y(trainData).values
    y_test =  test_label['label'].values

    #x=>x[1,x]
    x_train = np.concatenate((np.ones((x_train.shape[0], 1)), x_train), axis=1)
    x_test = np.concatenate((np.ones((x_test.shape[0], 1)), x_test), axis=1)

    valid_set_percentage = 0.1
    X_train, Y_train, X_valid, Y_valid = split_valid_set(x_train, y_train, valid_set_percentage)
    
    return X_train, Y_train, X_valid, Y_valid ,x_test,y_test



import numpy as np

from numpy.linalg import inv
import matplotlib.pyplot as plt
from dataLoader import data_loader
from dataLoader import sample
import os
from math import floor, log
import pandas as pd


output_dir = "output/"





def sigmoid(z):
    res = 1 / (1.0 + np.exp(-z))
    return np.clip(res, 1e-8, (1-(1e-8)))






def valid(X, Y, w):
    a = np.dot(w,X.T)
    y = sigmoid(a)
    y_ = np.around(y)
    result = (np.squeeze(Y) == y_)
    print('Valid acc = %f' % (float(result.sum()) / result.shape[0]))
    return y_

def train(X_train, Y_train):
  
    n= len(X_train[0])
    print("\n n ",n)
    w = np.zeros(n)

    l_rate = 0.001
    batch_size = 32
    m = len(X_train)
    step_num = int(floor(m / batch_size))
    epoch_num = 30
    list_cost = []
    total_loss = 0.0
    
    
    for epoch in range(1, epoch_num):
        total_loss = 0.0
        X_train, Y_train = sample(X_train, Y_train)

        for idx in range(1, step_num):
            X = X_train[idx*batch_size:(idx+1)*batch_size]
            Y = Y_train[idx*batch_size:(idx+1)*batch_size]

            s_grad = np.zeros(len(X[0]))


            z = np.dot(X, w)
            y = sigmoid(z)
            #squeeze 即把shape中为1的维度去掉
            loss = y - np.squeeze(Y)
            cross_entropy = -1 * (np.dot(np.squeeze(Y.T), np.log(y)) + np.dot((1 - np.squeeze(Y.T)), np.log(1 - y)))/ len(Y)
            total_loss += cross_entropy

            grad = np.sum( X * (y-np.squeeze(Y)).reshape((batch_size, 1)), axis=0)
            # grad = np.dot(X.T, loss)
            w = w - l_rate * grad
            
        #print("\n epoch :%d, total_loss: %7.3f"%(epoch, total_loss/batch_size))

       

        list_cost.append(total_loss)

    # valid(X_valid, Y_valid, w)
    plt.plot(np.arange(len(list_cost)), list_cost)
    plt.title("Train Process")
    plt.xlabel("epoch_num")
    plt.ylabel("Cost Function (Cross Entropy)")
    plt.savefig(os.path.join(os.path.dirname(output_dir), "TrainProcess"))
    plt.show()

    return w

if __name__ == "__main__":
   
    X_train, Y_train, X_valid, Y_valid,x_test,y_test  = data_loader()
    w_train = train(X_train, Y_train)
    valid(X_valid, Y_valid, w_train)

    print("\n x_test",x_test.shape, "\t y_test ",y_test.shape,"\t w",w_train.shape)

    valid(x_test, y_test, w_train)

    df = pd.DataFrame({"id": np.arange(1, 16282), "label": y_test})
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    df.to_csv(os.path.join(output_dir + 'lr_output.csv'), sep='\t', index=False)

https://github.com/maplezzz/ML2017S_Hung-yi-Lee_HW
动手学深度学习——softmax回归(原理解释+代码详解)-CSDN博客

https://www.cnblogs.com/hider/p/15431858.html 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/248773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

linux内核使用ppm图片开机

什么是ppm图片 PPM&#xff08;Portable Pixmap&#xff09;是一种用于存储图像的文件格式。PPM图像文件以二进制或ASCII文本形式存储&#xff0c;并且是一种简单的、可移植的图像格式。PPM格式最初由Jef Poskanzer于1986年创建&#xff0c;并经过了多次扩展和修改。 PPM图像…

Axure的动态面板的使用

目录 1.什么是动态面板&#xff1f; 2.使用动态面板 ​编辑 轮播图 erp的登录系统 erp侧边栏 1.什么是动态面板&#xff1f; 动态面板是Axure的高级交互元件&#xff0c;由不同的状态面板组成&#xff0c;是我们制作交互过程中运用频率最高的元件&#xff0c;很多交互效果需…

【Docker】实战:nginx、redis

▒ 目录 ▒ &#x1f6eb; 导读开发环境 1️⃣ Nginx 拉取 Nginx 镜像nginx.conf启动 Nginx访问 Nginx 2️⃣ redis拉取 Redis 镜像启动 Redis 容器测试 Redis &#x1f4d6; 参考资料 &#x1f6eb; 导读 开发环境 版本号描述文章日期2023-12-15操作系统Win10 - 22H222621.2…

网站服务器/域名/备案到底有什么关联?

​  在一个网站的组成中&#xff0c;网站服务器、域名、备案这几个要素是要被常提到的。在谈及三者关联之前&#xff0c;我们先了解下三者的各自概念。 域名&#xff1a;它是网站的唯一标识符&#xff0c;通俗理解来说就是用户在浏览器地址栏中输入的网址。一般来说&#xff…

MySQL作为服务端的配置过程与实际案例

MySQL是一款流行的关系型数据库管理系统&#xff0c;广泛应用于各种业务场景中。作为服务端&#xff0c;MySQL的配置过程对于数据库的性能、安全性和稳定性至关重要。本文将详细介绍MySQL作为服务端的配置过程&#xff0c;并通过一个实际案例进行举例说明。 一、MySQL服务端配…

新手HTML和CSS的常见知识点

​​​​ 目录 1.HTML标题标签&#xff08;到&#xff09;用于定义网页中的标题&#xff0c;并按照重要性递减排列。例如&#xff1a; 2.HTML段落标签&#xff08;&#xff09;用于定义网页中的段落。例如&#xff1a; 3.HTML链接标签&#xff08;&#xff09;用于创建链接…

首次使用 git 配置 github,gitee 密钥

gitee 和 github 密钥配置 1. 检查配置信息 使用命令 git config --global --list 检查邮箱是否一致 不一致可以使用如下命令进行设置 git config --global user.name "name" git config --global user.email "emailqq.com" 2. 生成 SSH 密钥 # 为 G…

2023年19款最佳3D打印软件

https://wenku.baidu.com/view/c0551497cf7931b765ce0508763231126edb77e3.html 2023年19款最佳3D打印软件 有免费也有付费&#xff01;十款入门至专业级的3D打印软件推荐 【云图创智】23种最好用的3D打印软件&#xff0c;常用的3D打印软件都在这里了

HTML基础标签

但实际上无论声明为中文还是英文都可以写&#xff0c;中文/英文 主要是浏览器在进行调用翻译功能的时候&#xff0c;会按照声明的语言来进行翻译。 标签语义&#xff1a; 标签的属性一般都是在第一个标签中定义该标签效果所拥有的属性。 即标签的作用是什么 <>标签功能…

LeetCode(61)删除链表的倒数第 N 个结点【链表】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 删除链表的倒数第 N 个结点 1.题目 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5]示例…

关于“Python”的核心知识点整理大全22

目录 ​编辑 9.4.2 在一个模块中存储多个类 虽然同一个模块中的类之间应存在某种相关性&#xff0c;但可根据需要在一个模块中存储任意数量的 类。类Battery和ElectricCar都可帮助模拟汽车&#xff0c;因此下面将它们都加入模块car.py中&#xff1a; car.py my_electric_car…

Leetcode sql50基础题最后的4题啦

算是结束了这个阶段了&#xff0c;之后的怎么学习mysql的方向还没确定&#xff0c;但是不能断掉&#xff0c;而且路是边走边想出来的。我无语了写完了我点进去看详情都不让&#xff0c;还得重新开启计划&#xff0c;那我之前的题解不都没有了&#xff01;&#xff01; 1.第二高…

10个国内外素材网站,提供免费 Photoshop 素材下载资源

即时设计 被很多人视为免费的PS素材网站——即时设计提供了资源广场版块&#xff0c;方便用户查找材料。对于提供的PS材料&#xff0c;即时设计也做了详细的分类工作&#xff0c;用户可以根据不同的使用标签快速找到相应的PS材料。 进入资源广场&#xff0c;在搜索框中输入要…

jmeter,动态参数之随机数、随机日期

通过函数助手&#xff0c;执行以下配置&#xff1a; 执行后的结果树&#xff1a; 数据库中也成功添加了数据&#xff0c;对应字段是随机值&#xff1a;

C#Winform菜鸟驿站管理系统-快递信息管理界面多条件查询实现方法

1&#xff0c;具体的页面设计如下&#xff0c; 2&#xff0c; 关于下拉框数据填充实现&#xff0c;站点选择代码实现如下&#xff0c;因为站点加载在很多界面需要用到&#xff0c;所以把加载站点的方法独立出来如下&#xff1b; /// <summary>/// 加载站点下拉框/// <…

案例064:基于微信小程序的考研论坛设计与实现

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…

cat EOF快速创建一个文件,并写入内容

在linux系统中&#xff0c;如果你有这个需求 vi一个文件 /etc/docker/daemon.json 在这个文件中写入内容 { "registry-mirrors": ["https://iw3lcsa3.mirror.aliyuncs.com","http://10.1.8.151:8082"],"insecure-registries":[&quo…

本地项目添加到gitlab命令操作

gitlab上面创建一个跟项目名同名的文件夹 创建文件夹&#xff0c;填写信息 添加readme文档&#xff0c;先保存下创建的文件夹 回到项目&#xff0c;复制项目的git 地址 然后进入到本地项目的文件夹&#xff0c;如d:/workspace/spring-demo&#xff0c;右键打开git bash弹框 命令…

【深度学习】机器学习概述(二)优化算法之梯度下降法(批量BGD、随机SGD、小批量)

​ 文章目录 一、基本概念二、机器学习的三要素1. 模型a. 线性模型b. 非线性模型 2. 学习准则a. 损失函数b. 风险最小化准则 3. 优化机器学习问题转化成为一个最优化问题a. 参数与超参数b. 梯度下降法梯度下降法的迭代公式具体的参数更新公式学习率的选择 c. 随机梯度下降批量…

提升英语学习效率,尽在Eudic欧路词典 for Mac

Eudic欧路词典 for Mac是一款专为英语学习者打造的强大工具。无论您是初学者还是高级学习者&#xff0c;这款词典都能满足您的需求。 首先&#xff0c;Eudic欧路词典 for Mac具备丰富的词库&#xff0c;涵盖了各个领域的单词和释义。您可以轻松查询并学习单词的意思、用法和例…