因果推断14--DRNet论文和代码学习

目录

论文介绍

代码实现

DRNet

ReadMe

因果森林


论文介绍

因果推断3--DRNet(个人笔记)_万三豹的博客-CSDN博客

摘要:估计个体在不同程度的治疗暴露下的潜在反应,对于医疗保健、经济学和公共政策等几个重要领域具有很高的实际意义。然而,现有的从观察数据中估计反事实结果的学习方法要么专注于估计平均剂量-反应曲线,要么局限于只有两种没有相关剂量参数的治疗方法。在这里,我们提出了一种新的机器学习方法,用于学习反事实表示,用于使用神经网络估计具有连续剂量参数的任意数量治疗的单个剂量-反应曲线。在已建立的潜在结果框架的基础上,我们引入了性能指标、模型选择标准、模型架构和用于估计单个剂量反应曲线的开放基准。我们的实验表明,在这项工作中开发的方法在估计个体剂量反应方面设置了一个新的最先进的方法。

代码实现

GitHub - d909b/drnet: 💉📈 Dose response networks (DRNets) are a method for learning to estimate individual dose-response curves for multiple parametric treatments from observational data using neural networks.

DRNet

def get_method_name_map():
        return {
            'knn': KNearestNeighbours,
            'ols1': OrdinaryLeastSquares1,
            'ols2': OrdinaryLeastSquares2,
            'cf': CausalForest,
            'rf': RandomForest,
            'bart': BayesianAdditiveRegressionTrees,
            'nn': TFNeuralNetwork,
            'nn+': NeuralNetwork,
            'xgb': GradientBoostedTrees,
            'gp': GaussianProcess,
            'psm': PSM,
            'psmpbm': PSM_PBM,
            'ganite': GANITE,
            'gps': GPS,
        }

    def _build_graph(self, input_dim, num_units,
                     num_representation_layers, num_regression_layers, weight_initialisation_std,
                     reweight_sample=False, loss_function="l2",
                     imbalance_penalty_function="wass", rbf_sigma=0.1,
                     wass_lambda=10.0, wass_iterations=10, wass_bpt=True):
        """
        Constructs a TensorFlow subgraph for counterfactual regression.
        Sets the following member variables (to TF nodes):
        self.output         The output prediction "y"
        self.tot_loss       The total objective to minimize
        self.imb_loss       The imbalance term of the objective
        self.pred_loss      The prediction term of the objective
        self.weights_in     The input/representation layer weights
        self.weights_out    The output/post-representation layer weights
        self.weights_pred   The (linear) prediction layer weights
        self.h_rep          The layer of the penalized representation
        """
        ''' Initialize input placeholders '''
        self.x = tf.placeholder("float", shape=[None, input_dim], name='x')
        self.t = tf.placeholder("float", shape=[None, 1], name='t')
        self.y_ = tf.placeholder("float", shape=[None, 1], name='y_')

        ''' Parameter placeholders '''
        self.imbalance_loss_weight = tf.placeholder("float", name='r_alpha')
        self.l2_weight = tf.placeholder("float", name='r_lambda')
        self.dropout_representation = tf.placeholder("float", name='dropout_in')
        self.dropout_regression = tf.placeholder("float", name='dropout_out')
        self.p_t = tf.placeholder("float", name='p_treated')

        dim_input = input_dim
        dim_in = num_units
        dim_out = num_units

        weights_in, biases_in = [], []

        if num_representation_layers == 0:
            dim_in = dim_input
        if num_regression_layers == 0:
            dim_out = dim_in

        ''' Construct input/representation layers '''
        h_rep, weights_in, biases_in = build_mlp(self.x, num_representation_layers, dim_in,
                          self.dropout_representation, self.nonlinearity,
                          weight_initialisation_std=weight_initialisation_std)

        # Normalize representation.
        h_rep_norm = h_rep / safe_sqrt(tf.reduce_sum(tf.square(h_rep), axis=1, keep_dims=True))

        ''' Construct ouput layers '''
        y, y_concat, weights_out, weights_pred = self._build_output_graph(h_rep_norm, self.t, dim_in, dim_out,
                                                                          self.dropout_regression,
                                                                          num_regression_layers,
                                                                          weight_initialisation_std)

        ''' Compute sample reweighting '''
        if reweight_sample:
            w_t = self.t/(2*self.p_t)
            w_c = (1-self.t)/(2*(1-self.p_t))
            sample_weight = w_t + w_c
        else:
            sample_weight = 1.0

        self.sample_weight = sample_weight

        ''' Construct factual loss function '''
        if self.with_pehe_loss:
            risk = pred_error = tf.reduce_mean(sample_weight*tf.square(self.y_ - y)) + \
                                pehe_loss(self.y_, y_concat, self.t, self.x, self.num_treatments) / 10.
        elif loss_function == 'log':
            y = 0.995/(1.0+tf.exp(-y)) + 0.0025
            res = self.y_*tf.log(y) + (1.0-self.y_)*tf.log(1.0-y)

            risk = -tf.reduce_mean(sample_weight*res)
            pred_error = -tf.reduce_mean(res)
        else:
            risk = tf.reduce_mean(sample_weight*tf.square(self.y_ - y))
            pred_error = tf.sqrt(tf.reduce_mean(tf.square(self.y_ - y)))

        ''' Regularization '''
        for i in range(0, num_representation_layers):
            self.weight_decay_loss += tf.nn.l2_loss(weights_in[i])

        p_ipm = 0.5

        if self.imbalance_loss_weight_param == 0.0:
            imb_dist = tf.reduce_mean(self.t)
            imb_error = 0
        elif imbalance_penalty_function == 'mmd2_rbf':
            imb_dist = mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma)
            imb_error = self.imbalance_loss_weight * imb_dist
        elif imbalance_penalty_function == 'mmd2_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = self.imbalance_loss_weight * mmd2_lin(h_rep_norm, self.t, p_ipm)
        elif imbalance_penalty_function == 'mmd_rbf':
            imb_dist = tf.abs(mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma))
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'mmd_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'wass':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        elif imbalance_penalty_function == 'wass2':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        else:
            imb_dist = lindisc(h_rep_norm, p_ipm, self.t)
            imb_error = self.imbalance_loss_weight * imb_dist

        ''' Total error '''
        tot_error = risk
        if self.imbalance_loss_weight_param != 0.0:
            tot_error = tot_error + imb_error
        tot_error = tot_error + self.l2_weight*self.weight_decay_loss

        self.output = y
        self.tot_loss = tot_error
        self.imb_loss = imb_error
        self.imb_dist = imb_dist
        self.pred_loss = pred_error
        self.weights_in = weights_in
        self.weights_out = weights_out
        self.weights_pred = weights_pred
        self.h_rep = h_rep
        self.h_rep_norm = h_rep_norm

