How do causal explanations help (or hinder) learning? This document presents a computational model proposing that explanations function as attention masks during error-driven learning—they selectively weight which causal variables are updated when prediction errors occur.
This model was developed in Chapter II of my PhD dissertation (Causal Explanations and Continuous Computation, 2025). It applies to experimental data from a rule-inference paradigm where participants learned causal rules from observations, either with or without accompanying causal selection explanations.
Causal selection explanations cite specific causes from among multiple sufficient causes (e.g., “The match caused the fire” rather than “The oxygen caused the fire”). In our experiments, we compared learning outcomes when participants received causal selection explanations versus other types of explanations versus no explanations at all.
In earlier work (Navarre, Konuk, Bramley & Mascarenhas, 2024), we proposed a model of inferences from explanation inspired by the Rational Speech Act (RSA) framework (Frank & Goodman, 2012; Goodman & Frank, 2016). It says that listeners (explainees) use explanations in learning by:
Here I propose to capture the same patterns of inference using a different (neural) learning architecture: people learn from observations by updating an internal model of the causal system via gradient-based learning procedures. The key twist is a much simpler mechanism for explanation effects: when a variable is mentioned in a (causal selection) explanation, it inflects that learning process by amplifying the input activation of the neurons encoding the corresponding variable(s) by some constant factor. This can be modeled very easily using attention masks. Below I show an implementation that captures the results of our experiment and then some (see dissertation for a more detailed discussion of the advantages of the model). To make the proof-of-concept more salient, I make the basic learning architecture as minimal as possible.
First, let’s fetch the experimental data of responses per condition and plot it here again for reference.
# ##LOAD DATA
## We make the extraction procedures simpler than in the other script that accompanied the publication, because there are many things we'll not do, such as plotting most popular rules, etc. For those refer to the other repo.
#1. Simplified multmerge function.
multmerge <- function(mypath) {
# Get only .csv files
filenames <- list.files(path = mypath, pattern = "\\.csv$", full.names = TRUE)
if (length(filenames) == 0) {
stop("No CSV files found in: ", mypath)
}
# Read each .csv into a list of data frames
datalist <- lapply(filenames, function(x) {
# or read_csv(x) if you want
read.csv(x, header = TRUE, na.strings = c("", "NA"))
})
# Combine them into one data frame
merged_df <- bind_rows(datalist)
return(merged_df)
}
#2. Extracting the data
data_path <- "../data-files/actual-cause-grid-1" # Relative to models/ folder
if (!dir.exists(data_path)) {
stop("Directory not found: ", data_path,
". Make sure you downloaded data-files in the same parent folder as code-files.")
}
the_data <- multmerge(data_path)
df_test <- the_data %>%
filter(questionId == "trial_test")
#2. Simplified parser
parse_json_array <- function(x) {
# Remove brackets
x2 <- gsub("\\[|\\]", "", x)
if (nchar(x2) == 0) return(NA) # or return numeric(0)
# Split by comma
as.numeric(strsplit(x2, ",")[[1]])
}
# Apply it to each row in df_test
df_test$parsed_key <- lapply(df_test$test_key, parse_json_array)
df_test$parsed_resp <- lapply(df_test$response, parse_json_array)
df_test$proportion_correct <- mapply(function(key_vec, resp_vec) {
# 3a) If either is NA or empty, return NA
if (any(is.na(key_vec)) || any(is.na(resp_vec))) return(NA_real_)
if (length(key_vec) != length(resp_vec)) return(NA_real_)
# 3b) Compare element by element
sum(key_vec == resp_vec)
}, df_test$parsed_key, df_test$parsed_resp)
df_subject <- df_test %>%
group_by(subject_ID, group) %>%
summarise(
mean_accuracy = mean(proportion_correct, na.rm=TRUE),
.groups = "drop"
)
df_subject$Condition <- factor(df_subject$group,
levels = c(1,2,3),
labels = c("OBS", "AC","CS")
)
# Reorder the factor so that AC < OBS < CS
df_subject$Condition <- factor(df_subject$Condition,
levels = c("AC","OBS","CS")
)
df_means <- df_subject %>%
group_by(Condition) %>%
summarise(mean_acc = mean(mean_accuracy, na.rm=TRUE), .groups="drop")
custom_colors <- c("AC"="red","OBS"="blue","CS"="green")
accuracy_plot <- ggplot(df_subject, aes(x=Condition, y=mean_accuracy, fill=Condition)) +
geom_violin(width=1, alpha=0.4) +
# Overlay a point for the mean
geom_point(data=df_means,
aes(x=Condition, y=mean_acc),
shape=23, size=3, fill="red", color="black") +
# Label with numeric means
geom_text(data=df_means,
aes(x=Condition, y=mean_acc, label=round(mean_acc,2)),
vjust=-1, color="black", fontface="bold") +
scale_fill_manual(values=custom_colors) +
labs(
title="Observed Accuracy by Condition",
x="Condition",
y="Mean Accuracy"
) +
theme_minimal()Dissertation Reference: This figure corresponds to Figure 2.3 in Chapter II of the dissertation (Causal Explanations and Continuous Computation, 2025). The conditions are: OBS (Observations Only), AC (Any Cause explanations), and CS (Causal Selection explanations).
Observed accuracy distribution by experimental condition. Diamond markers indicate condition means.
The distribution of response times during the learning phase informs how we model individual differences in the number of “learning epochs.” An exponential distribution of RTs suggests exponentially distributed learning durations across participants.
We talked in the relevant thesis chapter about how the distribution
of response times for the crucial trial
(trial_observations)—where people sample from urns and form
a judgment about the underlying rule from their observations (+
explanations)—follows an exponential distribution. Relevant code is in
the cell below.
############################################################
## RT Distributions
############################################################
# We'll:
# (1) Identify the RT data for relevant trials and plot the distribution
# (2) Fit ex-Gaussian, normal, and poisson
# (3) Compare fits via AIC/BIC
############################################################
crucial_data <- the_data %>%
filter(questionId == "trial_observations")
# I remove one extreme outlier with a response times higher than 30mn. Note that this subject's responses are not otherwise weird (so it may not be a real outlier) and that such removal is conservative with respect to the ex-gaussian hypothesis (which would be even more favored if we were to include it)
crucial_data <- crucial_data %>%
filter(rt > 0, !is.na(rt), rt < 200000) # removing the most extreme outlier
options(scipen = 999)
# 2) Quick look at the distribution
rt_plot <- ggplot(crucial_data, aes(x=rt)) +
geom_density() +
labs(
title="Participant response time distributions for learning trial",
x="Response Time (ms)",
y="Density"
) +
theme_minimal()
rt_plot# ggsave("RT_distribution.pdf", plot= rt_plot, width = 6, height = 4)
# 3) Extract the vector of RT
rt_data <- crucial_data$rt
# 4) Fit distributions using GAMLSS:
# - Normal (NO)
# - Ex-Gaussian (exGAUS)
# - Exponential (EXP)
fit_norm <- gamlss(rt_data ~ 1, family=NO) # Normal## GAMLSS-RS iteration 1: Global Deviance = 6794.479
## GAMLSS-RS iteration 2: Global Deviance = 6794.479
## GAMLSS-RS iteration 1: Global Deviance = 6743.017
## GAMLSS-RS iteration 2: Global Deviance = 6724.212
## GAMLSS-RS iteration 3: Global Deviance = 6709.461
## GAMLSS-RS iteration 4: Global Deviance = 6702.68
## GAMLSS-RS iteration 5: Global Deviance = 6700.124
## GAMLSS-RS iteration 6: Global Deviance = 6699.248
## GAMLSS-RS iteration 7: Global Deviance = 6698.963
## GAMLSS-RS iteration 8: Global Deviance = 6698.869
## GAMLSS-RS iteration 9: Global Deviance = 6698.838
## GAMLSS-RS iteration 10: Global Deviance = 6698.829
## GAMLSS-RS iteration 11: Global Deviance = 6698.826
## GAMLSS-RS iteration 12: Global Deviance = 6698.825
## GAMLSS-RS iteration 13: Global Deviance = 6698.824
## GAMLSS-RS iteration 1: Global Deviance = 6852.432
## GAMLSS-RS iteration 2: Global Deviance = 6852.432
##
## ## Model comparison: AIC
## df AIC
## fit_exg 3 6704.824
## fit_norm 2 6798.479
## fit_exp 1 6854.432
##
## ## Model comparison: BIC
## df BIC
## fit_norm 2 6805.812
## fit_exg 3 6715.824
## fit_exp 1 6858.098
The model is a simple feed-forward neural network with:
The key innovation is the attention mechanism: explanations modulate input activations during backpropagation, effectively acting as attention masks that weight which variables are updated when errors occur.
Feedforward neural network architecture. Input nodes (A, B, C, D) correspond to urn draws, hidden layer processes feature combinations, and output node W predicts win/lose outcome.
The attention mechanism is the core innovation. Each explanation \(\xi\) highlights a subset of input variables as causally relevant. We formalize this as an attention mask applied to input activations.
For each variable \(x \in \{A, B, C, D\}\), the explanation \(\xi\) assigns:
\[\xi(x) = \begin{cases} +1 & \text{if variable } x \text{ is mentioned in the explanation} \\ -1 & \text{otherwise} \end{cases}\]
The attention weight for variable \(x\) given explanation \(\xi\) is:
\[\text{Att}(x, \xi) = \exp\bigl[\alpha \cdot \xi(x)\bigr]\]
where \(\alpha\) is the attention parameter controlling the strength of the attention effect.
The network’s input activation becomes:
\[x_{\text{input}} \leftarrow (\pm 1) \times \text{Att}(x, \xi)\]
Attention mechanism: Explanation ‘C caused the outcome’ amplifies input C (larger node) while attenuating A, B, D (smaller nodes).
At each observation \(\omega\), the model computes the squared error:
\[L_{\omega} = \bigl(y_{\omega} - \hat{y}_{\omega}\bigr)^2\]
where \(y_{\omega}\) is the target (\(\pm 1\) for win/lose) and \(\hat{y}_{\omega}\) is the network’s prediction.
Parameters are updated via gradient descent:
\[w_i \leftarrow w_i - \eta \cdot \frac{\partial L_{\omega}}{\partial w_i}\]
The attention mechanism affects this update by modulating input activations, which in turn affects gradient magnitudes for weights connected to attended vs. unattended inputs.
Dissertation Reference: The full theoretical motivation and mathematical derivation appear in Chapter II, Sections 2.4–2.6 of the dissertation. See Algorithm 6 (Attention-Based Inductive Learning) for the complete procedure.
We build two lists of training data, one for the Causal Selection condition, one for the Any Cause condition. They differ only with respect to the list of explanations that accompany each observation.
It is not necessary to build a list for the “Observation only” condition, because the model for that condition is just a special case of the model for the other two conditions, where the learning pipeline does not change as a function of the explanations given. We just implement that condition by feeding the training function the training data for CS or AC by forcing the alpha parameter to zero.
Each dataset is a list of the same four observation, with repetition counts for each observations that match the number of times they were repeated in the data presented to subjects, and a list of explanations attached to each observation, that depend on the particular condition.
# 'CS' condition (Causal Selection) data
training_data_CS <- list(
list(
input = c(A=1, B=-1, C=1, D=-1),
target = c(E=1),
possible_explanations = c("C"),
repetition_count = 4
),
list(
input = c(A=1, B=1, C=1, D=1),
target = c(E=1),
possible_explanations = c("C"),
repetition_count = 1
),
list(
input = c(A=-1, B=1, C=-1, D=-1),
target = c(E=-1),
possible_explanations = c("C"),
repetition_count = 1
),
list(
input = c(A=1, B=1, C=-1, D=-1),
target = c(E=-1),
possible_explanations = c("C&D"),
repetition_count = 4
)
)
# 'AC' condition (Any Cause) data
training_data_AC <- list(
list(
input = c(A=1, B=-1, C=1, D=-1),
target = c(E=1),
possible_explanations = c("A&C"),
repetition_count = 4
),
list(
input = c(A=1, B=1, C=1, D=1),
target = c(E=1),
possible_explanations = c("A", "D", "A&D", "A&C", "C&D", "A&C&D"),
repetition_count = 1
),
list(
input = c(A=-1, B=1, C=-1, D=-1),
target = c(E=-1),
possible_explanations = c("A", "D", "A&D", "A&C", "A&C&D"),
repetition_count = 1
),
list(
input = c(A=1, B=1, C=-1, D=-1),
target = c(E=-1),
possible_explanations = c("C", "D"),
repetition_count = 4
)
)# Testing data (16 possible input combos)
testing_data <- list(
# 1) A= 1, B= 1, C= 1, D= 1 => E=1
list(input = c(A=1, B=1, C=1, D=1), target = c(E=1)),
# 2) A= 1, B= 1, C= 1, D=-1 => E=1
list(input = c(A=1, B=1, C=1, D=-1), target = c(E=1)),
# 3) A= 1, B= 1, C=-1, D= 1 => E=1
list(input = c(A=1, B=1, C=-1, D=1), target = c(E=1)),
# 4) A= 1, B= 1, C=-1, D=-1 => E=-1
list(input = c(A=1, B=1, C=-1, D=-1), target = c(E=-1)),
# 5) A= 1, B=-1, C= 1, D= 1 => E=1
list(input = c(A=1, B=-1, C=1, D=1), target = c(E=1)),
# 6) A= 1, B=-1, C= 1, D=-1 => E=1
list(input = c(A=1, B=-1, C=1, D=-1), target = c(E=1)),
# 7) A= 1, B=-1, C=-1, D= 1 => E=1
list(input = c(A=1, B=-1, C=-1, D=1), target = c(E=1)),
# 8) A= 1, B=-1, C=-1, D=-1 => E=-1
list(input = c(A=1, B=-1, C=-1, D=-1), target = c(E=-1)),
# 9) A=-1, B= 1, C= 1, D= 1 => E=1
list(input = c(A=-1, B=1, C=1, D=1), target = c(E=1)),
# 10) A=-1, B= 1, C= 1, D=-1 => E=1
list(input = c(A=-1, B=1, C=1, D=-1), target = c(E=1)),
# 11) A=-1, B= 1, C=-1, D= 1 => E=-1
list(input = c(A=-1, B=1, C=-1, D=1), target = c(E=-1)),
# 12) A=-1, B= 1, C=-1, D=-1 => E=-1
list(input = c(A=-1, B=1, C=-1, D=-1), target = c(E=-1)),
# 13) A=-1, B=-1, C= 1, D= 1 => E=1
list(input = c(A=-1, B=-1, C=1, D=1), target = c(E=1)),
# 14) A=-1, B=-1, C= 1, D=-1 => E=1
list(input = c(A=-1, B=-1, C=1, D=-1), target = c(E=1)),
# 15) A=-1, B=-1, C=-1, D= 1 => E=-1
list(input = c(A=-1, B=-1, C=-1, D=1), target = c(E=-1)),
# 16) A=-1, B=-1, C=-1, D=-1 => E=-1
list(input = c(A=-1, B=-1, C=-1, D=-1), target = c(E=-1))
)The lists above do contain all the necessary information needed, but then it has to be turned into a format that can be read by our training function below. That involves:
“Unfolding” the repetition counts by turning the 4-observation lists into 10-item lists with appropriate repetitions
Randomizing the order
Turning the explanation into vector formats that can be used in the attention mechanism
We do that with the helper below:
### Generate Subject Training Data:
### This takes one of the training set list, with the appropriate observation, and then returns a set of 10 observations, with the appropriate repetitions, as well as a vector of explanation associated with each observation, again in accordance with the condition.
### It will be run once per subject, ensuring that the mapping between observations and the appropriate explanations stays constant throughout training epochs for a given subject.
generate_subject_training_data <- function(base_training_data) {
# For each observation in base_training_data, pick one explanation
# for all of its repeated trials, then shuffle.
expanded_list <- list()
input_neurons <- c("A","B","C","D")
for (obs in base_training_data) {
chosen_expl <- NULL
if (!is.null(obs$possible_explanations) && length(obs$possible_explanations) > 0) {
chosen_expl <- sample(obs$possible_explanations, 1)
}
# Build explanation vector
if (!is.null(chosen_expl)) {
vars <- strsplit(chosen_expl, "&")[[1]]
vars <- trimws(vars)
expl_vector <- setNames(rep(0, length(input_neurons)), input_neurons)
expl_vector[vars] <- 1
} else {
expl_vector <- setNames(rep(0, length(input_neurons)), input_neurons)
}
# replicate obs for repetition_count times
for (i in seq_len(obs$repetition_count)) {
expanded_list[[length(expanded_list)+1]] <-
list(input = obs$input,
target = obs$target,
explanation = expl_vector)
}
}
# Shuffle
expanded_list <- sample(expanded_list)
return(expanded_list)
}We define the central network training pipeline. It consists of standard backpropagation training, with the addition of the attention mechanism:
# Helper functions:
tanh_activation <- function(x) {
tanh(x)
}
tanh_derivative <- function(x) {
1 - tanh(x)^2
}
# Global learning parameters
learning_rate <- 1 # We fix this to 1 throughout. It is made a parameter here for readers interested in playing around with it.
num_hidden_neurons <- 4 # Same for this
### Early stopping rule just in case we do a manipulation error with the number of epochs; this will never be reached in all normal uses.
threshold <- 1e-6
max_epochs <- 5000
#training function
train_network <- function(
num_epochs,
num_hidden_neurons,
alpha,
training_data_sampled
) {
# 1) Define neuron sets
input_neurons <- c("A","B","C","D")
output_neurons <- c("E")
hidden_neurons <- paste0("H", seq_len(num_hidden_neurons))
num_inputs <- length(input_neurons)
num_hidden <- num_hidden_neurons
num_outputs <- length(output_neurons)
# 2) Initialize weights & biases
W_input_hidden <- matrix(
rnorm(num_inputs * num_hidden, mean=0, sd=0.1),
nrow = num_inputs, ncol = num_hidden,
dimnames = list(input_neurons, hidden_neurons)
)
W_hidden_output <- matrix(
rnorm(num_hidden * num_outputs, mean=0, sd=0.1),
nrow = num_hidden, ncol = num_outputs,
dimnames = list(hidden_neurons, output_neurons)
)
bias_hidden <- setNames(runif(num_hidden, min=-0.1, max=0.1), hidden_neurons)
bias_output <- setNames(runif(num_outputs, min=-0.1, max=0.1), output_neurons)
previous_loss <- Inf
# 3) Training loop
for (epoch in seq_len(num_epochs)) {
total_loss <- 0
# Shuffle the data inside each epoch
training_data_epoch <- sample(training_data_sampled)
for (sample in training_data_epoch) {
activations_input <- sample$input # e.g. c(A=..., B=..., C=..., D=...)
target_output <- sample$target # c(E= ±1)
explanation_vec <- sample$explanation[input_neurons] # here written as 0/1 vector, so we can use the ifelse() function below; but the notation in terms of -1/1 in the thesis chapter is more natural
# Single-alpha attention scheme:
# e^(alpha) for mentioned, e^(-alpha) for unmentioned
attention_vector <- ifelse(explanation_vec == 1, exp(alpha), exp(-alpha))
# Adjust inputs
activations_input_adj <- activations_input * attention_vector
# Forward pass
net_hidden <- as.vector(activations_input_adj %*% W_input_hidden + bias_hidden)
act_hidden <- tanh_activation(net_hidden)
net_output <- as.vector(act_hidden %*% W_hidden_output + bias_output)
act_output <- tanh_activation(net_output)
# Error
error_output <- act_output - target_output
loss_sample <- sum(error_output^2)/num_outputs
total_loss <- total_loss + loss_sample
# Backprop
delta_output <- error_output * tanh_derivative(net_output)
delta_hidden <- as.vector((W_hidden_output %*% delta_output) *
tanh_derivative(net_hidden))
# Gradients
grad_W_input_hidden <- (activations_input_adj %*% t(delta_hidden))
grad_bias_hidden <- delta_hidden
grad_W_hidden_output<- (act_hidden %*% t(delta_output))
grad_bias_output <- delta_output
# Update (no regularization)
W_hidden_output <- W_hidden_output - learning_rate * grad_W_hidden_output
bias_output <- bias_output - learning_rate * grad_bias_output
W_input_hidden <- W_input_hidden - learning_rate * grad_W_input_hidden
bias_hidden <- bias_hidden - learning_rate * grad_bias_hidden
}
# Early stopping
avg_loss <- total_loss / length(training_data_sampled)
if (abs(previous_loss - avg_loss) < threshold) {
break
}
previous_loss <- avg_loss
}
# 4) Store final network
trained_network <- list(
W_input_hidden = W_input_hidden,
W_hidden_output= W_hidden_output,
bias_hidden = bias_hidden,
bias_output = bias_output
)
# 5) Evaluate on the test data with the final network
total_error_original <- 0
for (test_sample in testing_data) {
test_inp <- test_sample$input
target_out <- test_sample$target
net_h <- as.vector(test_inp %*% W_input_hidden + bias_hidden)
act_h <- tanh_activation(net_h)
net_o <- as.vector(act_h %*% W_hidden_output + bias_output)
act_o <- tanh_activation(net_o)
# binarize
pred <- ifelse(act_o>0, +1, -1)
err <- sum(abs(pred - target_out))/2
total_error_original <- total_error_original + err
}
num_samples_test <- length(testing_data)
accuracy_original <- (num_samples_test - total_error_original)
# 6) Return final object
list(
W_input_hidden = W_input_hidden,
W_hidden_output = W_hidden_output,
bias_hidden = bias_hidden,
bias_output = bias_output,
accuracy_original = accuracy_original
)
}This wraps around the network training function to call it and collect the results.
### Then a function to aggregate the results of several networks as we generate predictions out of each of them:
process_distribution <- function(dist_name,
sampled_epochs,
num_hidden_neurons,
alpha,
training_data)
{
# Filter out invalid epochs
sampled_epochs <- sampled_epochs[sampled_epochs>0 & sampled_epochs<=max_epochs]
if (length(sampled_epochs)==0) {
warning("No valid epochs for ", dist_name)
return(NULL)
}
results_list <- list()
for (num_epochs in sampled_epochs) {
# 1) Generate subject-specific training data
training_data_subj <- generate_subject_training_data(training_data)
# 2) Train
res <- train_network(
num_epochs = num_epochs,
num_hidden_neurons = num_hidden_neurons,
alpha = alpha,
training_data_sampled= training_data_subj
)
results_list[[length(results_list)+1]] <- res
}
# 3) Build data frame from results
df <- data.frame(
Epochs = sampled_epochs,
Accuracy_Original = sapply(results_list, function(x) x$accuracy_original),
Distribution = dist_name,
NumHiddenNeurons = num_hidden_neurons,
Alpha = alpha
)
df
}The attention-based model predicts a specific ordering of accuracy across conditions: AC < OBS < CS. This matches the experimental data, while alternative models (e.g., Bayesian rule inference) predict different orderings.
# =============================================================================
# TOY PARAMETERS SIMULATION (α=1, λ=1000)
# Results are cached to avoid re-running expensive simulations
# =============================================================================
cache_file_toy <- "cache/toy_params_results.rds"
# Check if cached results exist
if (file.exists(cache_file_toy)) {
cat("Loading cached toy parameters results...\n")
cached_toy <- readRDS(cache_file_toy)
df_AC <- cached_toy$df_AC
df_CS <- cached_toy$df_CS
df_OBS <- cached_toy$df_OBS
} else {
cat("Running toy parameters simulation (n=1000)...\n")
# Parameters for this demonstration
rate <- 1000
num_samples <- 1000 # Full simulation
num_hidden <- 4
alpha <- 1
# Generate the distribution of epochs
set.seed(999)
sampled_epochs <- rexp(num_samples, rate=rate)
sampled_epochs <- pmax(1, round(sampled_epochs))
# (a) AC condition, alpha=1
df_AC <- process_distribution(
dist_name = "AC_alpha1",
sampled_epochs = sampled_epochs,
num_hidden_neurons = num_hidden,
alpha = alpha,
training_data = training_data_AC
)
# (b) CS condition, alpha=1
df_CS <- process_distribution(
dist_name = "CS_alpha1",
sampled_epochs = sampled_epochs,
num_hidden_neurons = num_hidden,
alpha = alpha,
training_data = training_data_CS
)
# (c) Observation-only condition => alpha=0
df_OBS <- process_distribution(
dist_name = "OBS_alpha0",
sampled_epochs = sampled_epochs,
num_hidden_neurons = num_hidden,
alpha = 0,
training_data = training_data_AC
)
# Save to cache
dir.create("cache", showWarnings = FALSE)
saveRDS(list(df_AC = df_AC, df_CS = df_CS, df_OBS = df_OBS), cache_file_toy)
cat("Results cached to", cache_file_toy, "\n")
}## Loading cached toy parameters results...
# 4) Combine
df_combined <- bind_rows(df_AC, df_OBS, df_CS)
# 5) Plot distributions of accuracy on a violin plot with means
# Tag each distribution with a Condition label
df_combined <- df_combined %>%
mutate(Condition = case_when(
grepl("AC", Distribution) ~ "AC",
grepl("OBS", Distribution) ~ "OBS",
grepl("CS", Distribution) ~ "CS",
TRUE ~ "Unknown"
))
df_combined$Condition <- factor(df_combined$Condition, levels=c("AC","OBS","CS"))
# Colors
custom_colors <- c("AC"="red", "OBS"="blue", "CS"="green")
# Summarise means
df_means_model <- df_combined %>%
group_by(Condition) %>%
summarise(mean_acc = mean(Accuracy_Original))
print(df_means_model)## # A tibble: 3 × 2
## Condition mean_acc
## <fct> <dbl>
## 1 AC 13.1
## 2 OBS 13.3
## 3 CS 13.6
model_acc <- ggplot(df_combined, aes(x=Condition, y=Accuracy_Original, fill=Condition)) +
geom_violin(width=1, alpha=0.4) +
geom_point(data=df_means_model,
aes(x=Condition, y=mean_acc),
shape=23, size=3, fill="red", color="black") +
geom_text(data=df_means_model,
aes(x=Condition, y=mean_acc,
label=round(mean_acc,2)),
vjust=-1, color="black", fontface="bold") +
scale_fill_manual(values=custom_colors) +
labs(
title="Model accuracy distributions for alpha = 1, lambda = 1000 ",
x="Condition", y="Accuracy Score"
) +
theme_minimal()Dissertation Reference: This demonstration plot corresponds to Figure 2.5 in Chapter II. The toy parameters (α=1, λ=1000) show the model captures the correct qualitative pattern before parameter fitting.
Model accuracy distributions with toy parameters (α=1, λ=1000). The ordering AC < OBS < CS matches the experimental data.
Two things are apparent from visual inspection:
df_stats_emp <- df_subject %>%
group_by(Condition) %>%
summarise(
observed_mean = mean(mean_accuracy, na.rm=TRUE),
observed_sd = sd(mean_accuracy, na.rm=TRUE),
.groups="drop"
)
# Extract each condition’s stats
mu_AC <- df_stats_emp$observed_mean[df_stats_emp$Condition=="AC"]
mu_OBS <- df_stats_emp$observed_mean[df_stats_emp$Condition=="OBS"]
mu_CS <- df_stats_emp$observed_mean[df_stats_emp$Condition=="CS"]
### The SDs are collected just for inspection
sigma_AC <- df_stats_emp$observed_sd[df_stats_emp$Condition=="AC"]
sigma_OBS <- df_stats_emp$observed_sd[df_stats_emp$Condition=="OBS"]
sigma_CS <- df_stats_emp$observed_sd[df_stats_emp$Condition=="CS"]
cat("Observed Means:\n")## Observed Means:
## AC = 11.0102
## OBS = 11.90291
## CS = 12.76289
## Observed SDs:
## AC = 1.880419
## OBS = 2.315822
## CS = 2.330831
get_predicted_stats <- function(
condition,
exp_rate,
alpha,
num_samples = 50,
num_hidden = 4,
training_data_AC,
training_data_CS
) {
# 1) pick the appropriate training_data
if (condition=="AC") {
base_data <- training_data_AC
} else if (condition=="CS") {
base_data <- training_data_CS
} else if (condition=="OBS") {
base_data <- training_data_AC
} else {
stop("Condition must be AC, CS, or OBS")
}
# 2) If condition=="OBS", force alpha=0
if (condition=="OBS") {
alpha_used <- 0
} else {
alpha_used <- alpha
}
# 3) Sample epoch distribution
# We do rexp(num_samples, rate=exp_rate), round them, etc.
set.seed(999) # optional, for reproducibility
sampled_epochs <- rexp(num_samples, rate=exp_rate)
sampled_epochs <- pmax(1, round(sampled_epochs) )
dist_name <- paste0(condition, "_Alpha_", alpha_used, "_Exp_", exp_rate)
# 4) Call aggregator
df_res <- process_distribution(
dist_name = dist_name,
sampled_epochs = sampled_epochs,
num_hidden_neurons= num_hidden,
alpha = alpha_used,
training_data = base_data
)
# 5) from df_res, we get the distribution of Accuracy_Original
predicted_mean <- mean(df_res$Accuracy_Original, na.rm=TRUE)
# predicted_sd <- sd(df_res$Accuracy_Original, na.rm=TRUE)
# Return both as a list or named vector
return(list(
mean = predicted_mean
# sd = predicted_sd
))
}Function to compute SSE given those stats:
sse_total <- function(
M_AC,
M_OBS,
M_CS,
mu_AC,
mu_OBS,
mu_CS
) {
# difference-based SSE (3 pairwise differences)
diff_ac_obs_model <- (M_AC - M_OBS)
diff_cs_obs_model <- (M_CS - M_OBS)
diff_cs_ac_model <- (M_CS - M_AC)
diff_ac_obs_data <- (mu_AC - mu_OBS)
diff_cs_obs_data <- (mu_CS - mu_OBS)
diff_cs_ac_data <- (mu_CS - mu_AC)
SSE_diff <- (diff_ac_obs_model - diff_ac_obs_data)^2 +
(diff_cs_obs_model - diff_cs_obs_data)^2 +
(diff_cs_ac_model - diff_cs_ac_data)^2
return(SSE_diff)
}We examine how accuracy changes as a function of the attention parameter α, with the number of epochs fixed to 1 (equivalent to λ=1000).
Dissertation Reference: This analysis corresponds to Figure 2.7 in Chapter II, showing how both CS and AC conditions benefit from low attention levels, but diverge as α increases.
############################################################
## Evolution of accuracy as function of alpha (epochs fixed to 1)
## Results are cached to avoid re-running expensive simulations
############################################################
cache_file_alpha <- "cache/alpha_analysis_results.rds"
# Check if cached results exist
if (file.exists(cache_file_alpha)) {
cat("Loading cached alpha analysis results...\n")
cached_alpha <- readRDS(cache_file_alpha)
results_ac <- cached_alpha$results_ac
results_cs <- cached_alpha$results_cs
results_obs <- cached_alpha$results_obs
} else {
cat("Running alpha analysis simulations (n=500 per condition per alpha)...\n")
alpha_values <- c(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1, 1.5, 2)
# Store results
results_ac <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())
results_cs <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())
results_obs <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())
# For each alpha, run 500 simulations with 1 epoch each (matching dissertation)
num_runs <- 500
fixed_epochs <- rep(1, num_runs)
for (a in alpha_values) {
cat(" Processing alpha =", a, "\n")
# AC condition
df_ac <- process_distribution(
dist_name = paste0("AC_alpha=", round(a, 3)),
sampled_epochs = fixed_epochs,
num_hidden_neurons = 4,
alpha = a,
training_data = training_data_AC
)
mean_acc_ac <- mean(df_ac$Accuracy_Original)
results_ac <- rbind(results_ac, data.frame(alpha=a, mean_accuracy=mean_acc_ac, Condition="AC"))
# CS condition
df_cs <- process_distribution(
dist_name = paste0("CS_alpha=", round(a, 3)),
sampled_epochs = fixed_epochs,
num_hidden_neurons = 4,
alpha = a,
training_data = training_data_CS
)
mean_acc_cs <- mean(df_cs$Accuracy_Original)
results_cs <- rbind(results_cs, data.frame(alpha=a, mean_accuracy=mean_acc_cs, Condition="CS"))
# OBS condition (always alpha=0 in practice, but we compute for reference)
df_obs_alpha <- process_distribution(
dist_name = paste0("OBS_alpha=", round(a, 3)),
sampled_epochs = fixed_epochs,
num_hidden_neurons = 4,
alpha = 0, # OBS always has no attention
training_data = training_data_AC
)
mean_acc_obs <- mean(df_obs_alpha$Accuracy_Original)
results_obs <- rbind(results_obs, data.frame(alpha=a, mean_accuracy=mean_acc_obs, Condition="OBS"))
}
# Save to cache
dir.create("cache", showWarnings = FALSE)
saveRDS(list(results_ac = results_ac, results_cs = results_cs, results_obs = results_obs), cache_file_alpha)
cat("Results cached to", cache_file_alpha, "\n")
}## Loading cached alpha analysis results...
alpha_values <- c(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1, 1.5, 2)
# Combine results
results_combined <- bind_rows(results_ac, results_cs)
# Plot
alpha_plot <- ggplot(results_combined, aes(x = alpha, y = mean_accuracy, color = Condition)) +
geom_line(size = 1.2) +
geom_point(size = 2.5) +
# Add OBS baseline as horizontal dashed line
geom_hline(yintercept = mean(results_obs$mean_accuracy),
linetype = "dashed", color = "blue", alpha = 0.7) +
annotate("text", x = max(alpha_values) * 0.9, y = mean(results_obs$mean_accuracy) + 0.02,
label = "OBS baseline", color = "blue", size = 3) +
scale_color_manual(values = c("AC" = "red", "CS" = "green"),
labels = c("AC" = "Other explanations", "CS" = "Causal selection")) +
labs(
title = "Accuracy vs. Attention Parameter (epochs = 1)",
x = expression(paste("Attention parameter ", alpha)),
y = "Mean Accuracy",
color = "Condition"
) +
theme_minimal(base_size = 14) +
theme(
legend.position = "bottom",
panel.grid.minor = element_blank()
)
alpha_plotMean accuracy as a function of the attention parameter α. At low α, both conditions benefit from attention. At higher α, CS continues to improve while AC deteriorates.
Three patterns emerge from this analysis:
This pattern supports the attention-based account: explanations help when they highlight relevant variables (CS), but hurt when they highlight irrelevant ones (AC).