​ 最近工作中用到因果推理相关知识,发现自己基础太弱,读读论文填充一下。首先介绍一篇经典的论文。这篇文章提出了一种利用领域适应和深度神经网络表示学习的框架方法来进行反事实结果推理

  1. 公式化反事实推理问题为领域适应问题,更具体一点,转化为协变量转变问题。
  2. 利用深度神经网络表示,线性模型和变量选择来进行反事实推理。
  3. 利用reweighting samples的方法使treatment和control groups distribution balanced

Abstract

Observational studies are rising in importance due to the widespread accumulation of data in fields such as healthcare, education, employment and ecology. We consider the task of answering counterfactual questions such as, “Would this patient have lower blood sugar had she received a different medication?”. We propose a new algorithmic framework for counterfactual inference which brings together ideas from domain adaptation and representation learning. In addition to a theoretical justification, we perform an empirical comparison with previous approaches to causal inference from observational data. Our deep learning algorithm significantly outperforms the previous state-of-the-art.

​ 本文主要解决的问题类似于:对于一个患有低血糖的病人来说,接受不同的治疗方式是否是有用的?为了解决这个问题,本文基于域自适应和表征学习提出了一个用于进行反事实推理的算法框架。
​ 这里的域自适应指的是我们在一种样本集上训练出来的模型应该在该数据集的各领域内都是适应的,有点拗口,举个例子,我们希望预测对一个病人使用某种治疗手段是否会让它对病情好转,那么不管这个病人是高中学历还是大学学历,模型的预测结果都是一样的。但是我们实际的训练样本中,大多数的病人都是高中学历,那么我们使用这个模型去预测大学学历的病人时,效果就不一定会好。

​ 本文将域自适应和反事实的推理进行了结合,主要方法是提出了一种正则公式让处于不同干预的人群的representations的分布更加接近,后面会详细介绍是怎么做的,本文的主要contributions有以下两点:

  • 我们证明了如何将反事实推理问题看作一个域自适应问题,也可以看作一个特殊的协变量转移问题。

  • 我们推导出一系列反事实推理的表征算法。一种是基于线性模型和变量选择,另一种是基于表征的深度学习结构。

    最后,我们的方法要优于基于重加权的样本采样方式,证明了拉进treatment组和control组特征表示对于反事实的推理是有益的。

Problem Setup

​ 这一节作者介绍了一些在因果推理领域常用的定义及标识,这里简单介绍一下。

  • $\tau$:我们希望去预估的一些潜在的干预或者行为。比如对于低血糖的病人来说,我们有两种治疗手段。$t\in\tau$
  • $\chi$:用户侧的特征,代表这些病人的背景(contexts),$x\in\chi$
  • $y$:施加某种干预后的结果,例如给病人采用A手段进行治疗后血糖增加了多少。$Y_t(x)\in y$

​ 我们希望计算的是individualized treatment effect(ITE):$Y_1(x)-Y_0(x)$,但是在现实情况下,我们只能得到对用户施加两种treatment的其中一种的outcome。所以通常我们会计算拥有相似分布$x$的人群计算他们的average treatment effect(ATE):$ATE=E[Y_1(x)−Y_0(x)]$。一个更常用的方法是通过模型去预估ITE。

​ 其中$yi^F(x)$是我们观察到的事实结果,$y_i^{CF}$是模型预估出来的反事实结果。这种方式不同于普通的深度学习建模,会有如下问题:假设我们观察到的样本为集合$\hat{P}^F=\left{x_i,t_i\right}^n{i=1}$,在计算ITE时我们需要计算这群样本另一个treatment下的结果$\hat{P}^{CF}=\left{xi,1-t_i\right}^n{i=1}$。我们定义$P^F$是$\hat{P}^F$经过模型抽取特征后到分布factual distribution,同理$P^{CF}$是$\hat{P}^{CF}$经过模型抽取特征后到分布counterfactual distribution。显然这两个分布是不相等的,但如果这个分布的差异较大的话,模型可能会侧重学习到这两个分布的不同对于最终outcome的影响,而不是我们的treatment对于outcome的影响。在深度学习中这被叫做 covariate shift,同时这也是一种域自适应问题。举个列子,在进行随机实验时,不同的人群获得不同treatment的概率应该是随机的,但是在实际的观察研究中,往往有很多因素影响着treament的发放,这就导致了$t$与$x$不独立。所以上述的两个分布的差异会比较大。