ReadMe

剂量反应网络(DRNets)是一种用于学习使用神经网络从观测数据估计多参数治疗的个体剂量反应曲线的方法。该存储库包含用于评估DRNets的源代码以及用于评估个体治疗效果的最相关的现有最先进方法(有关结果,请参阅我们的手稿)。为了便于将来的研究,源代码被设计为易于使用(1)新方法和(2)新的基准数据集进行扩展。

作者:Patrick Schwab,苏黎世ETHpatrick.schwab@hest.ethz.ch苏黎世联邦理工学院Lorenz Linhardtllorenz@student.ethz.ch,Stefan Bauer,MPI for Intelligent Systemsstefan.bauer@tuebingen.mpg.de,苏黎世联邦理工学院Joachim M.Buhmannjbuhmann@inf.ethz.ch苏黎世联邦理工学院Walter Karlenwalter.karlen@hest.ethz.ch

许可证:MIT,请参阅License.txt

引用

如果您在工作中引用或使用我们的方法、代码或结果,请考虑引用:

@在过程中{schwab2020-剂量反应,

title={{学习用于估计个体剂量响应曲线的反事实表示}},

作者={施瓦布、帕特里克和林哈特、洛伦兹和鲍尔、斯特凡和布曼、约阿希姆·M和卡伦、沃尔特},

booktitle={{AAAI人工智能会议}},

年={2020}

}

