Introduction

💡 Key Insight

How do causal explanations help (or hinder) learning? This document presents a computational model proposing that explanations function as attention masks during error-driven learning—they selectively weight which causal variables are updated when prediction errors occur.

This model was developed in Chapter II of my PhD dissertation (Causal Explanations and Continuous Computation, 2025). It applies to experimental data from a rule-inference paradigm where participants learned causal rules from observations, either with or without accompanying causal selection explanations.

Causal selection explanations cite specific causes from among multiple sufficient causes (e.g., “The match caused the fire” rather than “The oxygen caused the fire”). In our experiments, we compared learning outcomes when participants received causal selection explanations versus other types of explanations versus no explanations at all.

In earlier work (Navarre, Konuk, Bramley & Mascarenhas, 2024), we proposed a model of inferences from explanation inspired by the Rational Speech Act (RSA) framework (Frank & Goodman, 2012; Goodman & Frank, 2016). It says that listeners (explainees) use explanations in learning by:

  1. Reverse-engineering the explainer’s causal beliefs based on the explanations: “If the explainer said ‘because of C’, then they must believe that the ground truth causal model is M (relying on shared conventions for what explainers tend to say given their causal beliefs)”
  2. Using that information to update their own beliefs about the ground-truth causal system using Bayesian inference

Here I propose to capture the same patterns of inference using a different (neural) learning architecture: people learn from observations by updating an internal model of the causal system via gradient-based learning procedures. The key twist is a much simpler mechanism for explanation effects: when a variable is mentioned in a (causal selection) explanation, it inflects that learning process by amplifying the input activation of the neurons encoding the corresponding variable(s) by some constant factor. This can be modeled very easily using attention masks. Below I show an implementation that captures the results of our experiment and then some (see dissertation for a more detailed discussion of the advantages of the model). To make the proof-of-concept more salient, I make the basic learning architecture as minimal as possible.


1 Data Extraction

1.1 1. Extracting and cleaning raw data:

First, let’s fetch the experimental data of responses per condition and plot it here again for reference.

# ##LOAD DATA






## We make the extraction procedures simpler than in the other script that accompanied the publication, because there are many things we'll not do, such as plotting most popular rules, etc. For those refer to the other repo.

#1. Simplified multmerge function.

multmerge <- function(mypath) {
  # Get only .csv files
  filenames <- list.files(path = mypath, pattern = "\\.csv$", full.names = TRUE)
  
  if (length(filenames) == 0) {
    stop("No CSV files found in: ", mypath)
  }
  
  # Read each .csv into a list of data frames
  datalist <- lapply(filenames, function(x) {
    # or read_csv(x) if you want
    read.csv(x, header = TRUE, na.strings = c("", "NA"))
  })
  
  # Combine them into one data frame
  merged_df <- bind_rows(datalist)
  
  return(merged_df)
}

#2. Extracting the data 


data_path <- "../data-files/actual-cause-grid-1"  # Relative to models/ folder


if (!dir.exists(data_path)) {
  stop("Directory not found: ", data_path, 
       ". Make sure you downloaded data-files in the same parent folder as code-files.")
}


the_data <- multmerge(data_path)



df_test <- the_data %>%
  filter(questionId == "trial_test")


#2. Simplified parser

parse_json_array <- function(x) {
  # Remove brackets
  x2 <- gsub("\\[|\\]", "", x)
  if (nchar(x2) == 0) return(NA)  # or return numeric(0)
  
  # Split by comma
  as.numeric(strsplit(x2, ",")[[1]])
}

# Apply it to each row in df_test
df_test$parsed_key <- lapply(df_test$test_key, parse_json_array)
df_test$parsed_resp <- lapply(df_test$response, parse_json_array)


df_test$proportion_correct <- mapply(function(key_vec, resp_vec) {
  # 3a) If either is NA or empty, return NA
  if (any(is.na(key_vec)) || any(is.na(resp_vec))) return(NA_real_)
  if (length(key_vec) != length(resp_vec)) return(NA_real_)
  
  # 3b) Compare element by element
  sum(key_vec == resp_vec)
}, df_test$parsed_key, df_test$parsed_resp)


df_subject <- df_test %>%
  group_by(subject_ID, group) %>%
  summarise(
    mean_accuracy = mean(proportion_correct, na.rm=TRUE),
    .groups = "drop"
  )


df_subject$Condition <- factor(df_subject$group,
   levels = c(1,2,3),
   labels = c("OBS", "AC","CS")
)

# Reorder the factor so that AC < OBS < CS
df_subject$Condition <- factor(df_subject$Condition,
   levels = c("AC","OBS","CS")
)

df_means <- df_subject %>%
  group_by(Condition) %>%
  summarise(mean_acc = mean(mean_accuracy, na.rm=TRUE), .groups="drop")



custom_colors <- c("AC"="red","OBS"="blue","CS"="green")

accuracy_plot <- ggplot(df_subject, aes(x=Condition, y=mean_accuracy, fill=Condition)) +
  geom_violin(width=1, alpha=0.4) +

  # Overlay a point for the mean
  geom_point(data=df_means,
             aes(x=Condition, y=mean_acc),
             shape=23, size=3, fill="red", color="black") +

  # Label with numeric means
  geom_text(data=df_means,
            aes(x=Condition, y=mean_acc, label=round(mean_acc,2)),
            vjust=-1, color="black", fontface="bold") +

  scale_fill_manual(values=custom_colors) +

  labs(
    title="Observed Accuracy by Condition",
    x="Condition",
    y="Mean Accuracy"
  ) +
  theme_minimal()

