Abstract
Causal inference plays a critical role in decision-making processes about whether to provide treatment to individuals across various domains, such as education, medicine, and e-commerce. One of the fundamental tasks in causal inference is to estimate the individual treatment effect (ITE), which represents the effect of a treatment on an individual outcome. Recently, many studies have focused on estimating ITE from graph data taking into account not only the covariates of units but also connections among them. In such a case, the outcome of a unit can be affected by not only its own covariates and treatment but also those of its neighbors, which is referred to as interference. Existing methods have utilized graph neural networks (GNNs) to capture interference and achieved improvements in estimating ITE on graph data. However, these methods are not computationally efficient and therefore cannot be applied to large graph data. To overcome this problem, we propose a novel method that reduces redundant computation in interference modeling while maintaining the prediction performance of ITE estimation. Our key idea is to model the propagation of interference by aggregating the information of neighbors before training and preserve the aggregated results for training our networks. We conduct intensive experiments on graph data consisting of up to a hundred thousand units and millions of edges. We show that the proposed method achieves superior or comparable performance to the existing GNN-based methods in ITE estimation, while the proposed method can be executed much faster than GNN-based methods.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
1 Introduction
Estimating treatment effects is beneficial for decision-making across various domains, such as e-commerce (Nabi et al., 2022), education (Raudenbush & Schwartz, 2020), and medicine (Schnitzer, 2022). For instance, estimating treatment effects fosters the understanding of whether a promotion activity encourages customers to engage with the promoted item. Estimation of individual treatment effect (ITE), which is the effect of a treatment for a specific individual, has been a popular task because ITE can assess the effect of a promotion activity on a particular customer. ITE enables us to determine whether we should conduct a promotion on the unit.
In this study, we estimate ITE from a large observational graph data, which includes records of covariates of units, connections among units, treatment assignments, and observed outcomes. As neighboring units communicate with each other, the outcome of an individual can be influenced by the treatments assigned to its neighboring units. In causal inference, this phenomenon is known as interference (Rakesh et al., 2018). When interference propagates over an entire graph, it is referred to as networked interference (Ma & Tresp, 2021). We show an example of the networked interference in Fig. 1. An advertisement assigned to a unit can propagate among units and affect another user through multi-hop connections. Here, a unit isolated in the graph does not receive interference. We need to properly model networked interference; Otherwise, we end up with inaccurate ITE estimation.
Networked interference received by the black unit, the aggregation mechanism of GCN, and the aggregation mechanism of GraphSAGE with neighbor sampling. Blue units represent that units cause interference to the black unit and their information is captured for the black unit. White units represent that units cause interference to the black unit but their information is not captured for the black unit (Color figure online)
A prior study (Ma & Tresp, 2021) proposed methods to capture networked interference (an example is shown in Fig. 2a) by utilizing graph neural networks (GNNs), such as graph convolutional networks (GCNs) (Welling & Kipf, 2016) and GraphSAGE (Hamilton et al., 2017). The key idea of their methods is to use layer-by-layer GNNs to model the multi-hop propagations of interference, as shown in Fig. 2b. However, the GNN-based methods are computationally inefficient for training with a large input graph and require a huge amount of GPU memory. Despite that the original GraphSAGE attempts to reduce the heavy computation of GNNs by the neighbor sampling method (Hamilton et al., 2017), its naive application degrades the ITE estimation quality due to losing a part of the information of interference at every training iteration, as shown in Fig. 2c. This phenomenon can be seen in our experiments (shown in Sect. 5.4). Capturing full networked interference while reducing computation in information aggregation is a challenge for estimating ITE from a large graph.
To overcome this challenge, we propose the Scalable Individual Treatment Effect estimator on graphs (SITE), which can estimate ITE under networked interference comparably to the GNN-based ITE estimators (Ma & Tresp, 2021), while greatly reducing the training cost. We show the architecture of the proposed method in Fig. 3. We overcome this challenge by addressing some sub-challenges as follows: (I) How to reduce heavy computation of information aggregation. (II) How to reduce the cost of GPU memory for information aggregation. (III) How to capture the full network interference while overcoming (I) and (II). To overcome these sub-challenges, we first aggregate the information of other units on the graph only once before training using CPU and reserve the aggregated results for training. This jointly reduces the computation of information aggregation and the cost of GPU memory for information aggregation. Different from sampling-based methods, such as neighbor sampling (Hamilton et al., 2017), which typically reduces computation during each training iteration by limiting the number of units involved in the calculations and often fails to capture the full networked interference, as shown in Fig. 2c, our pre-aggregation enables us to reduce the repeated computation of interference modeling during training without losing information about networked interference. Next, we train a map function, which takes inputs as the aggregated results to generate the representations of units, which are supposed to capture networked interference. This map function is achieved by a simple neural network instead of GNN and is more efficient than GNN for training. Results of extensive experiments (as presented in Sect. 5) on five public datasets show that the proposed method can achieve close performance in ITE estimation under networked interference to the GNN-based ITE estimator (Ma & Tresp, 2021), while greatly reducing the cost in training.
The contributions of this study are summarized as: We introduce a new issue, i.e., ITE estimation under networked interference on a large graph. To overcome this issue, we propose a new framework to capture networked interference and estimate ITE from large graph data. We conducted extensive experiments to verify the efficacy of the proposed method for ITE estimation and its efficiency in the training.
2 Related work
Learning to predict ITE from observational data has received considerable attention in machine learning and data mining (Yao et al., 2021). However, most of the previous approaches have assumed that there is no interference among units. For example, balancing neural network (BNN) (Johansson et al., 2016) and counterfactual regression (CFR) (Shalit et al., 2017) adopt such assumption and employ a two-step process that uses multiple MLPs to learn balanced representations of covariates and to predict outcomes. Unfortunately, their assumption hardly holds for real-world data, as units in real-world data are usually connected and often propagate information (Wang et al., 2013; Li et al., 2015; He & McAuley, 2016), which results in interference. For a more thorough literature review on this line, confer the survey paper (Yao et al., 2021).
In the modeling of Group-level interference (Hudgens & Halloran, 2008; Liu & Hudgens, 2014; Tchetgen & VanderWeele, 2012) units are into several subgroups and interference is assumed to exist only within subgroups of units. Group-level interference may not hold in some real-world situations, as interference can occur among units in different subgroups. The follow-up studies (Aronow & Samii, 2017; Forastiere et al., 2021; Viviano, 2019) assume that units can receive interference from their close neighbors. Nevertheless, interference can propagate widely over a graph not only close neighbors. To take it into account, networked interference is assumed to consider interference propagation over an entire graph of units. Ma and Tresp (2021) applied the propagation mechanism of GNN (Welling & Kipf, 2016) to capture networked interference. Moreover, Ma et al. (2022) proposed a hypergraph-GCN-based method for dealing with hyper-order interference on hypergraphs. Lin et al. (2023) proposed a heterogeneous GCN-based method for modeling interference propagation on heterogeneous graphs. However, GNN-based methods are inefficient for training when graphs are large because we have to stack a sufficiently large number of layers to capture networked interference over the entire graph.
To accelerate GNN training, various sampling-based techniques have been proposed to accelerate GNN training. For example, neighbor sampling methods that sample neighbors for every node (Hamilton et al., 2017; Bojchevski et al., 2020; Zeng et al., 2020), layer-wise sampling method that samples neighbors over every GNN layer (Chen et al., 2018), and graph sampling methods that sample subgraphs for training (Feng et al., 2022; Shi et al., 2023). These sampling-based methods may generally fail to capture networked interference over the entire graph. By contrast, simplifying graph convolutional network (SGC) (Wu et al., 2019) takes a different approach from sampling to simplify the GCN computation by removing the nonlinear function in the message propagation function of a GCN. The advantage of SGC is reducing the cost of computation during training without losing information about neighbors. We extend SGC to ITE estimation under network interference and confounders (described in Sect. 3). This extension of SGC is a baseline in our experiments. We give a comparison of existing methods and the proposed method in Table 1.
3 Problem setting
In this study, we aim to estimate ITE from observational graphs, which contain covariates of units, connections among units, treatment assignments, and observed outcomes. Herein, we use \({\varvec{x}}_i \in {\mathbb {R}}^d\) to denote the covariates of the unit i, \(t_i \in \{0,1\}\) to denote the treatment assigned to the unit i, \(y_{i} \in {\mathbb {R}}\) to denote the observed or factual outcome under the received treatment of the unit i, and N to denote the number of units. Let \({\varvec{X}}\) be the covariates of all units, \({\varvec{T}}\) be all treatment assignments, and \({\varvec{Y}}\) be all observed outcomes. We use non-bold, italicized, and capitalized letters (e.g., \(X_i\)) to denote random variables, and letters with the subscript \(-i\) to denote all other units except i, such as \({\varvec{X}}_{-i}\). Importantly, we consider large-scale graphs, so N can be large, such as 100,000. A unit with \(t_i=1\) is treated, and \(t_i=0\) is controlled.
Graphs. Graph data usually include not only information about treatments and covariates of units but also information about connections among units. We use \({\varvec{A}} \in \{0,1\}^{N \times N}\) to denote the adjacency matrix of a directed graph. If there is an edge to a unit i from a unit j, \(A_{ij}\) = 1; otherwise, \(A_{ij} = 0\).
ITE estimation under networked interference. Observational graph can be denoted by \(({\varvec{X}},{\varvec{T}},{\varvec{Y}},{\varvec{A}})\). We assume that there exists networked interference among units in a graph. In this case, the outcome of a unit is not only influenced by its own treatments and covariates but also influenced by those of its immediate and multi-hop neighbors (Ma & Tresp, 2021). To formalize networked interference, we use \({\varvec{s}}_i\) to denote a summary representation of \({\varvec{X}}_{-i}\) and \({\varvec{T}}_{-i}\) on a graph \({\varvec{A}}\). Following the previous studies (Ma e al., 2022; Ma & Tresp, 2021; Lin et al., 2023), we assume that networked interference can be captured by some aggregation function \(\textrm{AGG}(\cdot )\) that merges information of the other units on graphs as \({\varvec{s}}_i=\text {AGG}({\varvec{T}}_{-i},{\varvec{X}}_{-i},{\varvec{A}})\). Here, \({\varvec{s}}_i\in {\mathbb {R}}^{d'}\) is supposed to capture the networked interference received by the unit i on the graph \({\varvec{A}}\), where \(d'\) depends on the dimension of the parameters used for \(\textrm{AGG}(\cdot )\). In this case, the potential outcomes of the unit i under networked interference \({\varvec{s}}_i\) and treatment value \(t_i=1\) and \(t_i = 0\) are denoted by \(y^1_i({\varvec{s}}_i)\) and \(y^0_i({\varvec{s}}_i)\), respectively. Then, we adopt the definition of ITE under networked interference provided by Ma and Tresp (2021):
Confounders. Confounders are a part of covariates that affect the treatment assignment and outcome jointly, which results in confounding bias to ITE estimation (Yao et al., 2021). In observational studies, the existence of confounders is a typical issue (Shalit et al., 2017). For instance, a scenario where a customer is treated with an advertisement. Young customers may have more chances to see an advertisement. Meanwhile, young customers often prefer shopping more than elderly customers. In this case, the age is a confounder. If such confounders are not accurately addressed, ITE estimation will be biased. We address confounders in our proposal independently from capturing networked interference.
Identifiability of ITE. Subsequently, we discuss that ITE is identifiable under a set of assumptions. First, we extend the neighbor interference assumption (Forastiere et al., 2021) to networked interference, as follows:
Assumption 1
For \(\forall i\), \(\forall {\varvec{T}}_{-i},{\varvec{T}}'_{-i},\forall {\varvec{X}}_{-i},{\varvec{X}}'_{-i}\), and \(\forall {\varvec{A}},{\varvec{A}}'\): when \({\varvec{s}}_i=\mathrm{{AGG}}({\varvec{T}}_{-i},{\varvec{X}}_{-i},{\varvec{A}}) = \mathrm{{AGG}}({\varvec{T}}'_{-i},{\varvec{X}}'_{-i},{\varvec{A}}') = {\varvec{s}}'_i\), \(Y^t_i(S_i={\varvec{s}}_i) = Y^t_i(S_i={\varvec{s}}'_i)\) hold.
This assumption allows for the existence of networked interference, i.e., the outcome of a unit can receive interference from other units on the graph through \({\varvec{s}}_i\), which is generated by the aggregation function \(\text {AGG}(\cdot )\). Next, we extend consistency assumption to networked interference (Forastiere et al., 2021), as follows:
Assumption 2
\(Y_i = Y_i^{t_i}(S_i = {\varvec{s}}_i)\) on the graph \({\varvec{A}}\) for the unit i with \(t_i\) and \({\varvec{s}}_i\).
This assumption means that the potential outcome is equal to the observed outcomes under \(t_i\) and \({\varvec{s}}_i\). Lastly, for simplicity, we adopt the following unconfoundedness assumption widely used in existing studies for interference (Lin et al., 2023; Ma e al., 2022):
Assumption 3
For any unit i, given the covariates, the treatment assignment and output of the aggregation function are independent of potential outcomes, i.e., \(T_i,S_i \perp \!\!\! Y_i^1({\varvec{s}}_i),Y_i^0({\varvec{s}}_i) \vert X_i\).
This assumption says that we can observe every feature that describes the difference between the treatment and the control group (Guo et al., 2020). Under the above assumptions, the expected potential outcomes \(Y_i^{t}({\varvec{s}}_i)\) (\(t=1\) or \(t=0\)) are identifiable, which can be seen as follows:
Based on the above proof, the identifiability of ITE can be straightforwardly derived. This suggests that once we accurately aggregate \({\varvec{X}}_{-i}\) and \({\varvec{T}}_{-i}\) on graphs \({\varvec{A}}\) into \({\varvec{s}}_i\), we can identify potential outcomes \(Y_i^{1}({\varvec{s}}_i)\), \(Y_i^{0}({\varvec{s}}_i)\) and ITE.
4 Proposed method
We estimate ITE under networked interference from large graph data \(({\varvec{X}},{\varvec{T}},{\varvec{Y}},{\varvec{A}})\). To this end, we propose SITE. The architecture of SITE is shown in Fig. 3. It contains four components. The first component aggregates neighbor information before training to avoid redundant computation in interference modeling (explained in Sect. 4.1). The second component \(\psi\) transforms aggregated results of the first component into summary representations, which are supposed to capture networked interference received by units (explained in Sect. 4.1). The third component \(\phi\) learns representations of covariates to address the confounding bias (explained in Sect. 4.2). Finally, the last component predicts potential outcomes using the covariate and summary representations (explained in Sect. 4.3). The codes of SITE are available at https://github.com/LINXF208/SITE/tree/main. Subsequently, we present details for these components of SITE.
4.1 Neighbor information aggregation and mapping functions
To model networked interference on a large graph, it is important to precisely and efficiently aggregate information for units from their one-hop and multi-hop neighbors. The core idea of our proposed architecture lies in reducing repeated computation of aggregation by aggregating information for every unit before training. This pre-process makes the training significantly faster than full training of GNNs, without sacrificing the information of networked interference.
Inspired by Ma and Tresp (2021), networked interference can be captured using multi-layered GCN (Welling & Kipf, 2016). Specifically, let \(\psi\) be a map function, \({\varvec{s}}^{(l)}_i\) be the summary representation of the unit i at the l-th layer, \({\varvec{S}}^{(l)}\) be summary representations of all units at the l-th layer, and \({\varvec{s}}^{(0)}_i\) be the concatenation of \({\varvec{x}}_i\) and \(t_i\). We add a self-loop for every unit to retain its information for propagation. Let \({\varvec{W}}^{(l)}_{\textrm{GNN}}\) be the parameter matrix of the l-th GCN layer, and \({\varvec{L}}=\hat{{\varvec{D}}}^{-\frac{1}{2}} \hat{{\varvec{A}}}^{-\frac{1}{2}}\hat{{\varvec{D}}}^{-\frac{1}{2}}\) be the Laplacian matrix of \({\varvec{A}}\) (Welling & Kipf, 2016). Here, \(\hat{{\varvec{A}}}\) is the adjacency matrix with the self-loop being counted, and \(\hat{{\varvec{D}}}\) is the degree matrix of \(\hat{{\varvec{A}}}\). The propagation of a single GCN layer is defined as follows:
where \(\sigma\) is a non-linear activation function, such as ReLU. Here, a single GCN layer can capture the interference from only immediate or one-hop neighbors of a unit. Networked interference can be captured by stacking multiple GCN layers. For example, the output of the first layer \({\varvec{S}}^{(1)}\) captures the interference of one-hop neighbors; by using \({\varvec{S}}^{(1)}\) as input, the second layer can capture the interference from two-hop neighbors. A K-layered GCN is given as follows:
However, repetitive computations of the layer-by-layer interference propagation severely slow down training and consume a huge amount of GPU memory when the graph is large.
Our key idea is to reduce repetitive computations of aggregation to efficiently obtain summary \({\varvec{s}}_i\). To this end, we simplify the computation of every GCN layer by removing non-linear activation functions and learnable parameters, which is inspired by Wu et al. (2019). In this case, the aggregations of multiple GCN layers can be simplified as \({\varvec{L}}^K{\varvec{S}}^{(0)}\). Let \({\varvec{G}}={\varvec{L}}^K{\varvec{S}}^{(0)}\) be summary vectors and \({\varvec{g}}_i\) be the summary vector of the unit i. To reduce the cost of GPU memory for aggregation, we calculate summary vectors \({\varvec{G}}\) only once before training by using CPUs and reserve calculated \({\varvec{G}}\) for training. This simultaneously reduces the cost of GPU memory and the computational cost of aggregation. Here, \({\varvec{G}}\) contains information on networked interference received by units. By taking \({\varvec{G}}\) as input, we train a feed-forward network \(\psi\). The feed-forward network \(\psi\) takes input as \({\varvec{G}}\) and is expected to generate summary representations \({\varvec{S}}\), defined as follows:
where L is the number of feed-forward layers. By using such \(\psi\), we can model networked interference while achieving efficient training without losing information about neighbors.
K and L are two different hyperparameters. A larger K represents that a unit can receive interference from more distant neighbors. Increasing the value of L can enhance the ability of the model to extract information from \({\varvec{G}}\).
Representation balancing. A bias might exist in interference (Jiang et al., 2023), which can result in additional bias in ITE estimation. Consider again the example for confounders in Sect. 3, young customers typically have more young friends, which usually have higher rates of both purchase and seeing advertisements, whereas elderly customers have more elderly friends. To mitigate the imbalance in interference, we add a discrepancy penalty to our loss function. A common choice of a discrepancy is a maximum mean discrepancy (MMD) (Shalit et al., 2017). Specifically, we minimize MMD between the distributions of summary representation in control and treated groups to mitigate bias in interference. Let \(\text {MMD}_{\psi }\) denote the estimated MMD of summary representations in different treatment groups.
4.2 Covariate representation learner
To mitigate confounding bias, a straightforward solution is to map the covariates into a representation space where the distributions of representations of covariates are aligned well (Johansson et al., 2016; Shalit et al., 2017). To this end, we minimize MMD between distributions of representations of covariates in control and treated groups, following (Shalit et al. 2017).
To be specific, let \(\phi\) be a feed-forward network, which takes covariates \({\varvec{X}}\) as input and outputs representations of covariates \({\varvec{U}}=\phi ({\varvec{X}})\). Here, let \({\varvec{u}}_i\) be the representation of the covariates of the unit i. Similar to representation balancing for interference, we minimize MMD between the representation distributions of covariates in control and treated groups to mitigate confounding bias (Shalit et al., 2017). Let \(\text {MMD}_{\phi }\) be estimated MMD of covariate representations in different treatment groups.
4.3 Outcome predictors and ITE estimator
Given summary representations \({\varvec{S}}\) (introduced in Sect. 4.1) and covariate representations \({\varvec{U}}\) (introduced in Sect. 4.2), we use feed-forward neural networks to parameterize predictors \(h_0\) and \(h_1\) for predicting potential outcomes of \(t=0\) and \(t=1\), respectively. These feed-forward neural networks are learned by minimizing the mean square error (MSE) between outputs of predictor \({\hat{y}}_i=h_{t_i}({\varvec{u}}_i,{\varvec{s}}_i)\) and the observed outcomes \(y_{i}\):
Let \(\Theta\) be all learnable parameters of the proposed SITE. We add L2 regularization into our loss function to avoid model over-fitting. The loss function \({\mathcal {L}}\) of the proposed SITE consists of \({\mathcal {L}}_{h}\), the sum of the \(\textrm{MMD}_{\psi }\) and \(\textrm{MMD}_{\phi }\), L2 regularization (denoted as \(\Vert \Theta \Vert ^2)\), and each term is traded off by hyperparameters \(\alpha\) and \(\lambda\), as follows:
With the outcome predictors, we can estimate the ITE under networked interference using \({\hat{\tau }}=h_1({\varvec{u}}_i,{\varvec{s}}_i)-h_0({\varvec{u}}_i,{\varvec{s}}_i).\)
4.4 Time complexity analysis
Here, we discuss the time complexity of SITE for training. Let B be the training batch size, N be the number of nodes, E be the number of edges. To simplify the results, let us set the same number of hidden layers and the same dimension for \(\phi\), \(\psi\), \(h_1\), and \(h_0\). Let L be the number of hidden layer, and D be the dimision of every hidden layer of \(\phi\), \(\psi\), \(h_1\), and \(h_0\). Moreover, our analysis is based on the straightforward runtime for performing matrix multiplications for simplicity. To be specific, the complexity of \(\phi\) of our method is BdDL, the complexity of \(\psi\) is \(B(d+1)DL\), and the complexity of outcome predictors \(h_1\) and \(h_0\) are \(2BD^2L\). By combining the complexity of different components, the complexity of SITE is \(B(2d+1)DL+4BD^2L\) for training. To have a closer look at the training efficiency of SITE, we also provide a comparison of the time complexity of the proposed method and existing methods, as shown in Table 2. This reveals that the training efficiency of the proposed SITE, BNN, and CFR depends on the batch size and feature dimension, whereas that of the GNN-based methods (Ma & Tresp, 2021) depends on the number of nodes and edges. This means that we can set an appropriate batch size for our method to make training efficient and not lose much performance.
5 Experiment
In this section, we conducted experiments to answer the following questions: Question 1: Can SITE capture networked interference, and deliver accurate estimation of ITE and outcome under networked interference? Question 2: Can SITE be trained and make predictions efficiently? Question 3: Are representation balancing and L2 regularization important to SITE? Question 4: How hyperparamters affect SITE?
5.1 Datasets
To verify the performance of SITE, we conducted experiments on five public datasets: Flickr dataset (Wang et al., 2013), BlogCatalog dataset (Li et al., 2015), Twitch dataset (Sarkar & Rózemberczki, 2021), Amazon positive (AMZ-P) dataset (He & McAuley, 2016), and Amazon negative (AMZ-N) dataset (He & McAuley, 2016).
Synthetic outcome. As the ground truth of ITE is hard to collect, similar to the outcome simulation in Ma et al. (2022), we transformed the original data and graph structures to simulate outcomes under interference for every unit:
where \(f_{0}({\varvec{x}}_i)={\varvec{w}}_0^{\top }{\varvec{x}}_i\) is the synthetic outcome of the unit i under treatment \(t_i=0\) without interference, and every element of \({\varvec{w}}_0\) independently follows either the Gaussian distribution \({\mathcal {N}}(0,1)\) or uniform distribution \({\mathcal {U}}(0,1)\). The ITE of the unit i is synthesized by \(f_{t}(t_i,{\varvec{x}}_i)=t_i\cdot {\varvec{w}}_1^{\top }{\varvec{x}}_i\), where \({\varvec{w}}_1\) also follows either the element-wise Gaussian or uniform distribution. We simulated effects caused by interference through \(f_s({\varvec{T}},{\varvec{X}},{\varvec{N}}_i)= \sum _{z=1}^{Z} f_{\text {Agg}}({\varvec{O}}^{z-1},{\varvec{N}}_i)\). Here, \({\varvec{N}}_i\) is the set of neighbors of the unit i, \({\varvec{O}}^0=\text {Concat}({\varvec{X}},{\varvec{T}})\), Z is a hyperparameter to simulate how far the interference propagates in a network, and the aggregation function \(f_{\text {Agg}}\) is defined as \(f_{\text {Agg}}({\varvec{O}},{\varvec{N}}_i)= \frac{1}{|{\varvec{N}}_{i}|}\sum _{j \in {\varvec{N}}_{i}}{\varvec{w}}_{\text {s}}^{\top }{\varvec{o}}_j\). Every element of \({\varvec{w}}_{\text {s}}\) also follows either the Gaussian or uniform distribution. Lastly, \(\epsilon _i\sim {\mathcal {N}}(0,1)\) is a random noise.
We synthesized outcomes for the Flickr, BlogCatalog, and Twitch datasets by using Eq. (7). For the AMZ-N and AMZ-P datasets, we used the outcome provided by Rakesh et al. (2018). Importantly, the outcomes of the AMZ-N and AMZ-P datasets are not generated by Eq. (7).
Dataset description. We now introduce the details of each dataset.
Flickr dataset (Wang et al., 2013): Flickr is an online social website, where users share their images. The Flickr dataset (Wang et al., 2013) is collected from this website. In this dataset, each unit is a user of Flickr. There are 7,575 units with 479,476 edges. Here, we aim to estimate how much recommending a hot photo (treatment) to a user affects the experience of the user (outcome) of this photo. In this case, users may share recommended photos with their friends (neighbors), which constitutes networked interference. We used the 1,206-dimensional embeddings of user profiles that were provided by Guo et al. (2020), and simulated the treatments as follows:
where \({\varvec{w}}_t\) is a vector in which every element follows \({\mathcal {U}}(-1,1)\) independently, and \(\epsilon _{t_i}\) is a random Gaussian noise.
BlogCatalog dataset (Li et al., 2015): BlogCatalog is an online community, where users post their blogs. The BlogCatalog (abbreviated as Blog) dataset (Li et al., 2015, 2019; Guo et al., 2020) is collected from the online community. Every node in the graphs is a user of BlogCatalog. There are 5,196 units with 343,486 edges. Here, we aim to estimate how much a recommended blog (treatment) to a user affects the experience of the user (outcome) of this blog. In this case, users may share recommended blogs with their friends (neighbors), which constitutes networked interference. The treatments for the Blog dataset were simulated using Eq. (8).
Twitch dataset (Sarkar & Rózemberczki, 2021): Twitch is an online website where users show and watch game lives. The Twitch dataset (Sarkar & Rózemberczki, 2021) is collected from this website. Every node in the graphs is a user of Twitch. There are 168,114 units with 6,797,557 edges. Here, we aim to estimate how much a recommended channel (treatment) to a user affects the experience of the user (outcome) of this channel. In this case, users may share recommended channels with other users that connect to them (neighbors), which constitutes networked interference. The treatments for the Twitch dataset were simulated using Eq. (8).
Amazon datasets (He & McAuley, 2016): Rakesh et al. (2018) extracted a co-purchase graph of the electronics category from the Amazon dataset (He & McAuley, 2016) to study the effect of positive or negative reviews on the sales of products and the issue of interference. Every unit in the Amazon dataset is an item, and every edge indicates that the two items are always purchased together by customers. Rakesh et al. (2018) split the dataset and generated the AMZ-N (only contains negative reviews) and AMZ-P (only contains positive reviews) datasets. Units are associated with a directed graph structure in the AMZ-N and AMZ-P datasets. In the AMZ-N dataset, there are 14,538 units with 15,011 connections. In the AMZ-P dataset, there are 42,134 units with 52,999 connections. The treatment \(t \in \{0,1\}\) depends on the number of negative (or positive) reviews: if a unit has more than three negative (or positive) reviews (\(t=1\)) or if a unit has less than three negative (or positive) reviews (\(t=0\)) (Rakesh et al., 2018). The covariate \({\varvec{x}}\) (with 300 features) of each unit is created by applying the doc2vec method (Le & Mikolov, 2014) to the review of the user. Our goal is to predict the real-valued sales y. As y often fluctuates in a large range, we apply the z-score normalization to y during the training and testing phases. The ITE under interference (i.e., \(\tau\)) of each unit is approximately estimated by matching methods (Rakesh et al., 2018).
5.2 Baselines
We compared the proposed method with the following baseline methods:
BNN (Johansson et al., 2016): BNN addresses confounders by minimizing the discrepancy (Mansour et al., 2009) of covariate distributions of different treatment groups. It is an MLP-based method that uses a single neural network to predict outcomes for different treatment groups. Importantly, BNN cannot handle interference. Following Johansson et al. (2016), we consider two structures: BNN-4-0, which has four representation layers and a linear output layer, and BNN-2-2, which has two representation layers, two hidden layers of prediction networks, and a linear output layer.
CFR (Shalit et al., 2017): CFR adopts the same strategy as BNN by minimizing the MMD (Gretton et al., 2012), and Wasserstein distance (Villani et al., 2009) between distributions of covariate representations of different treatment groups. It is also an MLP-based method that uses two neural networks to predict outcomes for different treatment groups. CFR also does not consider interference. Following Shalit et al. (2017), we consider two schemes: CFR-MMD and CFR-Wass, which minimize MMD and Wasserstein distance to address confounder, respectively.
TARNet (Shalit et al., 2017): TARNet has the same model architecture as the CFR but without the MMD or Wasserstein penalty terms.
GCN-based methods (Ma & Tresp, 2021): The GCN-based methods use GCN (Welling & Kipf, 2016) to capture networked interference and use the HSIC regularization (Gretton et al., 2005) to balance distributions of different treatment groups. This method is abbreviated as GCN-ITE.
GraphSAGE-based methods (Ma & Tresp, 2021): The GraphSAGE-based methods use GraphSAGE (Hamilton et al., 2017) to capture networked interference and use the HSIC regularization to balance distributions of different treatment groups. Here, Ma and Tresp (2021) removed the neighbor sampling mechanism of GraphSAGE. This neighbor sampling mechanism first randomly samples neighbors and then aggregates information of sampled neighbors only for every unit. In our experiments, we consider two schemes: GraphSAGE-ITE and GraphSAGE-ITE-S. GraphSAGE-ITE does not apply the neighbor sampling mechanism and aggregates information about the full neighbors of units. GraphSAGE-ITE-S is an extension of GraphSAGE (Hamilton et al., 2017) by applying the neighbor sampling mechanism to GraphSAGE-ITE.
SGC (Wu et al., 2019): SGC is a traditional GNN method for tasks on graphs, such as node classification. SGC simplifies the computation of graph convolution and removes the non-linear activation function. As a baseline, we extend SGC to estimate ITE on graphs by replacing the GCN of the GCN-based method (Ma & Tresp, 2021) with SGC. This extension of SGC is abbreviated as SGC-ITE.
5.3 Experimental setting
For all datasets, we calculated \(\epsilon _\mathrm{{MSE}}\) and \(\epsilon _\mathrm{{PEHE}}\) to evaluate the error on the outcome and ITE estimation, respectively. \(\epsilon _\mathrm{{MSE}}\) quantifies the performance in outcome prediction, while \(\epsilon _\mathrm{{PEHE}}\) quantifies the performance in ITE estimation. They are defined as follows:
We randomly partitioned all datasets into the ratio \(70\%/15\%/15\%\) as training/validation/test splits and averaged results over ten repeated executions. Following Ma and Tresp (2021), graph structure \({\varvec{A}}\), covariates \({\varvec{X}}\), and treatments \({\varvec{T}}\) were given during the training, validation, and testing phases; Whereas only observed outcomes of individuals in the training dataset were provided during training. We simulated outcomes for the Flickr, Blog, and Twitch datasets and computed \({\varvec{G}}\) (see Section 4.2) by using the CPU of AMD EPYC Milan 7763 DP/UP. We trained all methods by using the GPU of RTX A6000. We varied the parameter Z for the outcome simulation to simulate different levels of the effect caused by interference. Specifically, we set \(Z=2\) for the Blog and Flickr datasets and \(Z=3\) for the Twitch dataset.
We searched for hyperparameters by checking the results on the validation set. Here, we set \(K=2\) for SITE in experiments on the Blog and Flickr datasets, \(K=3\) for the Twitch dataset, and searched for K from \(\{2,3\}\) for the AMZ-N and AMZ-P datasets. We set the maximum training iterations to 2,000, the dimension for all layers to 100, the batch size to 1,024, \(\alpha\) to 0.1, the number of the layers for \(\phi\), \(\psi\), and h to 3, and searched \(\lambda\) in the range of \(\{0.0,0.001,0.01,0.025\}\). We adopted ReLU for the activation function. Early stopping and dropout were applied to the proposed method to avoid over-fitting.
We used the default hyperparameter or searched for hyperparameters from the ranges suggested in the literature to implement the baseline methods. To avoid over-fitting, we also applied early stopping and dropout to the baselines for all datasets. Here, we give a summary of the architectures of baseline methods in Table 3. In addition, we achieve GraphSAGE-ITE using the aggregation function of GraphSAGE in (Ma & Tresp, 2021). Here, as every unit has different numbers of neighbors, we aggregate neighbor information directly using \({\varvec{A}}\) to simply perform this method.
5.4 Outcome and ITE estimation performance
To answer Question 1, we conducted experiments on all five datasets. We first compared the performance of the proposed method with the baselines on small graphs (the Blog and Flickr datasets). We presented the results in Table 4. These results reveal that the performance of the proposed method is better or comparable to that of the GCN-ITE and GraphSAGE-ITE methods and is better than SGC-ITE and GraphSAGE-ITE-S methods in outcome and ITE estimation. This shows the powerful performance of the proposed SITE in capturing networked interference. It implies that propagating interference before training and subsequently learning the interference representation can effectively capture networked interference. Moreover, we can observe the performance gap in ITE estimation between GraphSAGE-ITE and GraphSAGE-ITE-S is large. This implies that aggregation with neighbor sampling cannot precisely capture networked interference, which results in degrading performance in ITE estimation.
Next, we validated the proposed method on large graphs (the AMZ-N, AMZ-P, and Twitch datasets) and presented the results in Table 5. Here, we train GCN-ITE and SGC-ITE on the AMZ-P datasets using a mini-batch size of 64, as they exceed the memory of the GPU when setting a large batch size. We can observe that the proposed method achieves superior or close performance to the GNN-based methods. We can also observe that there may be a decrease in performance for GCN-ITE and SGC-ITE when they are trained by using a small batch size. By contrast, most of the GNN-based methods exceed the memory of the GPU on the Twitch datasets. This reveals that the proposed method is scalable, and its capability of capturing networked interference is still effective for large graphs.
5.5 Comparison of computational efficiency
To investigate the answer to Question 2, we conducted experiments to measure the training time of different methods on all datasets using batch size of 1,024 and 2,000 iterations without early stopping. We present the results in Table 6. The results show that the training efficiency of the proposed SITE is close to that of MLP-based methods (such as CFR) and the changes in training time with different sizes of datasets are small. This reveals that the proposed method is scalable and efficient for training.
We also compared the inference time and GPU memory consumption of SITE and baseline methods. Results are shown in Table 7. These results reveal that the inference time and GPU memory consumption do not change significantly on different datasets with different dataset sizes.
5.6 Ablation study
To investigate the answer to Question 3, we conducted ablation experiments on the Flickr and Blog datasets. Here, we first introduce various variants of SITE. SITE-B represents the ablation of the MMD by setting \(\alpha =0\), and SITE-R represents the ablation of the L2 regularization by setting \(\lambda =0\).
The results of the ablation experiments are presented in Fig. 4. These results show degradation in ITE and outcome estimation without MMD regularization. This confirms that balancing representations of covariates and interference of different treatment groups is important. These results also show degradation in ITE and outcome estimation without L2 regularization. This confirms that L2 regularization is also important for the proposed method.
5.7 Sensitivity study
To answer Question 4, we conducted experiments on the Flickr dataset with different values of \(\alpha\) and \(\lambda\) (\(\{0.0,0.01,0.1,0.5,1.0\}\)). The results of sensitivity experiments are presented in Fig. 5. We can observe that the performance changes with different values of \(\alpha\) are small on the Flickr dataset. This reveals that the proposed method is not partially sensitive to the \(\alpha\). We also observe large performance changes with different values of \(\lambda\).
To further investigate the reason why the performance changes significantly with different values of \(\lambda\), we recorded \(\epsilon _{\text {MSE}}\) of training and validation set with \(\lambda =0.0\) and \(\lambda =1.0\) during training on the Flickr dataset, as presented in Fig. 6. Let \({\bar{\epsilon }}_{\text {MSE}}\) be the average value of \(\epsilon _{\text {MSE}}\) over every 20 iterations. We have \(M=100\) points, and \(\text {iterations} = M*20\). Results show when \(\lambda =0.0\), the model is over-fitting; When \(\lambda =1.0\), the model is under-fitting. We suggest searching \(\lambda\) in the range of [0.01, 0.1].
6 Conclusion
In this study, we introduce a new issue: ITE estimation under networked interference on a large graph. To overcome this, we proposed the SITE, which can model networked interference and estimate the ITE on large graphs. We performed and evaluated the proposed method on five public datasets to verify its effectiveness in ITE estimation and efficiency in training. The results reveal that the performance of the proposed methods is comparable to that of GNN-based methods (Ma & Tresp, 2021) and the proposed methods can have close training time to that of MLP-based method (Shalit et al., 2017), which greatly less than that of the GNN-based method (Ma & Tresp, 2021).
In this study, we consider a graph that has a single type of edge and interaction among units. However, the real-world graph can be complex, such as hypergraphs (Ma e al., 2022) and heterogeneous graphs (Lin et al., 2023). Extending our proposed SITE to address complex graphs can also be considered as future work.
References
Aronow, P. M., & Samii, C. (2017). Estimating average causal effects under general interference, with application to a social network experiment. The Annals of Applied Statistics, 11, 1912–1947.
Bojchevski, A., Gasteiger, J., Perozzi, B., Kapoor, A., Blais, M., Rózemberczki, B., Lukasik, M., Günnemann, S. (2020). Scaling graph neural networks with approximate pagerank. In: Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 2464– 2473
Chen, J., Ma, T., Xiao, C.(2018). FastGCN: Fast learning with graph convolutional networks via importance sampling. In: Proceedings of the 6th International Conference on Learning Representations
Feng, W., Dong, Y., Huang, T., Yin, Z., Cheng, X., Kharlamov, E., Tang, J. (2022). Grand+: Scalable graph random neural networks. In: Proceedings of the 31st ACM Web Conference 2022, pp. 3248– 3258
Forastiere, L., Airoldi, E. M., & Mealli, F. (2021). Identification and estimation of treatment and interference effects in observational studies on networks. Journal of the American Statistical Association, 116(534), 901–918.
Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). A kernel two-sample test. The Journal of Machine Learning Research, 13(1), 723–773.
Gretton, A., Bousquet, O., Smola, A., Schölkopf, B. (2005). Measuring statistical dependence with Hilbert-Schmidt norms. In: Proceedings of the 16th International Conference on Algorithmic Learning Theory, pp. 63– 77
Guo, R., Li, J., Liu, H. (2020). Learning individual causal effects from networked observational data. In: Proceedings of the 13th International Conference on Web Search and Data Mining, pp. 232– 240
Hamilton, W.L., Ying, R., Leskovec, J. (2017). Inductive representation learning on large graphs. In: Advances in Neural Information Processing Systems
He, R., McAuley, J. (2016). Ups and downs: Modeling the visual evolution of fashion trends with one-class collaborative filtering. In: Proceedings of the 25th World Wide Web Conference, pp. 507– 517
Hudgens, M. G., & Halloran, M. E. (2008). Toward causal inference with interference. Journal of the American Statistical Association, 103(482), 832–842.
Jiang, S., Huang, Z., Luo, X., Sun, Y.( 2023). CF-GODE: Continuous-time causal inference for multi-agent dynamical systems. In: Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 997– 1009
Johansson, F., Shalit, U., Sontag, D. ( 2016). Learning representations for counterfactual inference. In: Proceedings of the 33rd International Conference on Machine Learning, pp. 3020– 3029
Le, Q., Mikolov, T.(2014). Distributed representations of sentences and documents. In: Proceedings of 31st International Conference on Machine Learning, pp. 1188– 1196
Li, J., Hu, X., Tang, J., Liu, H.(2015). Unsupervised streaming feature selection in social media. In: Proceedings of the 24th ACM International on Conference on Information and Knowledge Management, pp. 1041– 1050
Li, J., Guo, R., Liu, C., Liu, H. (2019). Adaptive unsupervised feature selection on attributed networks. In: Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 92– 100
Lin, X., Zhang, G., Lu, X., Bao, H., Takeuchi, K., Kashima, H. (2023). Estimating treatment effects under heterogeneous interference. In: Joint European Conference on Machine Learning and Knowledge Discovery in Databases, pp. 576– 592 . Springer
Liu, L., & Hudgens, M. G. (2014). Large sample randomization inference of causal effects in the presence of interference. Journal of the American Statistical Association, 109(505), 288–301.
Ma, J., Wan, M., Yang, L., Li, J., Hecht, B., Teevan, J. (2022) Learning causal effects on hypergraphs. In: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 1202– 1212
Ma, Y., & Tresp, V. (2021). Causal inference under networked interference and intervention policy enhancement. In: Proceedings of the 24th International Conference on Artificial Intelligence and Statistics, 130, 3700–3708.
Mansour, Y., Mohri, M., Rostamizadeh, A. (2009). Domain adaptation: Learning bounds and algorithms. arXiv preprint arXiv:0902.3430
Nabi, R., Pfeiffer, J., Charles, D., & Kıcıman, E. (2022). Causal inference in the presence of interference in sponsored search advertising. Frontiers in Big Data, 5, 888592.
Rakesh, V., Guo, R., Moraffah, R., Agarwal, N., Liu, H.(2018). Linked causal variational autoencoder for inferring paired spillover effects. In: Proceedings of the 27th ACM International Conference on Information and Knowledge Management, pp. 1679– 1682
Raudenbush, S. W., & Schwartz, D. (2020). Randomized experiments in education, with implications for multilevel causal inference. Annual Review of Statistics and its Application, 7(1), 177–208.
Sarkar, R., Rózemberczki, B. (2021). Twitch gamers: A dataset for evaluating proximity preserving and structural role-based node embeddings. In: Workshop on Graph Learning Benchmarks@ TheWebConf 2021
Schnitzer, M. E. (2022). Estimands and estimation of COVID-19 vaccine effectiveness under the test-negative design: Connections to causal inference. Epidemiology, 33(3), 325.
Shalit, U., Johansson, F. .D., & Sontag, D. (2017). Estimating individual treatment effect: Gneralization bounds and algorithms. In: Proceedings of the 34th International Conference on Machine Learning, 70, 3076–3085.
Shi, Z., Liang, X., Wang, J. (2023). LMC: Fast training of gnns via subgraph sampling with provable convergence. In: Proceedings of 11th International Conference on Learning Representations
Tchetgen, E. J. T., & VanderWeele, T. J. (2012). On causal inference in the presence of interference. Statistical Methods in Medical Research, 21(1), 55–75.
Villani, C., et al. (2009). Optimal transport: Old and New. vol. 338. Springer
Viviano, D. (2019). Policy targeting under network interference. arXiv preprint arXiv:1906.10258
Wang, X., Tang, L., Liu, H., & Wang, L. (2013). Learning with multi-resolution overlapping communities. Knowledge and Information Systems, 36, 517–535.
Welling, M., Kipf, T.N. (2016). Semi-supervised classification with graph convolutional networks. In: Proceedings of the 4th International Conference on Learning Representations
Wu, F., Souza, A., Zhang, T., Fifty, C., Yu, T., Weinberger, K. (2019). Simplifying graph convolutional networks. In: Proceedings of 36th International Conference on Machine Learning, pp. 6861– 6871
Yao, L., Chu, Z., Li, S., Li, Y., Gao, J., & Zhang, A. (2021). A survey on causal inference. ACM Transactions on Knowledge Discovery from Data, 15(5), 1–46.
Zeng, H., Zhou, H., Srivastava, A., Kannan, R., Prasanna, V. (2020). GraphSAINT: Graph sampling based inductive learning method. In: Proceedings of the 8th International Conference on Learning Representations
Acknowledgements
This work was supported by JST SPRING, Grant Number JPMJSP2110.
Funding
JST SPRING Grant Number JPMJSP2110.
Author information
Authors and Affiliations
Corresponding author
Ethics declarations
Conflict of interest
Not applicable.
Ethics approval and consent to participate
Not applicable.
Additional information
Editors: Kee-Eung Kim, Shou-De Lin.
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Lin, X., Bao, H., Cui, Y. et al. Scalable individual treatment effect estimator for large graphs. Mach Learn 114, 23 (2025). https://doi.org/10.1007/s10994-024-06694-w
Received:
Revised:
Accepted:
Published:
Version of record:
DOI: https://doi.org/10.1007/s10994-024-06694-w