用法:

可运行的脚本位于drnet/apps/子目录中。

drnet/apps/main.py是运行实验的主要可运行脚本。

drnet/apps/parameters.py中描述了可运行脚本的可用命令行参数

您可以通过将drnet/models/baseline/baseline.py子类化,将新的基线方法添加到评估中

有关如何实现自己的基线方法的示例,请参见drnet/models/baselines/neural_network.py。

通过向drnet/apps/main.py中的get_method_name_map方法添加新条目,可以从命令行注册新方法以供使用

您可以通过实现基准接口来添加新的基准,有关如何将自己的基准添加到基准套件的示例,请参见drnet/models/benchmarks。

通过向drnet/apps/evaluate.py中的get_benchmark_name_map方法添加新条目,可以从命令行注册新的基准测试以供使用

要求和相关性

该项目设计用于Python 2.7。我们不能保证,也没有测试过与Python 3的兼容性。

要运行TCGA和News基准,需要下载包含这些基准的原始数据样本的SQLite数据库(News.db和TCGA.db)。

您可以使用以下链接下载原始数据:tcga.db和news.db。

请注意,您需要大约10GB的可用磁盘空间来存储数据库。

将数据库文件保存到/数据目录,以便与下面的分步指南兼容或相应地调整命令。

要运行MVICU基准测试,您需要访问MIMIC-III数据库,由于数据集的敏感性,这需要经过审批过程。

注意,您需要大约75GB的可用磁盘空间来存储带有索引的MIMIC-III数据库。

访问数据集并将MIMIC-III数据加载到SQLite数据库(保存为例如/your/path/to/mimic3.db)后,可以使用drnet/apps/load_db_icu.py脚本将MVICU基准数据从MIMIC-IIII数据库提取到中的单独数据库中/数据文件夹,通过运行:

python drnet/apps/load_db_icu.py/your/path/to/mimic3.db./data

一旦建立,基准数据库将使用大约43MB的磁盘空间。

要运行BART、因果森林和GPS,并再现需要安装R的数字。看见https://www.r-project.org/安装说明。

要运行BART,需要安装R包rJava和bartMachine。看见https://github.com/kapelner/bartMachine安装说明。注意,rJava也需要一个工作的Java安装。

要运行因果森林,需要安装R包grf。看见https://github.com/grf-labs/grf安装说明。

要运行GPS,您需要安装R包causaldrf,例如在R-shell中运行install.packages(“causaldrv”)。

要复制论文的数字,您需要安装R-packagelatex2exp。看见https://cran.r-project.org/web/packages/latex2exp/vignettes/using-latex2exp.html安装说明。

有关python依赖关系,请参阅setup.py。您可以使用pipinstall。安装drnet包及其python依赖项。请注意,如果您的系统上没有正常的R安装,rpy2的安装将失败(请参见上文)。

再现实验

确保您具备上面列出的必要要求,包括/与此文件相关的数据目录以及所需的数据库(参见上文)。

您可以使用脚本drnet/apps/run_all_experiments.py获取main.py使用的精确参数,以重现论文中的实验结果。

drnet/apps/run_all_experiments.py脚本打印。

因果森林

https://grf-labs.github.io/grf/articles/grf.h

return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

W
The treatment assignment (must be a binary or real numeric vector with no NAs).

W

治疗分配(必须是没有NA的二进制或实数矢量)。