Dissertation Reference: This figure corresponds to Figure 2.3 in Chapter II of the dissertation (Causal Explanations and Continuous Computation, 2025). The conditions are: OBS (Observations Only), AC (Any Cause explanations), and CS (Causal Selection explanations).

accuracy_plot
Observed accuracy distribution by experimental condition. Diamond markers indicate condition means.

Observed accuracy distribution by experimental condition. Diamond markers indicate condition means.

1.2 Response Time Distributions

💡 Why Response Times Matter

The distribution of response times during the learning phase informs how we model individual differences in the number of “learning epochs.” An exponential distribution of RTs suggests exponentially distributed learning durations across participants.

📊 Technical Details: RT Distribution Analysis

We talked in the relevant thesis chapter about how the distribution of response times for the crucial trial (trial_observations)—where people sample from urns and form a judgment about the underlying rule from their observations (+ explanations)—follows an exponential distribution. Relevant code is in the cell below.

############################################################
##  RT Distributions 
############################################################

# We'll:
# (1) Identify the RT data for relevant trials and plot the distribution
# (2) Fit ex-Gaussian, normal, and poisson
# (3) Compare fits via AIC/BIC

############################################################




crucial_data <- the_data %>%
  filter(questionId == "trial_observations")

# I remove one extreme outlier with a response times higher than 30mn. Note that this subject's responses are not otherwise weird (so it may not be a real outlier) and that such removal is conservative with respect to the ex-gaussian hypothesis (which would be even more favored if we were to include it)

crucial_data <- crucial_data %>%
  filter(rt > 0, !is.na(rt), rt < 200000) # removing the most extreme outlier

options(scipen = 999)
# 2) Quick look at the distribution
rt_plot <- ggplot(crucial_data, aes(x=rt)) +
  geom_density() +
  labs(
    title="Participant response time distributions for learning trial",
    x="Response Time (ms)",
    y="Density"
  ) +
  theme_minimal()

rt_plot

# ggsave("RT_distribution.pdf", plot= rt_plot, width = 6, height = 4)

# 3) Extract the vector of RT
rt_data <- crucial_data$rt

# 4) Fit distributions using GAMLSS:
#    - Normal (NO)
#    - Ex-Gaussian (exGAUS)
#    - Exponential (EXP)
fit_norm <- gamlss(rt_data ~ 1, family=NO)      # Normal
## GAMLSS-RS iteration 1: Global Deviance = 6794.479 
## GAMLSS-RS iteration 2: Global Deviance = 6794.479
fit_exg  <- gamlss(rt_data ~ 1, family=exGAUS)  # ex-Gaussian
## GAMLSS-RS iteration 1: Global Deviance = 6743.017 
## GAMLSS-RS iteration 2: Global Deviance = 6724.212 
## GAMLSS-RS iteration 3: Global Deviance = 6709.461 
## GAMLSS-RS iteration 4: Global Deviance = 6702.68 
## GAMLSS-RS iteration 5: Global Deviance = 6700.124 
## GAMLSS-RS iteration 6: Global Deviance = 6699.248 
## GAMLSS-RS iteration 7: Global Deviance = 6698.963 
## GAMLSS-RS iteration 8: Global Deviance = 6698.869 
## GAMLSS-RS iteration 9: Global Deviance = 6698.838 
## GAMLSS-RS iteration 10: Global Deviance = 6698.829 
## GAMLSS-RS iteration 11: Global Deviance = 6698.826 
## GAMLSS-RS iteration 12: Global Deviance = 6698.825 
## GAMLSS-RS iteration 13: Global Deviance = 6698.824
fit_exp  <- gamlss(rt_data ~ 1, family=EXP)     # Exponential
## GAMLSS-RS iteration 1: Global Deviance = 6852.432 
## GAMLSS-RS iteration 2: Global Deviance = 6852.432
# Compare via AIC or BIC
cat("\n## Model comparison: AIC\n")
## 
## ## Model comparison: AIC
print(AIC(fit_norm, fit_exg, fit_exp))
##          df      AIC
## fit_exg   3 6704.824
## fit_norm  2 6798.479
## fit_exp   1 6854.432
cat("\n## Model comparison: BIC\n")
## 
## ## Model comparison: BIC
print(BIC(fit_norm, fit_exg, fit_exp))
##          df      BIC
## fit_norm  2 6805.812
## fit_exg   3 6715.824
## fit_exp   1 6858.098

2 Building the Model

Model Architecture

The model is a simple feed-forward neural network with:

  • 4 input neurons (A, B, C, D) representing causal variables
  • 4 hidden neurons with tanh activation
  • 1 output neuron (W) representing the outcome (win/lose)

The key innovation is the attention mechanism: explanations modulate input activations during backpropagation, effectively acting as attention masks that weight which variables are updated when errors occur.

2.1 Network Architecture Diagram