Balancing counterfactual regression

​ 我们的方法如下图所示,首先我们对于Context学习一个表达$\Phi:\chi \to \mathbb{R}^d$,这里我们可以使用深度网络也可以使用类似对样本进行加权筛选的方式。同时我去学习一个函数$h$,这个函数通过预估对当前的用户表征施加不同的treatment后的结果,会同时权衡三个目标:1)使得我们能够观察到的正常样本的预测误差尽可能小。2)使得对于反事实样本的预测的误差尽可能小,这里对于反事实样本,由于我们并没有它的真实label,我们可以通过一些方式来构建一个伪label(后面会讲)。3)这个函数还要对不同干预下的representation进行一个平衡。
​ 针对第一点,可以通过使用最小化error和一些正则方式来确保这个误差尽可能小,第二点,我们利用的是最近邻的方法,来构造反事实,即$y{i:t_i=0}^{CF}=y{NN_i:i\neq0}^{CF}$其中$NN_i$表示最近邻的邻居。本质是在模拟样本的反事实,有点类似于matching的方法。平衡好不同干预下的representation。第三点,我们通过最小化discrepancy distance这是域自适应的距离度量公式,后文会有详细介绍。

​ 直观地说,减少treatment人群和control人群之间差异的表示防止模型在试图从事实领域推广到反事实领域时使用数据的“unreliable”方面。例如,如果在我们的样本中,几乎没有男性接受过药物A,那么推断男性对药物A的反应很容易出错,因此可能需要更谨慎地使用性别特征。

​ 我们希望最小化的损失函数如下:

​ 上式中,$\alpha,\gamma>0$是用来控制我们对于两个分布的相似度的限制力度,其中$\text{disc}$是在后文定义的一种度量分布之间距离的函数。$j(i)\in \text{arg min}_{j\in s.t. t_j=1-t_i}\text{d}(x_i,x_j)$ 即$j$是在样本空间里和$i$最相近的一个样本,注意每个样本只能被选择一次。

Method

​ 一个简单的方法去做样本间分布的平衡是只选择那些在treatment组和control组分布相似的特征,但是,那些被舍弃的不相似的特征往往是对预测结果很重要的,直接把他们忽略会影响预测的准确性。一种解决方式是限制这些不相似特征对于预测结果的影响,我们基于这种方式,学习一组稀疏的权重参数,通过平衡模型的预测能力已经样本之间的相似度来确定特征的重要性。作者介绍了两种方法,1)线性模型。2)深度模型。这里主要对深度模型做一些介绍

​ 如上图所示,作者在标准的前馈神经网络做了一些改进,首先,$dr$层通过学习一个表达$\Phi(x)$将输入特征$x$进行映射,这一这里使用的样本是没有带上treatment特征的,同时,这一层的输出会被用来计算上文提到的$\text{disc}\mathcal{H}(\hat{P}\Phi^F,\hat{P}\Phi^{CF})$,之后的$d_o$层会将treatment特征作为额外特征加入。同时,注意到treatment的还有一根线连接到$\text{disc}$的计算上,这是因为我们需要这个标签来指导当前样本是属于treatment组还是control组。
​ 为什么要在网络中间将treatment作为额外特征加入而不是在网络一开始就一起训练呢,个人理解这是因为treatment往往很稀疏,为了避免treatment特征在网络前半段被大量其他特征淹没,所以在网络中间部分加入。同时,这样生成的表征$\Phi$也完全代表了与treatment无关的特征。

​ 最后作者通过一系列公式证明了上文中提到的分布差异对反事实的预测结果是有影响的,这里就不重复证明了,感兴趣可以去看原论文。最后的最后作者也指出当我们遇到多treatment情况时,可能需要更适合的算法。