class CausalForest(PickleableMixin, Baseline):
    def __init__(self):
        super(CausalForest, self).__init__()
        self.bart = None

    def install_grf(self):
        from rpy2.robjects.packages import importr
        import rpy2.robjects.packages as rpackages
        from rpy2.robjects.vectors import StrVector
        import rpy2.robjects as robjects

        # robjects.r.options(download_file_method='curl')

        # package_names = ["grf"]
        # utils = rpackages.importr('utils')
        # utils.chooseCRANmirror(ind=0)
        # utils.chooseCRANmirror(ind=0)
        #
        # names_to_install = [x for x in package_names if not rpackages.isinstalled(x)]
        # if len(names_to_install) > 0:
        #     utils.install_packages(StrVector(names_to_install))

        return importr("grf")

    def _build(self, **kwargs):
        from rpy2.robjects import numpy2ri
        from sklearn import linear_model
        grf = self.install_grf()

        self.grf = grf
        numpy2ri.activate()
        num_treatments = kwargs["num_treatments"]
        self.with_exposure = kwargs["with_exposure"]

        return [linear_model.Ridge(alpha=.5)] +\
               [None for _ in range(num_treatments)]

    def predict_for_model(self, model, x):
        base_y = Baseline.predict_for_model(self, self.model[0], x)
        if model == self.model[0]:
            return base_y
        else:
            import rpy2.robjects as robjects
            r = robjects.r
            result = r.predict(model, self.preprocess(x))
            y = np.array(result[0])
            return y[:, -1] + base_y

    def fit_grf_model(self, x, t, y):
        from rpy2.robjects.vectors import StrVector, FactorVector, FloatVector, IntVector
        return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

参考:

  1. 《因果学习周刊》第8期:因果反事实预测 - 知乎
  2. 因果效应估计:用数据和模型指导决策 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/10248.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GFD563A101 3BHE046836R0101

GFD563A101 3BHE046836R0101 ABB 7寸触摸屏 PP874K 3BSE069273R1 控制面板 原装进口 ABB 7寸触摸屏 PP874M 3BSE069279R1 黑色坚固 船用认证面板 ABB AC 800M PM865K01 处理器单元 3BSE031151R6 PLC库存 ABB AC 800M控制器模块 PM861AK01 3BSE018157R1 PM861A ABB AC 800PEC PC…

Kafka系统整理 一

一、Kafka 概述 1.1 定义 Kafka传统定义:Kafka是一个分布式的基于发布/订阅模式的消息队列 (Message Queue), 主要应用于大数据实时处理领域。 kafka最新定义:kafka是一个开源的分布式事件流平台(Event Streaming Platform), 被…

实验二 图像空间域频率域滤波

一.实验目的: 1. 模板运算是空间域图象增强的方法,也叫模板卷积。 (1)平滑:平滑的目的是模糊和消除噪声。平滑是用低通滤波器来完成,在空域中全是正值。 (2)锐化&…

Centos7安装部署Jenkins

Jenkins简介: Jenkins只是一个平台,真正运作的都是插件。这就是jenkins流行的原因,因为jenkins什么插件都有 Hudson是Jenkins的前身,是基于Java开发的一种持续集成工具,用于监控程序重复的工作,Hudson后来被…

【如何使用Arduino控制WS2812B可单独寻址的LED】

【如何使用Arduino控制WS2812B可单独寻址的LED】 1. 概述2. WS2812B 发光二极管的工作原理3. Arduino 和 WS2812B LED 示例3.1 例 13.2 例 24. 使用 WS2812B LED 的交互式 LED 咖啡桌4.1 原理图4.2 源代码在本教程中,我们将学习如何使用 Arduino 控制可单独寻址的 RGB LED 或 …

教育大数据总体解决方案(3)

为区县教育局提供标准制定、流程把控、实施监控、决策支持等服务,支持在全市统一的评价指标体系基础上,为各个区县提供个性化定制功能,各县能够在市统一评价指标体系内任意调整、增加二三级评价指标项,并可以调整对应指标项的分数…

SpringBoot 介绍

1.简介 SpringBoot最开始基于Spring4.0设计,是由Pivotal公司提供的框架。 SpringBoot发展史: 2003年Rod Johnson成立Interface公司,产品是SpringFramework2004年,Spring框架开源,公司改名为Spring Source2008年&…

我的面试八股(Java集合篇)

Java集合 两个抽象接口派生:一个是Collection接口,存放单一元素;一个是Map接口存放键值对。 Vector为什么是线程安全 简单,因为官方在可能涉及到线程不安全的操作都进行了synchronized操作,就自身源码就给你加了把锁。 Vector…

走进Vue【三】vue-router详解