Feedforward neural network architecture. Input nodes (A, B, C, D) correspond to urn draws, hidden layer processes feature combinations, and output node W predicts win/lose outcome.

Feedforward neural network architecture. Input nodes (A, B, C, D) correspond to urn draws, hidden layer processes feature combinations, and output node W predicts win/lose outcome.

2.2 The Attention Mechanism

The attention mechanism is the core innovation. Each explanation \(\xi\) highlights a subset of input variables as causally relevant. We formalize this as an attention mask applied to input activations.

2.2.1 Mathematical Formalization

For each variable \(x \in \{A, B, C, D\}\), the explanation \(\xi\) assigns:

\[\xi(x) = \begin{cases} +1 & \text{if variable } x \text{ is mentioned in the explanation} \\ -1 & \text{otherwise} \end{cases}\]

The attention weight for variable \(x\) given explanation \(\xi\) is:

\[\text{Att}(x, \xi) = \exp\bigl[\alpha \cdot \xi(x)\bigr]\]

where \(\alpha\) is the attention parameter controlling the strength of the attention effect.

  • If \(x\) is mentioned: weight = \(e^{\alpha}\) (amplified)
  • If \(x\) is not mentioned: weight = \(e^{-\alpha}\) (attenuated)

The network’s input activation becomes:

\[x_{\text{input}} \leftarrow (\pm 1) \times \text{Att}(x, \xi)\]

Attention mechanism: Explanation 'C caused the outcome' amplifies input C (larger node) while attenuating A, B, D (smaller nodes).

Attention mechanism: Explanation ‘C caused the outcome’ amplifies input C (larger node) while attenuating A, B, D (smaller nodes).

2.2.2 Loss Function and Learning Rule

At each observation \(\omega\), the model computes the squared error:

\[L_{\omega} = \bigl(y_{\omega} - \hat{y}_{\omega}\bigr)^2\]

where \(y_{\omega}\) is the target (\(\pm 1\) for win/lose) and \(\hat{y}_{\omega}\) is the network’s prediction.

Parameters are updated via gradient descent:

\[w_i \leftarrow w_i - \eta \cdot \frac{\partial L_{\omega}}{\partial w_i}\]

The attention mechanism affects this update by modulating input activations, which in turn affects gradient magnitudes for weights connected to attended vs. unattended inputs.

Dissertation Reference: The full theoretical motivation and mathematical derivation appear in Chapter II, Sections 2.4–2.6 of the dissertation. See Algorithm 6 (Attention-Based Inductive Learning) for the complete procedure.

2.3 Data structures

2.3.1 Condition-Specific Training Data

We build two lists of training data, one for the Causal Selection condition, one for the Any Cause condition. They differ only with respect to the list of explanations that accompany each observation.

It is not necessary to build a list for the “Observation only” condition, because the model for that condition is just a special case of the model for the other two conditions, where the learning pipeline does not change as a function of the explanations given. We just implement that condition by feeding the training function the training data for CS or AC by forcing the alpha parameter to zero.

Each dataset is a list of the same four observation, with repetition counts for each observations that match the number of times they were repeated in the data presented to subjects, and a list of explanations attached to each observation, that depend on the particular condition.

# 'CS' condition (Causal Selection) data
training_data_CS <- list(
  list(
    input = c(A=1,  B=-1, C=1,  D=-1), 
    target = c(E=1), 
    possible_explanations = c("C"),
    repetition_count = 4
  ),
  list(
    input = c(A=1,  B=1,  C=1,  D=1),  
    target = c(E=1), 
    possible_explanations = c("C"),
    repetition_count = 1
  ),
  list(
    input = c(A=-1, B=1,  C=-1, D=-1), 
    target = c(E=-1),  
    possible_explanations = c("C"),
    repetition_count = 1
  ),
  list(
    input = c(A=1,  B=1,  C=-1, D=-1), 
    target = c(E=-1),  
    possible_explanations = c("C&D"),
    repetition_count = 4
  )
)

# 'AC' condition (Any Cause) data
training_data_AC <- list(
  list(
    input = c(A=1,  B=-1, C=1,  D=-1), 
    target = c(E=1),
    possible_explanations = c("A&C"),
    repetition_count = 4
  ),
  list(
    input = c(A=1,  B=1,  C=1,  D=1),
    target = c(E=1),
    possible_explanations = c("A", "D", "A&D", "A&C", "C&D", "A&C&D"),
    repetition_count = 1
  ),
  list(
    input = c(A=-1, B=1,  C=-1, D=-1),
    target = c(E=-1),
    possible_explanations = c("A", "D", "A&D", "A&C", "A&C&D"),
    repetition_count = 1
  ),
  list(
    input = c(A=1,  B=1,  C=-1, D=-1),
    target = c(E=-1),
    possible_explanations = c("C", "D"),
    repetition_count = 4
  )
)

2.3.2 Testing data (common to all conditions)

# Testing data (16 possible input combos)

