Skip to content

Commit 3680413

Browse files
authored
Slight Improvements to GATE model (#213)
* enabled two more parameters to GATE model * added a env setup shell script * fixing tree attention * Reduced complexity of initial parameters in GATE
1 parent 0612db5 commit 3680413

File tree

4 files changed

+42
-9
lines changed

4 files changed

+42
-9
lines changed

examples/__only_for_dev__/adhoc_scaffold.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def print_metrics(y_true, y_pred, tag):
5050

5151
from pytorch_tabular import TabularModel # noqa: E402
5252
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig # noqa: E402
53-
from pytorch_tabular.models import CategoryEmbeddingModelConfig # noqa: E402
53+
from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402
5454

5555
data_config = DataConfig(
5656
# target should always be a list. Multi-targets are only supported for regression.
@@ -68,10 +68,10 @@ def print_metrics(y_true, y_pred, tag):
6868
fast_dev_run=True,
6969
)
7070
optimizer_config = OptimizerConfig()
71-
model_config = CategoryEmbeddingModelConfig(
71+
model_config = GatedAdditiveTreeEnsembleConfig(
7272
task="classification",
73-
# gflu_stages=3,
74-
# tree_depth=2,
73+
gflu_stages=3,
74+
tree_depth=2,
7575
# layers="4096-4096-512", # Number of nodes in each layer
7676
# activation="LeakyReLU", # Activation between each layers
7777
learning_rate=1e-3,

setup_env.sh

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
# Prompt the user for their name.
4+
echo "What is the name of the environment?"
5+
read -r name
6+
7+
mkdir -p .env
8+
# Create a virtual environment
9+
python3 -m venv .env/$name
10+
11+
# Activate the virtual environment
12+
source .env/$name/bin/activate
13+
14+
# Create a temporary requirements file
15+
# Read the contents of the file into a variable.
16+
contents=$(cat requirements.txt)
17+
# Replace all occurrences of ">=" with "==" in the variable.
18+
contents=$(echo "$contents" | sed 's/>=$/==/g')
19+
# Write the contents of the variable to the file.
20+
echo "$contents" > requirements.tmp
21+
22+
# Install the required dependencies from the temporary file
23+
pip install -r requirements.tmp
24+
25+
rm requirements.tmp
26+
27+
# Install an editable version of the package
28+
pip install -e .[dev]

src/pytorch_tabular/models/gate/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ class GatedAdditiveTreeEnsembleConfig(ModelConfig):
9494
default=0.0, metadata={"help": "Dropout rate for the feature abstraction layer. Defaults to 0.0"}
9595
)
9696

97-
tree_depth: int = field(default=5, metadata={"help": "Depth of the tree. Defaults to 5"})
97+
tree_depth: int = field(default=4, metadata={"help": "Depth of the tree. Defaults to 5"})
9898

9999
num_trees: int = field(
100-
default=20,
100+
default=10,
101101
metadata={"help": "Number of trees to use in the ensemble. Defaults to 20"},
102102
)
103103

src/pytorch_tabular/models/gate/gate_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def _build_network(self):
9595
)
9696
if self.tree_wise_attention:
9797
self.tree_attention = nn.MultiheadAttention(
98-
self.output_dim,
99-
1,
98+
embed_dim=self.output_dim,
99+
num_heads=1,
100+
batch_first=False,
100101
dropout=self.tree_wise_attention_dropout,
101102
)
102103

@@ -123,7 +124,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
123124
tree_input = torch.cat([tree_input, tree_output], 1)
124125
tree_outputs = torch.cat(tree_outputs, dim=-1)
125126
if self.tree_wise_attention:
126-
tree_outputs, _ = self.tree_attention(tree_outputs)
127+
tree_outputs = tree_outputs.permute(2, 0, 1)
128+
tree_outputs, _ = self.tree_attention(tree_outputs, tree_outputs, tree_outputs)
129+
tree_outputs = tree_outputs.permute(1, 2, 0)
127130
return tree_outputs
128131

129132

@@ -210,6 +213,8 @@ def _build_network(self):
210213
feature_mask_function=self.hparams.feature_mask_function,
211214
batch_norm_continuous_input=self.hparams.batch_norm_continuous_input,
212215
chain_trees=self.hparams.chain_trees,
216+
tree_wise_attention=self.hparams.tree_wise_attention,
217+
tree_wise_attention_dropout=self.hparams.tree_wise_attention_dropout,
213218
)
214219
# Embedding Layer
215220
self._embedding_layer = self._backbone._build_embedding_layer()

0 commit comments

Comments
 (0)