目录🌟前言🌟路由🌟什么是前端路由?🌟前端路由优点缺点🌟vue-router🌟安装🌟路由初体验1.路由组件router-linkrouter-view2.步骤1. 定义路由组件2. 定义路由3. 创建 router 实例4. 挂…

【Spark】RDD缓存机制

1. RDD缓存机制是什么? 把RDD的数据缓存起来,其他job可以从缓存中获取RDD数据而无需重复加工。 2. 如何对RDD进行缓存? 有两种方式,分别调用RDD的两个方法:persist 或 cache。 注意:调用这两个方法后并不…

处理用户输入

shell脚本编程系列 传递参数 向shell脚本传递数据的最简单方法是使用命令行参数 比如 ./add 10 30读取参数 bash shell会将所有的命令行参数都指派给位置参数的特殊变量。其中$0对应脚本名、$1是第一个参数、$2是第二个参数,依次类推,直到$9 #!/bin/b…

【星界探索——通信卫星】铱星:从“星光坠落”到“涅槃重生”,万字长文分析铱星卫星系统市场

【星界探索——通信卫星】铱星:从“星光坠落”到“涅槃重生”一、铱星简介二、铱星系统设计思路2.1 工作原理2.2 铱星布局三、铱星优势四、发展历程五、第一代铱星公司的破产原因分析5.1 终端和资费价格高昂,市场用户群体小5.2 财务危机5.3 市场分析不足…

深入讲解Linux内核中常用的数据结构和算法

Linux内核代码中广泛使用了数据结构和算法,其中最常用的两个是链表和红黑树。 链表 Linux内核代码大量使用了链表这种数据结构。链表是在解决数组不能动态扩展这个缺陷而产生的一种数据结构。链表所包含的元素可以动态创建并插入和删除。链表的每个元素都是离散存…

每日一问-ChapGPT-20230409-中医基础-四诊之望诊

文章目录每日一问-ChapGPT系列起因每日一问-ChapGPT-20230409-中医基础-四诊之望诊中医中的望闻问切介绍,以及对应的名家望诊的具体细节望诊拓展当日总结每日一问-ChapGPT系列起因 近来看了新闻,看了各种媒体,抖音,官媒&#xff…

Python 小型项目大全 46~50

# 四十六、百万骰子投掷统计模拟器 原文:http://inventwithpython.com/bigbookpython/project46.html 当你掷出两个六面骰子时,有 17%的机会掷出 7。这比掷出 2 的几率好得多:只有 3%。这是因为只有一种掷骰子的组合给你 2(当两个…

ptuning v2 的 chatglm垂直领域训练记录

thunlp chatglm 6B是一款基于海量高质量中英文语料训练的面向文本对话场景的语言模型。 THUDM/ChatGLM-6B: ChatGLM-6B:开源双语对话语言模型 | An Open Bilingual Dialogue Language Model (github.com) 国内的一位大佬把chatglm ptuning 的训练改成了多层多卡并…

期刊论文图片代码复现【由图片还原代码】(OriginMatlab)

👨‍🎓个人主页:研学社的博客💥💥💞💞欢迎来到本博客❤️❤️💥💥🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密…

【Golang入门】简介与基本语法学习

下面是一篇关于Golang新手入门的博客,记录一下。(如果有语言基础基本可以1小时入门) 一、什么是Golang? Golang(又称Go)是一种由谷歌公司开发的编程语言。它是一种静态类型、编译型、并发型语言&#xff0…

【JLink仿真器】盗版检测、连接故障、检测不到芯片问题

【JLink仿真器】盗版检测、连接故障、检测不到芯片问题一、问题描述二、解决方法1、降低驱动(解决非法问题以及连接故障)2、SWD引脚被锁(解决检测不到芯片)三、说明一、问题描述 盗版检测:the connected probe appear…

【Linux】网络原理

本篇博客让我们一起来了解一下网络的基本原理 1.网络发展背景 关于网络发展的历史背景这种东西就不多bb了,网上很容易就能找到参考资料,我的专业性欠缺,文章参考意义也不大。这里只做简单说明。 网络发展经过了如下几个模式 独立模式&…
最新文章