testing_data <- list(
  #  1) A= 1, B= 1, C= 1, D= 1 => E=1
  list(input = c(A=1,  B=1,  C=1,  D=1),  target = c(E=1)),
  
  #  2) A= 1, B= 1, C= 1, D=-1 => E=1
  list(input = c(A=1,  B=1,  C=1,  D=-1), target = c(E=1)),
  
  #  3) A= 1, B= 1, C=-1, D= 1 => E=1
  list(input = c(A=1,  B=1,  C=-1, D=1),  target = c(E=1)),
  
  #  4) A= 1, B= 1, C=-1, D=-1 => E=-1
  list(input = c(A=1,  B=1,  C=-1, D=-1), target = c(E=-1)),
  
  #  5) A= 1, B=-1, C= 1, D= 1 => E=1
  list(input = c(A=1,  B=-1, C=1,  D=1),  target = c(E=1)),
  
  #  6) A= 1, B=-1, C= 1, D=-1 => E=1
  list(input = c(A=1,  B=-1, C=1,  D=-1), target = c(E=1)),
  
  #  7) A= 1, B=-1, C=-1, D= 1 => E=1
  list(input = c(A=1,  B=-1, C=-1, D=1),  target = c(E=1)),
  
  #  8) A= 1, B=-1, C=-1, D=-1 => E=-1
  list(input = c(A=1,  B=-1, C=-1, D=-1), target = c(E=-1)),
  
  #  9) A=-1, B= 1, C= 1, D= 1 => E=1
  list(input = c(A=-1, B=1,  C=1,  D=1),  target = c(E=1)),
  
  # 10) A=-1, B= 1, C= 1, D=-1 => E=1
  list(input = c(A=-1, B=1,  C=1,  D=-1), target = c(E=1)),
  
  # 11) A=-1, B= 1, C=-1, D= 1 => E=-1
  list(input = c(A=-1, B=1,  C=-1, D=1),  target = c(E=-1)),
  
  # 12) A=-1, B= 1, C=-1, D=-1 => E=-1
  list(input = c(A=-1, B=1,  C=-1, D=-1), target = c(E=-1)),
  
  # 13) A=-1, B=-1, C= 1, D= 1 => E=1
  list(input = c(A=-1, B=-1, C=1,  D=1),  target = c(E=1)),
  
  # 14) A=-1, B=-1, C= 1, D=-1 => E=1
  list(input = c(A=-1, B=-1, C=1,  D=-1), target = c(E=1)),
  
  # 15) A=-1, B=-1, C=-1, D= 1 => E=-1
  list(input = c(A=-1, B=-1, C=-1, D=1),  target = c(E=-1)),
  
  # 16) A=-1, B=-1, C=-1, D=-1 => E=-1
  list(input = c(A=-1, B=-1, C=-1, D=-1), target = c(E=-1))
)

2.3.3 Functions for data pre-processing

The lists above do contain all the necessary information needed, but then it has to be turned into a format that can be read by our training function below. That involves:

  1. “Unfolding” the repetition counts by turning the 4-observation lists into 10-item lists with appropriate repetitions

  2. Randomizing the order

  3. Turning the explanation into vector formats that can be used in the attention mechanism

We do that with the helper below:

###  Generate Subject Training Data:
### This takes one of the training set list, with the appropriate observation, and then returns a set of 10 observations, with the appropriate repetitions, as well as a vector of explanation associated with each observation, again in accordance with the condition.
### It will be run once per subject, ensuring that the mapping between observations and the appropriate explanations stays constant throughout training epochs for a given subject.


generate_subject_training_data <- function(base_training_data) {
  # For each observation in base_training_data, pick one explanation 
  # for all of its repeated trials, then shuffle.
  expanded_list <- list()
  input_neurons <- c("A","B","C","D")
  
  for (obs in base_training_data) {
    chosen_expl <- NULL
    if (!is.null(obs$possible_explanations) && length(obs$possible_explanations) > 0) {
      chosen_expl <- sample(obs$possible_explanations, 1)
    }
    
    # Build explanation vector
    if (!is.null(chosen_expl)) {
      vars <- strsplit(chosen_expl, "&")[[1]]
      vars <- trimws(vars)
      expl_vector <- setNames(rep(0, length(input_neurons)), input_neurons)
      expl_vector[vars] <- 1
    } else {
      expl_vector <- setNames(rep(0, length(input_neurons)), input_neurons)
    }
    
    # replicate obs for repetition_count times
    for (i in seq_len(obs$repetition_count)) {
      expanded_list[[length(expanded_list)+1]] <-
        list(input = obs$input,
             target = obs$target,
             explanation = expl_vector)
    }
  }
  
  # Shuffle
  expanded_list <- sample(expanded_list)
  return(expanded_list)
}

2.4 Training functions:

2.4.1 Central network training function:

We define the central network training pipeline. It consists of standard backpropagation training, with the addition of the attention mechanism:

# Helper functions:
tanh_activation <- function(x) {
  tanh(x)
}

tanh_derivative <- function(x) {
  1 - tanh(x)^2
}

# Global learning parameters
learning_rate <- 1     # We fix this to 1 throughout. It is made a parameter here for readers interested in playing around with it.
num_hidden_neurons <- 4 # Same for this

### Early stopping rule just in case we do a manipulation error with the number of epochs; this will never be reached in all normal uses.
threshold     <- 1e-6
max_epochs    <- 5000

#training function
train_network <- function(
    num_epochs,
    num_hidden_neurons,
    alpha,
    training_data_sampled
) {
  # 1) Define neuron sets
  input_neurons  <- c("A","B","C","D")
  output_neurons <- c("E")
  hidden_neurons <- paste0("H", seq_len(num_hidden_neurons))
  
  num_inputs  <- length(input_neurons)
  num_hidden  <- num_hidden_neurons
  num_outputs <- length(output_neurons)
  
  # 2) Initialize weights & biases
  W_input_hidden <- matrix(
    rnorm(num_inputs * num_hidden, mean=0, sd=0.1),
    nrow = num_inputs, ncol = num_hidden,
    dimnames = list(input_neurons, hidden_neurons)
  )
  W_hidden_output <- matrix(
    rnorm(num_hidden * num_outputs, mean=0, sd=0.1),
    nrow = num_hidden, ncol = num_outputs,
    dimnames = list(hidden_neurons, output_neurons)
  )
  
  bias_hidden <- setNames(runif(num_hidden, min=-0.1, max=0.1), hidden_neurons)
  bias_output <- setNames(runif(num_outputs, min=-0.1, max=0.1), output_neurons)
  
  previous_loss <- Inf
  
  # 3) Training loop
  for (epoch in seq_len(num_epochs)) {
    total_loss <- 0
    
    # Shuffle the data inside each epoch
    training_data_epoch <- sample(training_data_sampled)
    
    for (sample in training_data_epoch) {
      activations_input <- sample$input   # e.g. c(A=..., B=..., C=..., D=...)
      target_output     <- sample$target  # c(E= ±1)
      explanation_vec   <- sample$explanation[input_neurons]  # here written as 0/1 vector, so we can use the ifelse() function below; but the notation in terms of -1/1 in the thesis chapter is more natural
      
      # Single-alpha attention scheme:
      # e^(alpha) for mentioned, e^(-alpha) for unmentioned
      attention_vector <- ifelse(explanation_vec == 1, exp(alpha), exp(-alpha))
      
      # Adjust inputs
      activations_input_adj <- activations_input * attention_vector
      
      # Forward pass
      net_hidden <- as.vector(activations_input_adj %*% W_input_hidden + bias_hidden)
      act_hidden <- tanh_activation(net_hidden)
      
      net_output <- as.vector(act_hidden %*% W_hidden_output + bias_output)
      act_output <- tanh_activation(net_output)
      
      # Error
      error_output <- act_output - target_output
      loss_sample  <- sum(error_output^2)/num_outputs
      total_loss   <- total_loss + loss_sample
      
      # Backprop
      delta_output <- error_output * tanh_derivative(net_output)
      delta_hidden <- as.vector((W_hidden_output %*% delta_output) *
                                tanh_derivative(net_hidden))
      
      # Gradients
      grad_W_input_hidden <- (activations_input_adj %*% t(delta_hidden))
      grad_bias_hidden    <- delta_hidden
      
      grad_W_hidden_output<- (act_hidden %*% t(delta_output))
      grad_bias_output    <- delta_output
      
      # Update (no regularization)
      W_hidden_output <- W_hidden_output - learning_rate * grad_W_hidden_output
      bias_output     <- bias_output     - learning_rate * grad_bias_output
      
      W_input_hidden  <- W_input_hidden  - learning_rate * grad_W_input_hidden
      bias_hidden     <- bias_hidden     - learning_rate * grad_bias_hidden
    }
    
    # Early stopping
    avg_loss <- total_loss / length(training_data_sampled)
    if (abs(previous_loss - avg_loss) < threshold) {
      break
    }
    previous_loss <- avg_loss
  }
  
  # 4) Store final network
  trained_network <- list(
    W_input_hidden = W_input_hidden,
    W_hidden_output= W_hidden_output,
    bias_hidden    = bias_hidden,
    bias_output    = bias_output
  )
  
  # 5) Evaluate on the test data with the final network
  total_error_original <- 0
  
  for (test_sample in testing_data) {
    test_inp   <- test_sample$input
    target_out <- test_sample$target
    
    net_h  <- as.vector(test_inp %*% W_input_hidden + bias_hidden)
    act_h  <- tanh_activation(net_h)
    net_o  <- as.vector(act_h %*% W_hidden_output + bias_output)
    act_o  <- tanh_activation(net_o)
    
    # binarize
    pred   <- ifelse(act_o>0, +1, -1)
    err    <- sum(abs(pred - target_out))/2
    total_error_original <- total_error_original + err
  }
  
  num_samples_test   <- length(testing_data)
  accuracy_original  <- (num_samples_test - total_error_original)
  
  # 6) Return final object 
  list(
    W_input_hidden    = W_input_hidden,
    W_hidden_output   = W_hidden_output,
    bias_hidden       = bias_hidden,
    bias_output       = bias_output,
    accuracy_original = accuracy_original
  )
}

2.4.2 Aggregator function:

This wraps around the network training function to call it and collect the results.

### Then a function to aggregate the results of several networks as we generate predictions out of each of them:

process_distribution <- function(dist_name,
                                 sampled_epochs,
                                 num_hidden_neurons,
                                 alpha,
                                 training_data) 
{
  # Filter out invalid epochs
  sampled_epochs <- sampled_epochs[sampled_epochs>0 & sampled_epochs<=max_epochs]
  if (length(sampled_epochs)==0) {
    warning("No valid epochs for ", dist_name)
    return(NULL)
  }
  
  results_list <- list()
  
  for (num_epochs in sampled_epochs) {
    # 1) Generate subject-specific training data
    training_data_subj <- generate_subject_training_data(training_data)
    
    # 2) Train
    
    res <- train_network(
      num_epochs           = num_epochs,
      num_hidden_neurons   = num_hidden_neurons,
      alpha                = alpha,
      training_data_sampled= training_data_subj
    )
    
    results_list[[length(results_list)+1]] <- res
  }
  
  # 3) Build data frame from results
  df <- data.frame(
    Epochs            = sampled_epochs,
    Accuracy_Original = sapply(results_list, function(x) x$accuracy_original),
    Distribution      = dist_name,
    NumHiddenNeurons  = num_hidden_neurons,
    Alpha             = alpha
  )

  
  df
}

3 Model Predictions

💡 Key Prediction

The attention-based model predicts a specific ordering of accuracy across conditions: AC < OBS < CS. This matches the experimental data, while alternative models (e.g., Bayesian rule inference) predict different orderings.

3.1 Initial Demonstration (Toy Parameters)

# =============================================================================
# TOY PARAMETERS SIMULATION (α=1, λ=1000)
# Results are cached to avoid re-running expensive simulations
# =============================================================================

cache_file_toy <- "cache/toy_params_results.rds"

# Check if cached results exist
if (file.exists(cache_file_toy)) {
  cat("Loading cached toy parameters results...\n")
  cached_toy <- readRDS(cache_file_toy)
  df_AC <- cached_toy$df_AC
  df_CS <- cached_toy$df_CS
  df_OBS <- cached_toy$df_OBS
} else {
  cat("Running toy parameters simulation (n=1000)...\n")

  # Parameters for this demonstration
  rate <- 1000
  num_samples <- 1000  # Full simulation
  num_hidden <- 4
  alpha <- 1

  # Generate the distribution of epochs
  set.seed(999)
  sampled_epochs <- rexp(num_samples, rate=rate)
  sampled_epochs <- pmax(1, round(sampled_epochs))

  # (a) AC condition, alpha=1
  df_AC <- process_distribution(
    dist_name        = "AC_alpha1",
    sampled_epochs   = sampled_epochs,
    num_hidden_neurons = num_hidden,
    alpha           = alpha,
    training_data   = training_data_AC
  )

  # (b) CS condition, alpha=1
  df_CS <- process_distribution(
    dist_name        = "CS_alpha1",
    sampled_epochs   = sampled_epochs,
    num_hidden_neurons = num_hidden,
    alpha           = alpha,
    training_data   = training_data_CS
  )

  # (c) Observation-only condition => alpha=0
  df_OBS <- process_distribution(
    dist_name        = "OBS_alpha0",
    sampled_epochs   = sampled_epochs,
    num_hidden_neurons = num_hidden,
    alpha           = 0,
    training_data   = training_data_AC
  )

  # Save to cache
  dir.create("cache", showWarnings = FALSE)
  saveRDS(list(df_AC = df_AC, df_CS = df_CS, df_OBS = df_OBS), cache_file_toy)
  cat("Results cached to", cache_file_toy, "\n")
}
## Loading cached toy parameters results...
# 4) Combine
df_combined <- bind_rows(df_AC, df_OBS, df_CS)

# 5) Plot distributions of accuracy on a violin plot with means


# Tag each distribution with a Condition label 
df_combined <- df_combined %>%
  mutate(Condition = case_when(
    grepl("AC", Distribution)  ~ "AC",
    grepl("OBS", Distribution) ~ "OBS",
    grepl("CS", Distribution)  ~ "CS",
    TRUE ~ "Unknown"
  ))


df_combined$Condition <- factor(df_combined$Condition, levels=c("AC","OBS","CS"))

# Colors
custom_colors <- c("AC"="red", "OBS"="blue", "CS"="green")

# Summarise means

df_means_model <- df_combined %>%
  group_by(Condition) %>%
  summarise(mean_acc = mean(Accuracy_Original))


print(df_means_model)
## # A tibble: 3 × 2
##   Condition mean_acc
##   <fct>        <dbl>
## 1 AC            13.1
## 2 OBS           13.3
## 3 CS            13.6
model_acc <- ggplot(df_combined, aes(x=Condition, y=Accuracy_Original, fill=Condition)) +
  geom_violin(width=1, alpha=0.4) +
  geom_point(data=df_means_model,
             aes(x=Condition, y=mean_acc),
             shape=23, size=3, fill="red", color="black") +
  geom_text(data=df_means_model,
            aes(x=Condition, y=mean_acc,
                label=round(mean_acc,2)),
            vjust=-1, color="black", fontface="bold") +
  scale_fill_manual(values=custom_colors) +
  labs(
    title="Model accuracy distributions for alpha = 1, lambda = 1000 ",
    x="Condition", y="Accuracy Score"
  ) +
  theme_minimal()

Dissertation Reference: This demonstration plot corresponds to Figure 2.5 in Chapter II. The toy parameters (α=1, λ=1000) show the model captures the correct qualitative pattern before parameter fitting.

model_acc
Model accuracy distributions with toy parameters (α=1, λ=1000). The ordering AC < OBS < CS matches the experimental data.

Model accuracy distributions with toy parameters (α=1, λ=1000). The ordering AC < OBS < CS matches the experimental data.

💡 Key Observations

Two things are apparent from visual inspection:

  1. The difference between the means per condition follows the same order in our model and in the data: AC < OBS < CS
  2. The attention account captures the fact that causal selection explanations make you more accurate than mere observation, and other kinds of explanations make you less accurate
  3. This is strong evidence for the attention account—the Bayesian model we elaborated previously predicted the wrong ordering (OBS < AC < CS, see CogSci paper)
  4. The absolute values for accuracy predicted by the model are higher than observed—we address this via parameter fitting below

3.2 Grid search parameter fitting:

3.2.1 Retrieve and store

df_stats_emp <- df_subject %>%
  group_by(Condition) %>%
  summarise(
    observed_mean = mean(mean_accuracy, na.rm=TRUE),
    observed_sd   = sd(mean_accuracy, na.rm=TRUE),
    .groups="drop"
  )

# Extract each condition’s stats
mu_AC  <- df_stats_emp$observed_mean[df_stats_emp$Condition=="AC"]
mu_OBS <- df_stats_emp$observed_mean[df_stats_emp$Condition=="OBS"]
mu_CS  <- df_stats_emp$observed_mean[df_stats_emp$Condition=="CS"]


### The SDs are collected just for inspection
sigma_AC  <- df_stats_emp$observed_sd[df_stats_emp$Condition=="AC"]
sigma_OBS <- df_stats_emp$observed_sd[df_stats_emp$Condition=="OBS"]
sigma_CS  <- df_stats_emp$observed_sd[df_stats_emp$Condition=="CS"]

cat("Observed Means:\n")
## Observed Means:
cat("  AC =", mu_AC,"\n")
##   AC = 11.0102
cat(" OBS =", mu_OBS,"\n")
##  OBS = 11.90291
cat("  CS =", mu_CS,"\n\n")
##   CS = 12.76289
cat("Observed SDs:\n")
## Observed SDs:
cat("  AC =", sigma_AC,"\n")
##   AC = 1.880419
cat(" OBS =", sigma_OBS,"\n")
##  OBS = 2.315822
cat("  CS =", sigma_CS,"\n")
##   CS = 2.330831

3.2.2 Function for fetching the stats that come out of a certain set of parameters in the model:

get_predicted_stats <- function(
    condition,
    exp_rate,
    alpha,
    num_samples     = 50,    
    num_hidden      = 4,     
    training_data_AC,
    training_data_CS
) {
  # 1) pick the appropriate training_data
  if (condition=="AC") {
    base_data <- training_data_AC
  } else if (condition=="CS") {
    base_data <- training_data_CS
  } else if (condition=="OBS") {
    base_data <- training_data_AC
  } else {
    stop("Condition must be AC, CS, or OBS")
  }
  
  # 2) If condition=="OBS", force alpha=0
  if (condition=="OBS") {
    alpha_used <- 0
  } else {
    alpha_used <- alpha
  }
  
  # 3) Sample epoch distribution
  #    We do rexp(num_samples, rate=exp_rate), round them, etc.
  set.seed(999)  # optional, for reproducibility
  sampled_epochs <- rexp(num_samples, rate=exp_rate)
  sampled_epochs <- pmax(1, round(sampled_epochs) )
  
  dist_name <- paste0(condition, "_Alpha_", alpha_used, "_Exp_", exp_rate)
  
  # 4) Call aggregator
  df_res <- process_distribution(
    dist_name         = dist_name,
    sampled_epochs    = sampled_epochs,
    num_hidden_neurons= num_hidden,
    alpha             = alpha_used,
    training_data     = base_data
  )
  
  # 5) from df_res, we get the distribution of Accuracy_Original
  predicted_mean     <- mean(df_res$Accuracy_Original, na.rm=TRUE)
  # predicted_sd       <- sd(df_res$Accuracy_Original, na.rm=TRUE)
  
  # Return both as a list or named vector
  return(list(
    mean = predicted_mean
    # sd   = predicted_sd
  ))
}

3.2.3 Compute SSE

Function to compute SSE given those stats:

sse_total <- function(
    M_AC,
    M_OBS,
    M_CS, 
    mu_AC, 
    mu_OBS, 
    mu_CS
) {
  #  difference-based SSE (3 pairwise differences)
  diff_ac_obs_model <- (M_AC  - M_OBS)
  diff_cs_obs_model <- (M_CS  - M_OBS)
  diff_cs_ac_model  <- (M_CS  - M_AC)
  
  diff_ac_obs_data  <- (mu_AC  - mu_OBS)
  diff_cs_obs_data  <- (mu_CS  - mu_OBS)
  diff_cs_ac_data   <- (mu_CS  - mu_AC)
  
  SSE_diff <- (diff_ac_obs_model - diff_ac_obs_data)^2 +
              (diff_cs_obs_model - diff_cs_obs_data)^2 +
              (diff_cs_ac_model  - diff_cs_ac_data)^2
  

  return(SSE_diff)
}

3.3 Effect of Attention Parameter α

We examine how accuracy changes as a function of the attention parameter α, with the number of epochs fixed to 1 (equivalent to λ=1000).

Dissertation Reference: This analysis corresponds to Figure 2.7 in Chapter II, showing how both CS and AC conditions benefit from low attention levels, but diverge as α increases.

############################################################
## Evolution of accuracy as function of alpha (epochs fixed to 1)
## Results are cached to avoid re-running expensive simulations
############################################################

cache_file_alpha <- "cache/alpha_analysis_results.rds"

# Check if cached results exist
if (file.exists(cache_file_alpha)) {
  cat("Loading cached alpha analysis results...\n")
  cached_alpha <- readRDS(cache_file_alpha)
  results_ac <- cached_alpha$results_ac
  results_cs <- cached_alpha$results_cs
  results_obs <- cached_alpha$results_obs
} else {
  cat("Running alpha analysis simulations (n=500 per condition per alpha)...\n")

  alpha_values <- c(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1, 1.5, 2)

  # Store results
  results_ac <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())
  results_cs <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())
  results_obs <- data.frame(alpha=numeric(), mean_accuracy=numeric(), Condition=character())

  # For each alpha, run 500 simulations with 1 epoch each (matching dissertation)
  num_runs <- 500
  fixed_epochs <- rep(1, num_runs)

  for (a in alpha_values) {
    cat("  Processing alpha =", a, "\n")

    # AC condition
    df_ac <- process_distribution(
      dist_name        = paste0("AC_alpha=", round(a, 3)),
      sampled_epochs   = fixed_epochs,
      num_hidden_neurons = 4,
      alpha           = a,
      training_data   = training_data_AC
    )
    mean_acc_ac <- mean(df_ac$Accuracy_Original)
    results_ac <- rbind(results_ac, data.frame(alpha=a, mean_accuracy=mean_acc_ac, Condition="AC"))

    # CS condition
    df_cs <- process_distribution(
      dist_name        = paste0("CS_alpha=", round(a, 3)),
      sampled_epochs   = fixed_epochs,
      num_hidden_neurons = 4,
      alpha           = a,
      training_data   = training_data_CS
    )
    mean_acc_cs <- mean(df_cs$Accuracy_Original)
    results_cs <- rbind(results_cs, data.frame(alpha=a, mean_accuracy=mean_acc_cs, Condition="CS"))

    # OBS condition (always alpha=0 in practice, but we compute for reference)
    df_obs_alpha <- process_distribution(
      dist_name        = paste0("OBS_alpha=", round(a, 3)),
      sampled_epochs   = fixed_epochs,
      num_hidden_neurons = 4,
      alpha           = 0,  # OBS always has no attention
      training_data   = training_data_AC
    )
    mean_acc_obs <- mean(df_obs_alpha$Accuracy_Original)
    results_obs <- rbind(results_obs, data.frame(alpha=a, mean_accuracy=mean_acc_obs, Condition="OBS"))
  }

  # Save to cache
  dir.create("cache", showWarnings = FALSE)
  saveRDS(list(results_ac = results_ac, results_cs = results_cs, results_obs = results_obs), cache_file_alpha)
  cat("Results cached to", cache_file_alpha, "\n")
}
## Loading cached alpha analysis results...
alpha_values <- c(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1, 1.5, 2)

# Combine results
results_combined <- bind_rows(results_ac, results_cs)

# Plot
alpha_plot <- ggplot(results_combined, aes(x = alpha, y = mean_accuracy, color = Condition)) +
  geom_line(size = 1.2) +
  geom_point(size = 2.5) +
  # Add OBS baseline as horizontal dashed line
  geom_hline(yintercept = mean(results_obs$mean_accuracy), 
             linetype = "dashed", color = "blue", alpha = 0.7) +
  annotate("text", x = max(alpha_values) * 0.9, y = mean(results_obs$mean_accuracy) + 0.02,
           label = "OBS baseline", color = "blue", size = 3) +
  scale_color_manual(values = c("AC" = "red", "CS" = "green"),
                     labels = c("AC" = "Other explanations", "CS" = "Causal selection")) +
  labs(
    title = "Accuracy vs. Attention Parameter (epochs = 1)",
    x = expression(paste("Attention parameter ", alpha)),
    y = "Mean Accuracy",
    color = "Condition"
  ) +
  theme_minimal(base_size = 14) +
  theme(
    legend.position = "bottom",
    panel.grid.minor = element_blank()
  )

alpha_plot
Mean accuracy as a function of the attention parameter α. At low α, both conditions benefit from attention. At higher α, CS continues to improve while AC deteriorates.

Mean accuracy as a function of the attention parameter α. At low α, both conditions benefit from attention. At higher α, CS continues to improve while AC deteriorates.

💡 Key Patterns

Three patterns emerge from this analysis:

  1. Initial benefits for both conditions: When α increases from 0 to ~0.1, both CS and AC conditions improve relative to the OBS baseline
  2. Divergence at higher α: Past this point, CS continues to benefit while AC performance drops
  3. Overshoot at very high α: Even CS eventually suffers if attention is too strong, as the learner becomes too narrowly focused

This pattern supports the attention-based account: explanations help when they highlight relevant variables (CS), but hurt when they highlight irrelevant ones (AC).