Introduction

Dissertation Reference: This implementation corresponds to Chapter 4 of the dissertation: “A Neural Approach to Causal Selection Judgments” (Part I). It implements the neural network model and Layer-wise Relevance Propagation approach, tested against experimental data from Chapter 3: “Plural Causes in Causal Judgment.”

🎮 Interactive Demo

Try the interactive Shiny app to explore the model with your own parameter values. Adjust the normality parameters and see how they affect predicted causal importance scores across different scenarios.

Background: Causal Selection and Counterfactual Theories

Causal selection is the process underlying our intuition that an outcome happened because of a given event, or that an event is the cause of an outcome. Unlike judgments of mere actual causation (identifying which events can be counted as causes), causal selection induces a ranking over these events, singling out some as being more important than others in bringing about the outcome.

The Forest Fire Example

When a forest catches fire after a lightning strike, people tend to say that the lightning bolt was the cause of the fire, not mentioning the presence of oxygen in the air—although oxygen was no less indispensable for the fire to occur. This preference reflects causal selection at work: we single out the abnormal event (lightning) over the normal one (oxygen).

Two well-documented patterns characterize how normality interacts with causal structure:

  • Abnormal inflation: In conjunctive structures (where multiple causes are jointly necessary), abnormal/rare causes receive greater causal credit than normal ones.
  • Abnormal deflation: In disjunctive structures (where each cause alone is sufficient), the pattern reverses—normal causes receive greater credit.

Counterfactual simulation models explain these patterns by proposing that people mentally simulate alternative scenarios—with normal scenarios coming more readily to mind—then assess how each cause covaries with the outcome across these simulations.

The Neural Model: Why Go Beyond Structural Causal Models?

A key premise of most existing counterfactual models is that Structural Causal Models (SCMs) provide the appropriate representational scaffold. However, recent experimental evidence suggests that people’s representations of causal relationships possess more structure than SCMs typically capture.

The Program Perspective

Consider a weighted sum equation: \(S := A + 2B + C\). This relation can be computed by different neural architectures—one with direct connections from inputs to output, another routing signals through intermediate hidden nodes. Although both encode the same function, they differ in their internal structure. My contention is that people represent causal rules through similarly non-minimal architectures, and that this hidden structure plays a crucial role in how we generate causal explanations.

This document implements a computational model that operationalizes counterfactual simulation at a sub-symbolic level, using neural network representations that mirror the logical structure of causal rules.

Model Architecture: Five Core Components

The model consists of five interconnected components, each corresponding to a chapter in the dissertation:

1. Logic Programs as Representations of Causal Knowledge

We assume people represent Boolean causal rules (like “Win if you get two colored balls out of three”) via declarative logic programs—sets of Horn Clauses of the form (Head ← Body):

\[E \leftarrow A, B \quad \text{(To prove E, prove both A and B)}\] \[E \leftarrow C, D \quad \text{(Or prove both C and D)}\]

This representation emphasizes sufficient conditions: each clause body is a conjunction that suffices to prove the head. This aligns with mental model theories positing that we grasp propositions through their minimal verification conditions rather than tracking all compatible worlds.

2. Program-to-Network Translation (CILP Algorithm)

The logic program is translated into a neural network via the CILP (Connectionist Inductive Learning and Logic Programming) algorithm. The resulting network has:

  • Input layer: Nodes for each propositional variable (A, B, C, D)
  • Hidden layer: One node per clause, implementing conjunctive gates
  • Output layer: One node for the outcome (E), implementing a disjunctive gate

The hidden layer structure thus mirrors the disjunctive normal form of the rule, preserving information about which variables form minimal sufficient conditions together.

3. MCMC Sampling for Counterfactual Simulation

Starting from the actual world state, the model explores alternative scenarios through Markov Chain Monte Carlo sampling (Metropolis-Hastings or Gibbs variants). Transition probabilities depend on:

  1. Normality biases on input nodes (normal events sampled more frequently)
  2. Hidden node activations (transitions that would flip hidden nodes are resisted)

The second factor introduces autocorrelation that keeps sampled counterfactuals close to the actual world, while also grouping variables by their shared clause membership.

4. Layer-wise Feedback Propagation (Weight Updates)

As the model visits different counterfactual states, it performs weight updates using Layer-wise Feedback Propagation (LFP)—a procedure inspired by Layer-wise Relevance Propagation (LRP) from interpretable machine learning. In each state:

  • A reward signal is assigned based on the outcome
  • Rewards propagate backward, adjusting weights based on each connection’s contribution
  • Connections from inputs aligned with the outcome are strengthened

This process encodes the covariation between causes and outcomes directly into the network’s weight parameters.

5. Causal Importance via Relevance Propagation

After sampling, causal importance scores are computed by propagating relevance from the output back through the network. The measure combines:

  • Relevance scores from LRP (proportional to connection weights)
  • Path complexity (\(\kappa\) measure) penalizing explanations that recruit multiple disjoint clauses

\[\kappa(C, O) = \frac{\sum_{c \in C} R_c}{\mathcal{C}(C, O)}\]

where \(\mathcal{C}(C, O)\) counts the number of edge-disjoint paths from causes \(C\) to outcome \(O\).

Experimental Data

This model was tested against human causal-selection judgments collected in two experiments involving games of chance with urns containing colored balls.

Experiment 1: The “Two out of Three” Rule

Participants judged causal importance for an outcome determined by: \(\text{WIN} := A + B + C \geq 2\)

Three urns had different proportions of colored balls (probabilities 0.05, 0.5, 0.95). Key finding: plural causes (e.g., “urns A and B caused the win”) showed non-linear effects that couldn’t be derived from singular judgments alone.

Experiment 1 paradigm: Participants played rounds with urns of varying probabilities, then judged what caused a fictitious player's win.

Experiment 1 paradigm: Participants played rounds with urns of varying probabilities, then judged what caused a fictitious player’s win.

Query format in Experiment 1: Causal judgments were collected for both singular and plural causes.

Query format in Experiment 1: Causal judgments were collected for both singular and plural causes.

Experiment 2: The Disjunctive Rule

Participants judged outcomes under: \(\text{WIN} := (A \land B) \lor (C \land D)\)

This rule explicitly groups variables into “teams” (purple balls vs. orange balls). Key findings:

  • Plural causes on the same side of the disjunction (A∧B or C∧D) were preferred over cross-disjunction pairs (A∧C)
  • When explaining losses, the disjunctive structure effect disappeared—consistent with the hypothesis that negated plurals are interpreted homogeneously (like linguistic plurals)
Experiment 2 stimulus: Four urns grouped into two 'teams' by ball color. Win requires two balls of the same color.

Experiment 2 stimulus: Four urns grouped into two ‘teams’ by ball color. Win requires two balls of the same color.

1 Components of the Model


1.1 A function for generating Neural Networks out of Logic Programs: generate_network_from_logic_program

Dissertation Reference: This function implements the CILP (Connectionist Inductive Learning and Logic Programming) translation algorithm from Part I, Section 2 of the dissertation.

This code implements the transformation from a General Logic Program (with clauses of the shape \(A \leftarrow B_1, B_2\), possibly with negations) into a Neural Network that captures the same logical relationships.

As described in the relevant chapters, the translated Neural Network will contain:

  • A hidden neuron for each clause.
  • An input neuron for each unique literal that can appear in the body.
  • An output neuron for each unique literal that can appear in the head.

All edges from the input to hidden layer and from hidden to output layer have weight \(\pm W\) (see Equation 4.2 in the dissertation), with the sign determined by whether the literal appears positively or negatively in the clause.

Thresholds for hidden and output neurons are computed based on \(\displaystyle A_{\min}\)
and \(\displaystyle \text{MAX$_{P}$}\) (the maximum size among [body length, #clauses per same head]).
These thresholds ensure that the hidden neuron approximates a conjunctive gate,
and the output neuron approximates a disjunctive gate, as described
in the original CILP approach (d’Avila Garcez et al., 2002).

The compute_parameters() function in particular implements the relevant equations from the thesis, such as:

\[ A_{\text{threshold}} \;=\; \frac{\text{MAX\_P} - 1}{\text{MAX$_{P}$} + 1} \]

\[ A_{\min} \;=\; A_{\text{threshold}} \;+\; \frac{1}{\text{MAX$_{P}$}} \]

\[ W \;=\; \frac{2}{\beta}\;\left[\, \frac{\ln\bigl(1 + A_{\min}\bigr)\;-\;\ln\bigl(1 - A_{\min}\bigr)} {\text{MAX$_{P}$}\,\bigl(A_{\min} - 1\bigr)\;+\;A_{\min}\;+\;1} \right] \]

The resulting network is then converted into a convenient “sampling_network”
format, storing edges/biases/nodes in a format convenient for subsequent sampling and updates.

See the thesis discussion in “An Algorithm for Translating Programs into Neural Networks”
for more context on how each step implements the logical relationships.


generate_network_from_logic_program <- function(clauses, input_probabilities, plot_network = FALSE) {
  
  # =====================================================================
  # Nested function: parse_clauses()
  # Gathers heads, bodies, negative literal counts, etc.
  # =====================================================================
  parse_clauses <- function(clauses) {
    all_literals        <- c()
    heads               <- c()
    body_literals_list  <- list()
    n_l_list            <- c()
    k_l_list            <- c()
    mu_l_list           <- c()
    
    # Collect head names without '!'
    head_names <- sapply(clauses, function(clause) gsub("!", "", clause$head))
    # Count how many clauses share each head
    head_counts <- table(head_names)
    
    for (i in seq_along(clauses)) {
      clause <- clauses[[i]]
      head   <- clause$head
      body   <- clause$body
      heads  <- c(heads, head)

      # --------------------------------------------------------
      # Record each literal (removing '!')
      # --------------------------------------------------------
      literals   <- gsub("!", "", body)
      head_name  <- gsub("!", "", head)
      all_literals <- c(all_literals, literals, head_name)
      
      # --------------------------------------------------------
      # Negative literal count & total literal count
      # --------------------------------------------------------
      n_l       <- sum(grepl("^!", body))
      n_l_list  <- c(n_l_list, n_l)
      k_l       <- length(body)
      k_l_list  <- c(k_l_list, k_l)
      
      # --------------------------------------------------------
      # Store the body in a list
      # --------------------------------------------------------
      body_literals_list[[i]] <- body
      
      # --------------------------------------------------------
      # mu_l = # of clauses sharing the same head
      # --------------------------------------------------------
      mu_l <- as.numeric(head_counts[head_name])
      mu_l_list <- c(mu_l_list, mu_l)
    }
    
    # Remove duplicates for all_literals
    all_literals <- unique(all_literals)
    
    return(list(
      all_literals       = all_literals,
      heads              = heads,
      body_literals_list = body_literals_list,
      n_l_list           = n_l_list,
      k_l_list           = k_l_list,
      mu_l_list          = mu_l_list
    ))
  }
  
  # =====================================================================
  # Nested function: compute_parameters()
  # Computes thresholds for A_min and W based on the maximum body size
  # and #clauses per head. Reflects Equations from CILP approach.
  # =====================================================================
  compute_parameters <- function(parsed_clauses, D = 1, beta = 1) {
    k_l_list  <- parsed_clauses$k_l_list
    mu_l_list <- parsed_clauses$mu_l_list
    
    # --------------------------------------------------------
    # MAX_P is the largest of (max #body-literals, max #clauses sharing the same head)
    # --------------------------------------------------------
    MAX_P <- max(c(k_l_list, mu_l_list))
    
    # --------------------------------------------------------
    # A_min slightly above (MAX_P - 1)/(MAX_P + 1)
    # --------------------------------------------------------
    A_min_threshold <- (MAX_P - 1) / (MAX_P + 1)
    A_min           <- A_min_threshold + 1 / MAX_P
    
    # --------------------------------------------------------
    # Weight magnitude W from eqn in thesis
    # --------------------------------------------------------
    numerator   <- log(1 + A_min) - log(1 - A_min)
    denominator <- MAX_P * (A_min - 1) + A_min + 1
    W           <- (2 / beta) * (numerator / denominator)
    W           <- W / D
    
    # --------------------------------------------------------
    # Theta (hidden thresholds): ensures a conjunctive gate approximation
    # --------------------------------------------------------
    theta_l_list <- ((1 + A_min) * (k_l_list - 1) / 2) * W
    
    # --------------------------------------------------------
    # Theta (output thresholds): ensures a disjunctive gate approximation
    # --------------------------------------------------------
    output_neurons <- unique(gsub("!", "", parsed_clauses$heads))
    mu_A_list      <- sapply(output_neurons, function(output_neuron) {
      sum(gsub("!", "", parsed_clauses$heads) == output_neuron)
    })
    
    theta_A_list <- ((1 + A_min) * (1 - mu_A_list) / 2) * W
    names(theta_A_list) <- output_neurons
    
    return(list(
      A_min         = A_min / D,
      W             = W,
      theta_l_list  = theta_l_list,
      theta_A_list  = theta_A_list
    ))
  }
  
  # =====================================================================
  # Nested function: create_network()
  # Builds a network object with:
  #   - W_input_hidden
  #   - W_hidden_output
  #   - bias_input, bias_hidden, bias_output
  # For subsequent translation into the final "sampling format."
  # =====================================================================
  create_network <- function(parsed_clauses, params, input_probabilities) {
    all_literals       <- parsed_clauses$all_literals
    body_literals_list <- parsed_clauses$body_literals_list
    heads              <- parsed_clauses$heads
    n_clauses          <- length(heads)
    
    # --------------------------------------------------------
    # Identify input vs. hidden vs. output neurons
    # --------------------------------------------------------
    input_neurons <- unique(unlist(body_literals_list))
    input_neurons <- gsub("!", "", input_neurons)
    input_neurons <- unique(input_neurons)
    
    hidden_neurons <- sapply(seq_len(n_clauses), function(i) {
  clause_body <- body_literals_list[[i]]            # e.g. c("A", "B") or c("!A", "B")
  clause_body_clean <- gsub("!", "", clause_body)   # remove any "!"
  clause_body_lower <- tolower(clause_body_clean)   # turn A->a, B->b, etc.
  
  # Collapse the body-literal names into one string, e.g. "ab"
  subscript <- paste0(clause_body_lower, collapse = "")
  
  # Construct a label like "H_ab"
  hidden_name <- paste0("H", subscript)
  return(hidden_name)
})

    output_neurons <- unique(gsub("!", "", heads))
    
    # Index mappings (not strictly needed, but helps clarity)
    input_indices  <- setNames(seq_along(input_neurons),   input_neurons)
    hidden_indices <- setNames(seq_along(hidden_neurons),  hidden_neurons)
    output_indices <- setNames(seq_along(output_neurons),  output_neurons)
    
    # --------------------------------------------------------
    # Initialize empty weight matrices & biases
    # --------------------------------------------------------
    W_input_hidden  <- matrix(0, nrow = length(input_neurons),  ncol = n_clauses)
    W_hidden_output <- matrix(0, nrow = n_clauses,              ncol = length(output_neurons))
    
    bias_hidden     <- params$theta_l_list * (-1)
    names(bias_hidden) <- hidden_neurons
    
    bias_output     <- -1 * params$theta_A_list[output_neurons]
    names(bias_output) <- output_neurons
    
    W <- params$W
    
    # --------------------------------------------------------
    # Fill W_input_hidden with ±W (sign depends on negation)
    # --------------------------------------------------------
    for (i in seq_len(n_clauses)) {
      clause_body <- body_literals_list[[i]]
      for (literal in clause_body) {
        is_negative  <- startsWith(literal, "!")
        literal_name <- gsub("!", "", literal)
        input_idx    <- input_indices[[literal_name]]
        
        weight <- ifelse(is_negative, -W, W)
        W_input_hidden[input_idx, i] <- weight
      }
    }
    
    # --------------------------------------------------------
    # Fill W_hidden_output: also ±W depending on negation in heads
    # --------------------------------------------------------
    for (i in seq_len(n_clauses)) {
      head        <- heads[[i]]
      is_negative <- startsWith(head, "!")
      head_name   <- gsub("!", "", head)
      output_idx  <- output_indices[[head_name]]
      
      weight <- ifelse(is_negative, -W, W)
      W_hidden_output[i, output_idx] <- weight
    }
    
    # --------------------------------------------------------
    # Compute biases for input neurons from input_probabilities
    # --------------------------------------------------------
    bias_input <- numeric(length(input_neurons))
    names(bias_input) <- input_neurons
    for (input_neuron in input_neurons) {
      p <- input_probabilities[[input_neuron]]
      b <- log(p / (1 - p))  # logistic transform to get bias
      bias_input[input_neuron] <- b
    }
    
    # Label the rows & cols for clarity
    rownames(W_input_hidden) <- input_neurons
    colnames(W_input_hidden) <- hidden_neurons
    rownames(W_hidden_output) <- hidden_neurons
    colnames(W_hidden_output) <- output_neurons
    
    return(list(
      W_input_hidden  = W_input_hidden,
      W_hidden_output = W_hidden_output,
      bias_input      = bias_input,
      bias_hidden     = bias_hidden,
      bias_output     = bias_output,
      input_neurons   = input_neurons,
      hidden_neurons  = hidden_neurons,
      output_neurons  = output_neurons
    ))
  }
  
  # =====================================================================
  # Nested function: convert_network_to_sampling_format()
  # Turns the above raw network into a list of edges, biases, layers, etc.
  # for easy sampling or forward passes.
  # =====================================================================
  convert_network_to_sampling_format <- function(network) {
    nodes <- c(network$input_neurons, network$hidden_neurons, network$output_neurons)
    
    # --------------------------------------------------------
    # Build a named list of biases (input, hidden, output)
    # --------------------------------------------------------
    biases <- list()
    for (input_neuron in network$input_neurons) {
      biases[[input_neuron]] <- network$bias_input[[input_neuron]]
    }
    for (hidden_neuron in network$hidden_neurons) {
      biases[[hidden_neuron]] <- network$bias_hidden[[hidden_neuron]]
    }
    for (output_neuron in network$output_neurons) {
      biases[[output_neuron]] <- network$bias_output[[output_neuron]]
    }
    
    # --------------------------------------------------------
    # Build edges (from input->hidden, hidden->output)
    # --------------------------------------------------------
    edges <- list()
    
    W_input_hidden  <- network$W_input_hidden
    input_neurons   <- network$input_neurons
    hidden_neurons  <- network$hidden_neurons
    
    for (i in seq_along(input_neurons)) {
      for (j in seq_along(hidden_neurons)) {
        weight <- W_input_hidden[i, j]
        if (weight != 0) {
          edges[[length(edges) + 1]] <- list(
            from   = input_neurons[i],
            to     = hidden_neurons[j],
            weight = weight
          )
        }
      }
    }
    
    W_hidden_output <- network$W_hidden_output
    output_neurons  <- network$output_neurons
    
    for (i in seq_along(hidden_neurons)) {
      for (j in seq_along(output_neurons)) {
        weight <- W_hidden_output[i, j]
        if (weight != 0) {
          edges[[length(edges) + 1]] <- list(
            from   = hidden_neurons[i],
            to     = output_neurons[j],
            weight = weight
          )
        }
      }
    }
    
    # --------------------------------------------------------
    # Define layer structure
    # --------------------------------------------------------
    layers <- list(
      network$input_neurons,
      network$hidden_neurons,
      network$output_neurons
    )
    
    return(list(
      nodes   = nodes,
      biases  = biases,
      edges   = edges,
      layers  = layers,
      network = network  # Original network object for reference
    ))
  }
  
  # =====================================================================
  # Nested function: create_causal_graph()
  # Optionally plots the generated network as a diagram, using DiagrammeR.
  # =====================================================================
  create_causal_graph <- function(network) {
    library(DiagrammeR)
    
    graph       <- DiagrammeR::create_graph()
    nodes       <- list()
    node_labels <- list()
    
    # 1) Input nodes
    for (input in network$input_neurons) {
      bias <- round(network$bias_input[input], 2)
      input_label <- paste0(input, "\nBias: ", bias)
      graph <- graph %>%
        add_node(
          label   = input_label,
          type    = "input",
          node_aes= node_aes(shape = "circle", color = "lightblue")
        )
      nodes             <- c(nodes, input)
      node_labels[[input]] <- input_label
    }
    
    # 2) Hidden nodes
    for (i in seq_along(network$hidden_neurons)) {
      hidden_node  <- network$hidden_neurons[i]
      bias         <- round(network$bias_hidden[i], 2)
      hidden_label <- paste0(hidden_node, "\nBias: ", bias)
      graph <- graph %>%
        add_node(
          label   = hidden_label,
          type    = "hidden",
          node_aes= node_aes(shape = "circle", color = "orange")
        )
      nodes             <- c(nodes, hidden_node)
      node_labels[[hidden_node]] <- hidden_label
    }
    
    # 3) Output nodes
    for (i in seq_along(network$output_neurons)) {
      output_node  <- network$output_neurons[i]
      bias         <- round(network$bias_output[i], 2)
      output_label <- paste0(output_node, "\nBias: ", bias)
      graph <- graph %>%
        add_node(
          label   = output_label,
          type    = "output",
          node_aes= node_aes(shape = "circle", color = "lightgreen")
        )
      nodes                <- c(nodes, output_node)
      node_labels[[output_node]] <- output_label
    }
    
    # 4) Edges: input->hidden
    for (i in seq_along(network$input_neurons)) {
      input_node <- network$input_neurons[i]
      for (j in seq_along(network$hidden_neurons)) {
        hidden_node <- network$hidden_neurons[j]
        weight      <- network$W_input_hidden[i, j]
        if (weight != 0) {
          edge_color <- ifelse(weight > 0, "green", "red")
          graph <- graph %>% add_edge(
            from    = node_labels[[input_node]],
            to      = node_labels[[hidden_node]],
            rel     = "causal",
            edge_aes= edge_aes(label = paste("W:", round(weight, 2)), color = edge_color)
          )
        }
      }
    }
    
    # 5) Edges: hidden->output
    for (i in seq_along(network$hidden_neurons)) {
      hidden_node <- network$hidden_neurons[i]
      for (j in seq_along(network$output_neurons)) {
        output_node <- network$output_neurons[j]
        weight      <- network$W_hidden_output[i, j]
        if (weight != 0) {
          edge_color <- ifelse(weight > 0, "green", "red")
          graph <- graph %>% add_edge(
            from    = node_labels[[hidden_node]],
            to      = node_labels[[output_node]],
            rel     = "causal",
            edge_aes= edge_aes(label = paste("W:", round(weight, 2)), color = edge_color)
          )
        }
      }
    }
    
    # Return the rendered graph (will display in knitr chunks when called)
    render_graph(graph, layout = "tree")
  }
  
  # =====================================================================
  # Main body of generate_network_from_logic_program()
  # Orchestrates all nested steps, returning sampling_network.
  # =====================================================================
  
  # 1) Parse the user-provided clauses
  parsed_clauses <- parse_clauses(clauses)
  
  # 2) Compute parameters (A_min, W, thresholds)
  params <- compute_parameters(parsed_clauses)
  
  # 3) Create raw network object
  network <- create_network(parsed_clauses, params, input_probabilities)
  
  # 4) Convert it to sampling format (edges, layers, biases)
  sampling_network <- convert_network_to_sampling_format(network)
  
  # 5) Create graph if requested
  network_graph <- NULL
  if (plot_network) {
    network_graph <- create_causal_graph(network)
  }

  # 6) Return the final "sampling_network" object (with optional graph)
  sampling_network$graph <- network_graph
  return(sampling_network)
}

1.1.1 Example: Experiment 2a from Quillien & Lucas (2023)

This section illustrates how we can represent the causal relationships from Experiment 2a in Quillien & Lucas (2023). In that experiment, we have a simple causal model with exogenous nodes \(U_{A}, U_{B}, U_{C}\) representing probabilities of drawing a certain ball from each urn:

\[ A := U_{1}, \quad P(U_{1}) = 0.05 \] \[ B := U_{2}, \quad P(U_{2}) = 0.5 \] \[ C := U_{3}, \quad P(U_{3}) = 0.95 \]

The outcome \(\mathrm{E}\) occurs if and only if at least two of the three events \(A, B, C\) happen. Formally, in the SCM notation:

\[ \mathrm{E} := (A + B + C) \geq 2 \]

To express this same idea in Logic Program form, we can split \(\mathrm{E}\) into three clauses, each specifying that \(\mathrm{E}\) will be true if any pair of the three events \(A, B, C\) is true. This gives a Definite Logic Program with clauses:

  • \(E \leftarrow A \land B\)
  • \(E \leftarrow A \land C\)
  • \(E \leftarrow B \land C\)

Below is how we represent it in our code via the generate_network_from_logic_program() function.


# -------------------------------------------------------
# Example: Logic Program for "E = (A+B+C >= 2)"
# -------------------------------------------------------

# Define three clauses:
#   D <- A, C
#   D <- A, B
#   D <- B, C
# In the code, "D" stands in for "E" or "Effect" in the original discussion;
# we can rename as needed. The clauses match the fact that "D is true if (A & C) OR (A & B) OR (B & C)."

clauses <- list(
  list(head = "D", body = c("A", "C")),
  list(head = "D", body = c("A", "B")),
  list(head = "D", body = c("B", "C"))
)

# Corresponding input probabilities for A, B, C
# These reflect the exogenous probabilities from the causal model:
#   A ~ 0.05
#   B ~ 0.5
#   C ~ 0.95
input_probabilities <- list(
  "A" = 0.05,
  "B" = 0.5,
  "C" = 0.95
)

# Generate the sampling network from the logic program
# "plot_network = TRUE" will create a DiagrammeR graph of the resulting connections
sampling_network <- generate_network_from_logic_program(
  clauses,
  input_probabilities,
  plot_network = TRUE
)

# Display the network graph
sampling_network$graph

Neural network representation of the causal structure E = (A+B+C >= 2). Input nodes (blue) represent causes with their bias terms (log-odds of prior probability). Hidden nodes (orange) represent clause detectors. The output node (green) represents the effect. Edge labels show connection weights.


1.2 Sampling Over the Neural Networks

Dissertation Reference: This section implements the MCMC sampling procedures from Part I, Section 3 of the dissertation, “Sampling Counterfactual Worlds.”

As discussed in the thesis, once we have generated a neural network using the CILP approach, we want to sample possible activations of the network’s input nodes (which correspond to different counterfactual scenarios). We do this using Monte Carlo Markov Chain (MCMC) techniques, which allow us to explore nearby counterfactual worlds from the vantage of a specific “real-world” state.

Below are the main functions implementing the two sampling schemes discussed:

  1. The mutation sampler (mh_layered_sampler), a variant that selects one input node to flip at each step, computes an acceptance probability based on a “loss” difference, and either moves to the new state or stays in the old one.

  2. Layer-wise Gibbs Sampling (layer_gibbs_sampler), which similarly updates nodes in a structured order, computing local probabilities for each node being \(+1\) vs. \(-1\).

Both algorithms incorporate: - Biases (\(\log(\frac{p}{1-p})\)) on input nodes, reflecting the normality or prior probability of events. - Temperature scaling, which can amplify or reduce the effect of biases on flipping acceptance (although we don’t ever use that parameter). - Deterministic updates for non-input layers, ensuring that hidden/output neurons reflect a thresholded activation based on the current input states.

These functions are central to exploring “nearby” states in the network, anchoring the sampling process to the real world while still allowing occasional transitions to less likely worlds.


# ===============================================================
# generate_initial_hidden_states():
#   Initializes the state of hidden neurons given input states
#   using simple sign checks on net inputs (bias + weighted sum).
# ===============================================================
generate_initial_hidden_states <- function(initial_state_io, network) {
  input_neurons    <- network$input_neurons
  hidden_neurons   <- network$hidden_neurons
  W_input_hidden   <- network$W_input_hidden
  bias_hidden      <- network$bias_hidden
  
  # Initialize the full state with whatever inputs/outputs were provided
  initial_state <- initial_state_io
  
  # For each hidden neuron, compute the net input from input neurons
  # and assign state = +1 if net_input>0, else -1.
  for (hidden_neuron in hidden_neurons) {
    net_input <- bias_hidden[[hidden_neuron]]
    
    # Sum over all input neurons
    for (input_neuron in input_neurons) {
      weight <- W_input_hidden[input_neuron, hidden_neuron]
      if (!is.null(initial_state[[input_neuron]])) {
        net_input <- net_input + weight * initial_state[[input_neuron]]
      } else {
        # If no explicit value, assume -1
        net_input <- net_input + weight * (-1)
      }
    }
    
    # Threshold at 0
    state_value <- ifelse(net_input > 0, 1, -1)
    initial_state[[hidden_neuron]] <- state_value
  }
  
  return(initial_state)
}

# ===============================================================
# Simple sigmoid function for computing probabilities from net inputs
# ===============================================================
sigmoid <- function(x) {
  classent <- class(x)
  if (!(classent == "numeric")) {
  cat("sigmoid() input x =", x, " class=", class(x), "\n")
  }
  1 / (1 + exp(-x))
}

# ===============================================================
# Negative log-likelihood loss, comparing {+1/-1} target to a predicted probability
# ===============================================================
compute_loss_NLL <- function(target, prob) {
  if (target == 1) {
    -log(prob)
  } else if (target == -1) {
    -log(1 - prob)
  } else {
    stop("Target must be 1 or -1")
  }
}

# ===============================================================
# mh_layered_sampler():
#   Implements a layered Metropolis-Hastings approach.
#   - Randomly flips one input node at a time, then computes acceptance prob
#   - Deterministic updates for hidden/output layers
# ===============================================================
mh_layered_sampler <- function(
  iterations, nodes, layers, edges, biases, initial_state = NULL, temperature = 1
) {
  # Identify which nodes are input nodes (first layer)
  input_nodes <- layers[[1]]
  
  # Scale biases by temperature
  biases_T <- list()
  for (node in nodes) {
    biases_T[[node]] <- biases[[node]] / temperature
  }
  
  # Scale weights by temperature
  edges_T <- lapply(edges, function(e) list(
    from   = e$from,
    to     = e$to,
    weight = e$weight / temperature
  ))
  
  # Initialize state: either all +1 or a user-supplied initial_state
  if (is.null(initial_state)) {
    state <- setNames(rep(1, length(nodes)), nodes)
  } else {
    # Ensure it's a named list
    if (!is.list(initial_state)) {
      state <- as.list(initial_state)
    } else {
      state <- initial_state
    }
  }
  
  # Prepare a data.frame to store the sequence of sampled states
  states <- data.frame(matrix(ncol = length(nodes), nrow = iterations + 1))
  colnames(states) <- nodes
  
  # Record the initial state in the first row
  states[1, ] <- sapply(nodes, function(n) state[[n]])
  
  # Build a "parent map" for quick reference: which nodes feed into each node?
  parent_map <- list()
  for (node in nodes) {
    parent_map[[node]] <- c()
  }
  for (e in edges_T) {
    parent_map[[e$to]] <- c(parent_map[[e$to]], e$from)
  }
  
  # Also build an edge_map from "X->Y" to the scaled weight
  edge_map <- list()
  for (e in edges_T) {
    key <- paste(e$from, e$to, sep = "->")
    edge_map[[key]] <- e$weight
  }
  
  # ===============================================================
  # MAIN SAMPLING LOOP
  # ===============================================================
  for (iter in 1:iterations) {
    # -------------------------------------------
    # 1) Metropolis-Hastings step for input layer
    # -------------------------------------------
    current_input_state <- unlist(state[input_nodes])
    
    # Randomly select one input node to flip
    node_to_flip <- sample(input_nodes, 1)
    
    # Proposed state: flip that node
    proposed_state <- state
    proposed_state[[node_to_flip]] <- -proposed_state[[node_to_flip]]
    
    # We'll define an internal function to compute total loss for
    # the input nodes + their immediate children
    compute_total_loss <- function(state_snapshot) {
      total_loss <- 0
      
      # For each input node, compute negative log-likelihood based on its bias
      for (inp in input_nodes) {
        net_input_inp <- biases_T[[inp]]
        prob_inp       <- sigmoid(net_input_inp)
        total_loss     <- total_loss + compute_loss_NLL(state_snapshot[[inp]], prob_inp)
        
        # Now consider children: each child is a hidden node that depends on "inp"
        children_edges <- edges_T[sapply(edges_T, function(e) e$from == inp)]
        for (child_edge in children_edges) {
          child <- child_edge$to
          net_input_child <- biases_T[[child]]
          parents         <- parent_map[[child]]
          
          # Sum up parent's contributions
          for (parent in parents) {
            edge_key       <- paste(parent, child, sep = "->")
            net_input_child <- net_input_child + edge_map[[edge_key]] * state_snapshot[[parent]]
          }
          
          # Evaluate child's predicted probability
          prob_child <- sigmoid(net_input_child)
          
          # The child's observed state is +1 or -1, so compute NLL
          # We scale or adjust the child's loss by length(children_edges) and weight magnitude, to reduce extremes
          this_loss_child <- compute_loss_NLL(state_snapshot[[child]], prob_child) /
            length(children_edges) / abs(edge_map[[edge_key]])
          
          total_loss <- total_loss + this_loss_child
        }
      }
      return(total_loss)
    }
    
    # Compute loss for current vs. proposed states
    loss_current  <- compute_total_loss(state)
    loss_proposed <- compute_total_loss(proposed_state)
    
    # Metropolis-Hastings acceptance probability
    delta_loss       <- loss_proposed - loss_current
    acceptance_prob  <- min(1, exp(-delta_loss))
    
    # Decide whether to accept
    if (runif(1) < acceptance_prob) {
      # Accept the proposed flip
      state[input_nodes] <- proposed_state[input_nodes]
    }
    
    # ------------------------------------------------
    # 2) Deterministic update for subsequent layers
    # ------------------------------------------------
    for (layer_idx in 2:length(layers)) {
      layer <- layers[[layer_idx]]
      for (node in layer) {
        # Summation of parent's influences + bias
        net_in <- biases_T[[node]]
        parents <- parent_map[[node]]
        for (parent in parents) {
          edge_key <- paste(parent, node, sep = "->")
          net_in   <- net_in + edge_map[[edge_key]] * state[[parent]]
        }
        # Threshold at 0 => +1 or -1
        if (net_in >= 0) {
          state[[node]] <- 1
        } else {
          state[[node]] <- -1
        }
      }
    }
    
    # Record the new state
    states[iter + 1, ] <- sapply(nodes, function(n) state[[n]])
  }
  
  return(list(states = states))
}

# ===============================================================
# layer_gibbs_sampler():
#   Implements a layer-wise Gibbs sampling approach.
#   - Each input node is sampled from a local posterior ~ exp(-loss)
#   - Hidden/output nodes updated deterministically
# ===============================================================
layer_gibbs_sampler <- function(
  iterations, nodes, layers, edges, biases, initial_state = NULL, temperature = 1
) {
  # Identify input nodes
  input_nodes <- layers[[1]]
  
  # Scale biases by temperature
  biases_T <- list()
  for (node in nodes) {
    biases_T[[node]] <- biases[[node]] / temperature
  }
  
  # Scale weights by temperature
  edges_T <- lapply(edges, function(e) list(
    from   = e$from,
    to     = e$to,
    weight = e$weight / temperature
  ))
  
  # Initialize the chain
  if (is.null(initial_state)) {
    # By default, random assignment
    state <- setNames(sample(c(-1, 1), length(nodes), replace = TRUE), nodes)
  } else {
    state <- initial_state
  }
  if (!is.list(state)) {
    state <- as.list(state)
  }
  
  # Set up a data.frame to record sampled states
  states <- data.frame(matrix(ncol = length(nodes), nrow = iterations + 1))
  colnames(states) <- nodes
  
  # Start with the initial state
  states[1, ] <- sapply(nodes, function(n) state[[n]])
  
  # Build parent & child maps
  parent_map <- list()
  child_map  <- list()
  for (node in nodes) {
    parent_map[[node]] <- c()
    child_map[[node]]  <- c()
  }
  for (e in edges_T) {
    parent_map[[e$to]] <- c(parent_map[[e$to]], e$from)
    child_map[[e$from]]<- c(child_map[[e$from]], e$to)
  }
  
  # ===============================================================
  # MAIN SAMPLING LOOP (layer-wise updates)
  # ===============================================================
  for (iter in 1:iterations) {
    # For each layer, update the nodes in that layer
    for (layer_idx in 1:length(layers)) {
      layer <- layers[[layer_idx]]
      for (node in layer) {
        # If it's an input node, we do a mini Gibbs step
        if (node %in% input_nodes) {
          
          # Evaluate the loss if node = +1 vs. node = -1
          # We'll store them in 'losses' for each possible value
          losses <- c()
          for (val in c(1, -1)) {
            temp_state      <- state
            temp_state[[node]] <- val
            
            # Compute the "input node" negative log-likelihood
            net_in_node     <- biases_T[[node]]
            prob_node       <- sigmoid(net_in_node)
            loss_node       <- compute_loss_NLL(val, prob_node)
            
            total_loss      <- loss_node
            
            # Also incorporate each child's penalty
            for (child in child_map[[node]]) {
              net_input_child <- biases_T[[child]]
              # Sum up parent's influences
              for (parent in parent_map[[child]]) {
                edge <- Filter(function(e) e$from == parent && e$to == child, edges_T)[[1]]
                net_input_child <- net_input_child + edge$weight * temp_state[[parent]]
              }
              prob_child     <- sigmoid(net_input_child)
              # We scale by #child_map[[node]] or weight magnitude, depending on logic
              loss_child     <- compute_loss_NLL(state[[child]], prob_child) /
                length(child_map[[node]]) / abs(edge$weight)
              
              total_loss     <- total_loss + loss_child
            }
            
            losses <- c(losses, total_loss)
          }
          
          # Convert those two losses into probabilities
          probs <- exp(-losses)
          probs <- probs / sum(probs)
          
          # Sample the node's new value from {+1, -1}
          state[[node]] <- sample(c(1, -1), 1, prob = probs)
          
        } else {
          # If it's not an input node, do a deterministic update
          net_in <- biases_T[[node]]
          for (parent in parent_map[[node]]) {
            edge <- Filter(function(e) e$from == parent && e$to == node, edges_T)[[1]]
            net_in <- net_in + edge$weight * state[[parent]]
          }
          # Threshold at 0
          if (net_in >= 0) {
            state[[node]] <- 1
          } else {
            state[[node]] <- -1
          }
        }
      }
    }
    # Store the new overall state
    states[iter + 1, ] <- sapply(nodes, function(n) state[[n]])
  }
  
  return(list(states = states))
}

1.3 Transition Trees for MCMC Sampling

After obtaining sampled state sequences from either the Metropolis-Hastings (mh_layered_sampler()) or Gibbs (layer_gibbs_sampler()) algorithms, we can gain insight into the most common transitions and states visited by constructing a transition tree.

The two functions below, build_transition_counts_and_stays() and generate_transition_tree():

  1. build_transition_counts_and_stays()
    • Takes a list of state sequences (each one typically being the result of a single MCMC simulation).
    • Tracks how many times the chain stays in the same state consecutively, as well as how many times it transitions from one state to a distinct next state.
    • Returns aggregated counts and the total number of transitions.
  2. generate_transition_tree()
    • Runs multiple MCMC simulations (by default, num_simulations times).
    • Aggregates all sampled states into a single dataset.
    • Identifies the top \(N\) states by frequency.
    • Builds a tree (using the data.tree package) whose root is the initial state. Each node branches out to up to branching_factor of its most probable next states.
    • Truncates the tree at a specified max_depth to avoid an overly large diagram.

1.3.1 Reading the Transition Tree

  • Root node: The initial state from the first simulation.
  • Child nodes: The top (up to branching_factor) transitions that occur from that state. Each edge is labeled with a transition probability (the fraction of transitions that go from the parent to that particular child).
  • Average Stay (avg_stay_length): The average number of consecutive steps the chain spent in this state before moving on to another.
  • Depth: Goes from 1 (the root) down to max_depth. Beyond that, we do not expand further in the tree.

Such a tree can be generated for Experiment 2a (Quillien & Lucas, 2023) or for the “Boys and Girls” party example from the dissertation. By adjusting the sampler function (mh_layered_sampler vs. layer_gibbs_sampler), you can see how the transition structure differs between these two MCMC approaches.

Below is the annotated code. We end with an example that: - Defines some clauses and input probabilities (like the “Boys and Girls” scenario or your “D = A+B >= 2” scenario). - Builds the sampling network. - Runs the tree-generation function to produce and plot the transition tree for both Gibbs and MH samplers.


# ================================================================
# build_transition_counts_and_stays():
#   Aggregates how often the chain stays in the same state or moves
#   to a new state across a list of state sequences.
#   - 'state_sequences' is a list of data.frames, each row = a sampled state.
#   - 'nodes' is a character vector of node names (variables).
# ================================================================
build_transition_counts_and_stays <- function(state_sequences, nodes) {
  transition_counts  <- list()
  stay_counts        <- list()
  total_transitions  <- 0
  
  # Loop over each simulation's sequence
  for (seq in state_sequences) {
    num_states <- nrow(seq)
    
    i <- 1
    # We'll advance 'i' step by step through the chain
    while (i < num_states) {
      # Convert the current row's node values into a comma-separated string
      current_state <- paste(seq[i, nodes], collapse = ",")
      stay_length   <- 1
      j <- i + 1
      
      # Count how many consecutive rows remain the same as 'current_state'
      while (j <= num_states && all(seq[j, nodes] == seq[i, nodes])) {
        stay_length <- stay_length + 1
        j <- j + 1
      }
      
      # Record how long the chain stayed in 'current_state'
      if (!is.null(stay_counts[[current_state]])) {
        stay_counts[[current_state]] <- c(stay_counts[[current_state]], stay_length)
      } else {
        stay_counts[[current_state]] <- stay_length
      }
      
      # If there's a *next* state (i.e. j <= num_states), record the transition
      if (j <= num_states) {
        next_state <- paste(seq[j, nodes], collapse = ",")
        
        # If we already have transitions from 'current_state'
        if (!is.null(transition_counts[[current_state]])) {
          # If we have a count for 'next_state' from 'current_state'
          if (!is.null(transition_counts[[current_state]][[next_state]])) {
            transition_counts[[current_state]][[next_state]] <-
              transition_counts[[current_state]][[next_state]] + 1
          } else {
            transition_counts[[current_state]][[next_state]] <- 1
          }
        } else {
          # If this is our first time seeing 'current_state'
          transition_counts[[current_state]] <- list()
          transition_counts[[current_state]][[next_state]] <- 1
        }
        total_transitions <- total_transitions + 1
      }
      
      # Now skip past the block of repeated states
      i <- j
    }
  }
  
  return(list(
    transition_counts = transition_counts,
    stay_counts       = stay_counts,
    total_transitions = total_transitions
  ))
}





# ================================================================
# generate_transition_tree():
#   1) Runs multiple MCMC simulations (MH or Gibbs) to get state sequences.
#   2) Aggregates states into a single data.frame, identifies top states.
#   3) Builds a 'data.tree' that starts from the initial state and expands
#      up to 'branching_factor' next states, to a depth 'max_depth'.
#   4) Plots the tree with edges labeled by transition probabilities,
#      and nodes labeled with avg stay length and state name.
# ================================================================

# Packages loaded in setup chunk

generate_transition_tree <- function(
  nodes, layers, edges, biases,
  iterations_per_simulation = 1000,  # how many steps per chain
  num_simulations           = 10,    # how many chains
  max_depth                 = 5,     # tree expansion limit
  branching_factor          = 3,     # how many top next-states to expand
  initial_state             = NULL,
  temperature               = 1,
  sampler_function          = mh_layered_sampler, # which MCMC approach
  N_top_states             = 10
) {
  # Identify which nodes are "input" so we only label those in the tree.
  # We assume layers[[1]] is the input layer.
  input_nodes <- layers[[1]]  
  
  # -------------------------------------------------------------------
  # Same helper as before, but we'll pass input_nodes instead of all 'nodes'.
  # -------------------------------------------------------------------
  label_state_key <- function(state_str, node_names) {
    # Example: state_str = "1,1,-1,1", node_names = c("A","B")...
    # We'll only pair up the first length(node_names) entries in state_str.
    state_vals <- unlist(strsplit(state_str, ",", fixed = TRUE))
    
    if (length(state_vals) == length(nodes)) {
      # If the raw string has as many entries as *all* nodes,
      # we only want the first `length(node_names)` chunk for labeling.
      # (Because node_names here is actually input_nodes.)
      wanted_vals <- state_vals[seq_along(node_names)]
      
      labeled <- mapply(
        function(nm, val) paste0(nm, "=", val),
        node_names, wanted_vals,
        SIMPLIFY = TRUE
      )
      paste(labeled, collapse = ", ")
    } else {
      # Fallback
      state_str
    }
  }
  
  # 1) Collect multiple MCMC runs
  state_sequences <- list()
  all_states      <- list()
  
  for (sim in 1:num_simulations) {
    result_mh <- sampler_function(
      iterations     = iterations_per_simulation,
      nodes          = nodes,
      layers         = layers,
      edges          = edges,
      biases         = biases,
      initial_state  = initial_state,
      temperature    = temperature
    )
    states_mh <- result_mh$states
    state_sequences[[sim]] <- states_mh
    all_states[[sim]]      <- states_mh
  }
  
  # 2) Flatten and identify top states
  all_states_df <- do.call(rbind, all_states)
  state_keys    <- apply(all_states_df, 1, function(rowvals) paste(rowvals[nodes], collapse = ","))
  state_table   <- table(state_keys)
  sorted_states <- sort(state_table, decreasing = TRUE)
  top_states    <- head(sorted_states, N_top_states)
  
  # Create a nicely formatted data frame for the top states
  top_states_df <- data.frame(
    State = names(top_states),
    `Visit Count` = as.integer(top_states),
    `Proportion` = round(as.integer(top_states) / sum(state_table), 3),
    check.names = FALSE
  )

  # Note: top_states_df is returned and should be displayed with kable in calling chunk
  
  # 3) Build transitions, stays, etc.
  transition_data   <- build_transition_counts_and_stays(state_sequences, nodes)
  transition_counts <- transition_data$transition_counts
  stay_counts       <- transition_data$stay_counts
  total_transitions <- transition_data$total_transitions
  
  # 4) Identify the initial state's key (full set of nodes)
  raw_initial_key <- paste(state_sequences[[1]][1, nodes], collapse = ",")
  
  # Build a label for the root, but show only input nodes
  labeled_initial <- label_state_key(raw_initial_key, input_nodes)
  
  # Root node with the label
  root <- Node$new(paste0("State: ", labeled_initial))
  root$state_key <- raw_initial_key  # store the FULL state key internally
  
  # -------------------------------------------------------------------
  # Recursive function to expand children up to max_depth
  # -------------------------------------------------------------------
  build_tree <- function(node, current_depth) {
    if (current_depth >= max_depth) return()
    
    state_key   <- node$state_key   # raw full state key, e.g. "1,1,-1,1"
    transitions <- transition_counts[[state_key]]
    stays       <- stay_counts[[state_key]]
    
    if (!is.null(stays)) {
      node$avg_stay_length <- mean(stays)
    }
    if (!is.null(transitions)) {
      sorted_transitions <- sort(unlist(transitions), decreasing = TRUE)
      top_state_keys     <- names(sorted_transitions)[1:min(branching_factor, length(sorted_transitions))]
      top_counts         <- sorted_transitions[top_state_keys]
      total_counts       <- sum(unlist(transitions))
      
      for (i in seq_along(top_state_keys)) {
        next_state_key <- top_state_keys[i]
        count          <- top_counts[i]
        prob           <- count / total_counts
        
        # Now only the input nodes get displayed in the child's label
        labeled_child <- label_state_key(next_state_key, input_nodes)
        
        child_node <- node$AddChild(paste0("State: ", labeled_child))
        child_node$state_key       <- next_state_key
        child_node$transitionCount <- count
        child_node$prob            <- prob
        
        build_tree(child_node, current_depth + 1)
      }
    }
  }
  
  build_tree(root, 1)
  
  # 5) Style the tree
  SetGraphStyle(root, rankdir = "TB")
  SetEdgeStyle(root, fontname = 'helvetica', arrowhead = "vee", arrowsize = 0.5)
  SetNodeStyle(root, fontname = 'helvetica', shape = "box", label = function(node) {
    lbl <- node$name
    if (!is.null(node$avg_stay_length)) {
      lbl <- paste0(lbl, "\nAvg Stay: ", round(node$avg_stay_length, 2))
    }
    if (!is.null(node$prob)) {
      lbl <- paste0(lbl, "\nProb: ", round(node$prob, 3))
    }
    lbl
  })
  
  # 6) Convert data.tree => DiagrammeR => HTML widget
  g <- ToDiagrammeRGraph(root)
  graph_widget <- render_graph(g)
  
  # 7) Create a Rehder & Davis style state transition graph
  # This shows all visited states and their transition probabilities
  # Pass the actual initial state key so the graph correctly highlights it
  state_graph <- create_state_transition_graph(transition_counts, stay_counts, input_nodes, top_states,
                                               initial_state_key = raw_initial_key)

  # 8) Return everything
  list(
    tree               = root,
    top_states         = top_states,
    top_states_df      = top_states_df,
    all_states         = all_states_df,
    graph_widget       = graph_widget,
    state_graph        = state_graph,
    transition_counts  = transition_counts
  )
}


# ================================================================
# create_state_transition_graph():
#   Creates a Rehder & Davis (2012) style state transition diagram
#   showing states as nodes and transitions as directed edges.
#   Node size reflects visit frequency; edge width reflects transition probability.
#   Uses ggraph for separate curved arrows between states.
# ================================================================
create_state_transition_graph <- function(transition_counts, stay_counts, input_nodes, top_states,
                                          initial_state_key = "1,1,1",
                                          size_exponent = 0.5,
                                          min_node_size = 12,
                                          max_node_size = 40,
                                          min_edge_width = 0.3,
                                          max_edge_width = 4,
                                          edge_min_prob = 0.01,
                                          show_labels = TRUE) {
  library(igraph)
  library(ggraph)
  library(ggplot2)

  state_names <- names(top_states)
  state_counts <- as.integer(top_states)
  total_visits <- sum(state_counts)
  n_inputs <- length(input_nodes)

  # Label function: compact notation with negation symbol
  # Only uses the first n_inputs values from the state key (the input nodes)
  compact_label <- function(state_str) {
    vals <- as.numeric(strsplit(state_str, ",")[[1]])
    # Only take first n_inputs values (the input nodes, not hidden/output)
    vals <- vals[1:min(n_inputs, length(vals))]
    parts <- sapply(seq_along(vals), function(j) {
      if (j <= length(input_nodes)) {
        if (vals[j] == 1) input_nodes[j] else paste0("\u00AC", input_nodes[j])
      } else {
        as.character(vals[j])
      }
    })
    paste(parts, collapse = " ")
  }

  # Check if outcome is positive
  # For D = A+B+C >= 2 type scenarios, at least 2 of the first 3 inputs must be 1
  is_win <- function(state_str) {
    vals <- as.numeric(strsplit(state_str, ",")[[1]])
    # Only consider input nodes for win/loss determination
    input_vals <- vals[1:min(n_inputs, length(vals))]
    sum(input_vals == 1) >= 2
  }

  # Calculate proportions
  proportions <- state_counts / total_visits
  max_prop <- max(proportions)

  # Build edge list
  edges_df <- data.frame(from = character(), to = character(),
                         prob = numeric(), stringsAsFactors = FALSE)

  for (from_state in state_names) {
    if (!is.null(transition_counts[[from_state]])) {
      trans <- transition_counts[[from_state]]
      total_trans <- sum(unlist(trans))

      for (to_state in names(trans)) {
        if (to_state %in% state_names && from_state != to_state) {
          trans_prob <- trans[[to_state]] / total_trans
          if (trans_prob >= edge_min_prob) {
            edges_df <- rbind(edges_df, data.frame(
              from = from_state, to = to_state, prob = trans_prob,
              stringsAsFactors = FALSE
            ))
          }
        }
      }
    }
  }

  max_trans_prob <- max(edges_df$prob, na.rm = TRUE)
  edges_df$width <- min_edge_width + (max_edge_width - min_edge_width) * (edges_df$prob / max_trans_prob)
  edges_df$label <- ifelse(edges_df$prob > 0.03 & show_labels,
                           paste0(round(edges_df$prob * 100), "%"), "")

  # Build node data
  nodes_df <- data.frame(
    name = state_names,
    label = sapply(state_names, compact_label),
    proportion = proportions,
    pct_label = paste0(round(proportions * 100, 1), "%"),
    is_win = sapply(state_names, is_win),
    is_initial = state_names == initial_state_key,
    stringsAsFactors = FALSE
  )

  # Size scaling with power-law for more dramatic differences
  nodes_df$size_factor <- (nodes_df$proportion / max_prop)^size_exponent
  nodes_df$size <- min_node_size + (max_node_size - min_node_size) * nodes_df$size_factor

  # Create igraph object
  g <- graph_from_data_frame(edges_df, directed = TRUE, vertices = nodes_df)

  # Create a deterministic layout based on state names
  # This ensures consistent node positions across different graphs
  # Use a hash of the state name to determine position on a circle
  n_nodes <- nrow(nodes_df)
  # Sort state names alphabetically to get consistent ordering
  sorted_names <- sort(nodes_df$name)
  name_to_angle <- setNames(seq(0, 2*pi*(1 - 1/n_nodes), length.out = n_nodes), sorted_names)

  # Create layout matrix with x, y coordinates on a circle
  layout_matrix <- matrix(0, nrow = n_nodes, ncol = 2)
  for (i in seq_len(n_nodes)) {
    angle <- name_to_angle[nodes_df$name[i]]
    layout_matrix[i, 1] <- cos(angle)
    layout_matrix[i, 2] <- sin(angle)
  }

  # Plot with ggraph - using arc for curved separate edges
  ggraph(g, layout = "manual", x = layout_matrix[,1], y = layout_matrix[,2]) +
    geom_edge_arc(aes(width = width, alpha = prob),
                  arrow = arrow(length = unit(4, "mm"), type = "closed"),
                  end_cap = circle(8, "mm"),
                  start_cap = circle(8, "mm"),
                  strength = 0.2,
                  color = "gray40") +
    geom_edge_arc(aes(label = label),
                  strength = 0.2,
                  label_dodge = unit(4, "mm"),
                  angle_calc = "along",
                  label_size = 3.5,
                  color = NA) +
    geom_node_point(aes(size = size,
                        fill = is_win,
                        stroke = ifelse(is_initial, 3, 0.8)),
                    shape = 21, color = "gray30") +
    geom_node_text(aes(label = paste0(label, "\n", pct_label)),
                   size = 3.5, vjust = 0.5) +
    scale_fill_manual(values = c("TRUE" = "#90EE90", "FALSE" = "#FFB6C1"),
                      guide = "none") +
    scale_edge_width(range = c(min_edge_width, max_edge_width), guide = "none") +
    scale_edge_alpha(range = c(0.4, 1), guide = "none") +
    scale_size_identity() +
    theme_void() +
    theme(legend.position = "none") +
    coord_fixed()
}

1.3.2 Experiment 2 Example

Below is the example for Experiment 1’s “\(D = A + B + C \geq 2\)” scenario, where \(D\) is true if any two of \(A\), \(B\), or \(C\) are true. We set the input probabilities of \(A\), \(B\), and \(C\) (matching the LOW, INTERMEDIATE, and HIGH urns from the experiment), generate the sampling network, and then run Metropolis-Hastings and Gibbs sampling to produce and visualize a transition tree. We use num_simulations = 2000 with 100 iterations per simulation for stable, noise-free results.

# ================================================================
# Example code for Experiment 1 (D = A+B+C >= 2 scenario)
# Uses 2000 simulations x 100 iterations for stable results
# Results cached via knitr cache to avoid re-running expensive MCMC
# ================================================================

cache_file_exp1 <- "cache/experiment1_mcmc_results.rds"

if (file.exists(cache_file_exp1)) {
  # Load cached results
  cached_exp1 <- readRDS(cache_file_exp1)
  gibbs_tree_exp2 <- cached_exp1$gibbs_tree
  mh_tree_exp2 <- cached_exp1$mh_tree
  sampling_network_experiment2 <- cached_exp1$network
  initial_state_experiment2 <- cached_exp1$initial_state
} else {
  # Clauses: D <- A,B; D <- A,C; D <- B,C (any 2 of 3)
  clauses_experiment2 <- list(
    list(head = "D", body = c("A", "B")),
    list(head = "D", body = c("A", "C")),
    list(head = "D", body = c("B", "C"))
  )

  # Input probabilities matching Experiment 1:
  # A = LOW (0.05), B = INTERMEDIATE (0.5), C = HIGH (0.95)
  input_probabilities_experiment2 <- list(
    "A" = 0.05,
    "B" = 0.5,
    "C" = 0.95
  )

  # 1) Generate the sampling network
  sampling_network_experiment2 <- generate_network_from_logic_program(
    clauses_experiment2,
    input_probabilities_experiment2,
    plot_network = FALSE
  )

  # 2) Define initial observed state (all balls drawn, player wins)
  initial_state_io_experiment2 <- list("A"=1, "B"=1, "C"=1, "D"=1)

  # 3) Create full initial state, including hidden neurons
  networkent_experiment2    <- sampling_network_experiment2$network
  initial_state_experiment2 <- generate_initial_hidden_states(initial_state_io_experiment2, networkent_experiment2)

  # 4) Build transition tree with Gibbs sampler (2000 sims x 100 iters)
  gibbs_tree_exp2 <- generate_transition_tree(
    nodes                    = sampling_network_experiment2$nodes,
    layers                   = sampling_network_experiment2$layers,
    edges                    = sampling_network_experiment2$edges,
    biases                   = sampling_network_experiment2$biases,
    iterations_per_simulation= 100,
    num_simulations          = 2000,
    max_depth                = 3,
    branching_factor         = 2,
    initial_state            = initial_state_experiment2,
    temperature              = 1,
    sampler_function         = layer_gibbs_sampler
  )

  # 5) Build transition tree with Metropolis-Hastings sampler
  mh_tree_exp2 <- generate_transition_tree(
    nodes                    = sampling_network_experiment2$nodes,
    layers                   = sampling_network_experiment2$layers,
    edges                    = sampling_network_experiment2$edges,
    biases                   = sampling_network_experiment2$biases,
    iterations_per_simulation= 100,
    num_simulations          = 2000,
    max_depth                = 3,
    branching_factor         = 2,
    initial_state            = initial_state_experiment2,
    temperature              = 1,
    sampler_function         = mh_layered_sampler
  )

  # Save results to cache
  dir.create("cache", showWarnings = FALSE)
  saveRDS(list(
    gibbs_tree = gibbs_tree_exp2,
    mh_tree = mh_tree_exp2,
    network = sampling_network_experiment2,
    initial_state = initial_state_experiment2
  ), cache_file_exp1)
}

1.3.2.1 Metropolis-Hastings Results

Top visited states (MH Sampler):

Most frequently visited states (MH Sampler)
State Visit Count Proportion
-1,-1,1,-1,-1,-1,-1 97058 0.480
-1,1,1,-1,-1,1,1 89881 0.445
1,1,1,1,1,1,1 5815 0.029
-1,-1,-1,-1,-1,-1,-1 3878 0.019
-1,1,-1,-1,-1,-1,-1 2734 0.014
1,-1,1,-1,1,-1,1 2470 0.012
1,-1,-1,-1,-1,-1,-1 118 0.001
1,1,-1,1,-1,-1,1 46 0.000
State Transition Graph (MH Sampler). Node size reflects visit frequency; edge width reflects transition probability.

State Transition Graph (MH Sampler). Node size reflects visit frequency; edge width reflects transition probability.

1.3.2.2 Gibbs Sampling Results

Top visited states (Gibbs Sampler):

Most frequently visited states (Gibbs Sampler)
State Visit Count Proportion
-1,-1,1,-1,-1,-1,-1 95235 0.471
-1,1,1,-1,-1,1,1 90451 0.448
-1,1,-1,-1,-1,-1,-1 4967 0.025
-1,-1,-1,-1,-1,-1,-1 3808 0.019
1,1,1,1,1,1,1 3800 0.019
1,-1,1,-1,1,-1,1 3341 0.017
1,-1,-1,-1,-1,-1,-1 236 0.001
1,1,-1,1,-1,-1,1 162 0.001
State Transition Graph (Gibbs Sampler). Node size reflects visit frequency; edge width reflects transition probability.

State Transition Graph (Gibbs Sampler). Node size reflects visit frequency; edge width reflects transition probability.


1.3.3 Experiment 2 Example: The Disjunctive Rule

This example illustrates Experiment 2’s disjunctive rule: \(F := (A \land B) \lor (C \land D)\), analogous to the “Boys and Girls” scenario from the dissertation. Here, an outcome depends on either all members of one group or all members of another group—capturing the structure where variables are grouped into “teams.”

In this example, \(F\) (“the party is great”) is true if all three members of the “Boys” group (J, B, D) attend OR all three members of the “Girls” group (M, S, C) attend. We set J (one of the boys) to have lower probability (0.5) than the others (0.8), making his attendance the “abnormal” event.

# ================================================================
# Experiment 2 / Boys and Girls example:
# F := (J ∧ B ∧ D) ∨ (M ∧ S ∧ C)
# Uses 2000 simulations x 100 iterations for stable results
# Results cached to avoid re-running expensive MCMC simulations
# ================================================================

cache_file_bg <- "cache/boys_girls_mcmc_results.rds"

if (file.exists(cache_file_bg)) {
  # Load cached results
  cached_bg <- readRDS(cache_file_bg)
  gibbs_tree_bg <- cached_bg$gibbs_tree
  mh_tree_bg <- cached_bg$mh_tree
  sampling_network_boys_girls <- cached_bg$network
  initial_state_bg <- cached_bg$initial_state
} else {
  # Clauses for "Boys and Girls" scenario
  # F is true if all Boys attend OR all Girls attend
  clauses_boys_girls <- list(
    list(head = "F", body = c("J", "B", "D")),  # Boys group
    list(head = "F", body = c("M", "S", "C"))   # Girls group
  )

  # Input probabilities: J is less likely (abnormal), others are normal
  input_probabilities_boys_girls <- list(
    "J" = 0.5,   # Less likely (abnormal)
    "B" = 0.8,
    "D" = 0.8,
    "M" = 0.8,
    "S" = 0.8,
    "C" = 0.8
  )

  # 1) Generate the sampling network
  sampling_network_boys_girls <- generate_network_from_logic_program(
    clauses_boys_girls,
    input_probabilities_boys_girls,
    plot_network = FALSE
  )

  # 2) Define initial state: in the actual world, everyone came and F is true
  initial_state_io_bg <- list(
    "J"=1, "B"=1, "D"=1,
    "M"=1, "S"=1, "C"=1,
    "F"=1
  )

  # 3) Get the underlying raw network object
  networkent_bg    <- sampling_network_boys_girls$network
  initial_state_bg <- generate_initial_hidden_states(initial_state_io_bg, networkent_bg)

  # 4) Build transition tree with Gibbs sampler (2000 sims x 100 iters)
  gibbs_tree_bg <- generate_transition_tree(
    nodes                    = sampling_network_boys_girls$nodes,
    layers                   = sampling_network_boys_girls$layers,
    edges                    = sampling_network_boys_girls$edges,
    biases                   = sampling_network_boys_girls$biases,
    iterations_per_simulation= 100,
    num_simulations          = 2000,
    max_depth                = 3,
    branching_factor         = 2,
    initial_state            = initial_state_bg,
    temperature              = 1,
    sampler_function         = layer_gibbs_sampler
  )

  # 5) Build transition tree with Metropolis-Hastings sampler
  mh_tree_bg <- generate_transition_tree(
    nodes                    = sampling_network_boys_girls$nodes,
    layers                   = sampling_network_boys_girls$layers,
    edges                    = sampling_network_boys_girls$edges,
    biases                   = sampling_network_boys_girls$biases,
    iterations_per_simulation= 100,
    num_simulations          = 2000,
    max_depth                = 3,
    branching_factor         = 2,
    initial_state            = initial_state_bg,
    temperature              = 1,
    sampler_function         = mh_layered_sampler
  )

  # Save results to cache
  dir.create("cache", showWarnings = FALSE)
  saveRDS(list(
    gibbs_tree = gibbs_tree_bg,
    mh_tree = mh_tree_bg,
    network = sampling_network_boys_girls,
    initial_state = initial_state_bg
  ), cache_file_bg)
}

1.3.3.1 Metropolis-Hastings Results

Top visited states (MH Sampler):

Most frequently visited states (MH Sampler)
State Visit Count Proportion
1,1,1,1,1,1,1,1,1 83693 0.414
-1,1,1,1,1,1,-1,1,1 33619 0.166
-1,1,-1,1,1,1,-1,1,1 8066 0.040
-1,-1,1,1,1,1,-1,1,1 7608 0.038
1,-1,1,1,1,1,-1,1,1 7347 0.036
1,1,-1,1,1,1,-1,1,1 7075 0.035
1,1,1,1,1,-1,1,-1,1 6275 0.031
1,1,1,-1,1,1,1,-1,1 6143 0.030
1,1,1,1,-1,1,1,-1,1 6048 0.030
-1,1,1,1,-1,1,-1,-1,-1 3647 0.018

1.3.3.2 Gibbs Sampling Results

Top visited states (Gibbs Sampler):

Most frequently visited states (Gibbs Sampler)
State Visit Count Proportion
-1,1,1,1,1,1,-1,1,1 38151 0.189
1,1,1,1,1,1,1,1,1 26890 0.133
-1,1,1,1,1,-1,-1,-1,-1 10815 0.054
-1,1,1,1,-1,1,-1,-1,-1 10211 0.051
1,1,-1,1,1,1,-1,1,1 9760 0.048
1,-1,1,1,1,1,-1,1,1 9616 0.048
-1,1,1,-1,1,1,-1,-1,-1 9548 0.047
-1,1,-1,1,1,1,-1,1,1 8144 0.040
-1,-1,1,1,1,1,-1,1,1 8005 0.040
1,1,1,1,1,-1,1,-1,1 6949 0.034

1.4 Weight Update Methods: Layer-wise Feedback Propagation

Dissertation Reference: This section implements the core Layer-wise Relevance Propagation (LRP) technique from Part I, Section 4 of the dissertation, adapted for computing causal importance scores.

The code below implements the Layer-wise Feedback Propagation (LFP) approach described in the thesis, where updates to the weights \(\{w_{ij}\}\) are driven by (a) the reward assigned to each output node, and (b) the decomposition of each neuron’s net input \(z_j\) into individual contributions \(z_{ij} = w_{ij}\,a_i\). This procedure is reminiscent of LRP (Layer-wise Relevance Propagation), but adapted here for training or “re-weighting” the network in a manner that: 1. Focuses on matching the actual outcome in the real world, rather than a purely external target. 2. Accumulates incremental updates \(\delta_{w_{ij}}^\text{lfp}\) that reflect each neuron’s share in explaining the observed outcome.

Concretely: - We run a forward pass with the \(\mathbf{a}_{\text{in}}\) from the real world, extracting \(a_{Oc}^{AW}\). - For each sampled counterfactual state \(\{s_i\}\), we compute the outcome neuron’s activation \(a_{Oc}^i\). The reward \(r_{Oc}^i\) is given by \[ r_{Oc}^i = \begin{cases} a_{Oc}^{AW} \times a_{Oc}^i, &\text{(if \(a_{Oc}^{AW} > 0\))}\\ \text{abs}(a_{Oc}^i), &\text{(if \(a_{Oc}^{AW} < 0\))} \end{cases} \] meaning negative outcomes invert the reward profile.

  • We do a backward pass for each layer in reverse order, distributing the reward among the parent neurons in proportion to their net contribution \(z_{ij}\). Each weight is updated by: \[ \delta_{w_{ij}}^\text{lfp} \;\;=\;\; \frac{|w_{ij}|\;a_i}{|z_j^+|\;+\;|z_j^-|\;+\;\varepsilon} \;\times\; \text{sign}(z_j)\;\times\;r_j \] or the corresponding decomposition that handles positive/negative contributions more explicitly (see Equations 4.17 and 4.18 in the dissertation, for the LFP\(_{z^+z^-}\) rule).

Below is the annotated R code for these operations.


##########################################################
# compute_weight_updates():
#   Applies LFP-style updates to the weights of a neural network
#   based on a series of sampled states (counterfactuals).
#
#   1) We run a forward pass w/ "input_values" to get a_Oc^AW.
#   2) For each sample from 'samples', we compute an outcome reward,
#      distribute it backward through each layer, and accumulate
#      incremental weight & bias updates.
#
# Arguments:
#   samples            : data.frame of counterfactual states (rows).
#   edges              : list() of edges, each w/ (from, to, weight).
#   nodes              : structure containing nodes$input, nodes$hidden, nodes$output, etc.
#   biases             : named list of neuron biases (ex: biases[["N2"]]).
#   learning_rate      : scalar "eta" used in the final update step.
#   weight_update_rule : "Additive" or "Multiplicative" for applying updates.
#   input_values       : the 'real world' input, used for the forward pass to get a_Oc^AW.
#
# Returns a list with:
#   $weight_updates  : named list of net weight increments (or exponents),
#   $bias_updates    : named list of net bias increments (or exponents).
#
##########################################################
compute_weight_updates <- function(
  samples, edges, nodes, biases,
  learning_rate = 0.1,
  weight_update_rule = 'Additive',
  input_values
) {
  ############################################################
  # 1) Determine the actual outcome activation a_Oc_AW
  #    by running a forward pass with the real-world input_values.
  ############################################################
  {
    # We'll store activations_aw as a list of node->activation
    activations_aw <- list()

    # Set input activations from 'input_values'
    for (node in nodes$input) {
      activations_aw[[node]] <- input_values[[node]]
    }

    # We'll apply tanh() to hidden+output sequentially
    activation_function <- function(x) tanh(x)
    all_nodes_aw <- c(nodes$hidden, nodes$output)
    net_inputs_aw <- list()

    for (node in all_nodes_aw) {
      # gather incoming edges
      incoming_edges <- Filter(function(e) e$to == node, edges)
      net_in <- biases[[node]]
      for (e in incoming_edges) {
        from_node <- e$from
        w <- e$weight
        a_from <- activations_aw[[from_node]]
        net_in <- net_in + w * a_from
      }
      net_inputs_aw[[node]] <- net_in
      activations_aw[[node]] <- activation_function(net_in)
    }

    # We assume a single output node for the outcome
    output_node <- nodes$output[[1]]
    a_Oc_AW <- activations_aw[[output_node]]
  }

  ############################################################
  # 2) Initialize cumulative weight/bias updates
  ############################################################
  cumulative_weight_updates <- list()
  cumulative_bias_updates   <- list()
  for (edge in edges) {
    key <- paste(edge$from, edge$to, sep = "->")
    cumulative_weight_updates[[key]] <- 0
  }
  for (node in names(biases)) {
    cumulative_bias_updates[[node]] <- 0
  }

  num_samples <- nrow(samples)

  ############################################################
  # 3) For each sample in 'samples', do:
  #    - Forward pass => get net inputs & outputs
  #    - Compute a reward, r_{Oc}^i, as per eq. in the thesis
  #    - Backprop that reward using LFP logic
  #    - Accumulate the incremental delta weights
  ############################################################
  for (i in seq_len(num_samples)) {
    # Convert row i into a named list
    state_vector <- samples[i, ]
    state <- as.list(state_vector)
    names(state) <- colnames(samples)

    # Compute forward pass for this sample
    forward_result <- compute_outputs(state, edges, biases, nodes)
    outputs    <- forward_result$outputs
    net_inputs <- forward_result$net_inputs

    ########################################################
    # 3a) Compute the reward for the outcome node(s)
    #     eq: r_{Oc}^i = a_Oc^{AW} * a_Oc^i ( if a_Oc_AW > 0 )
    #         = abs(a_Oc^i)               ( if a_Oc_AW < 0 )
    ########################################################
    rewards <- list()
    output_nodes <- nodes$output
    for (out_node in output_nodes) {
      current_outcome_activation <- outputs[[out_node]]

      # If the real outcome is positive, reward is a_Oc^{AW} * a_Oc^i
      # If negative, we use abs(a_Oc^i) => flipping sign
      if (a_Oc_AW > 0) {
        r_out <- a_Oc_AW * current_outcome_activation
      } else {
        r_out <- abs(current_outcome_activation)
      }
      rewards[[out_node]] <- r_out
    }

    ############################################################
    # 3b) Backward pass of reward via LFP
    #     We process layers from output back to hidden, skipping input layer.
    ############################################################
    layers_rev <- rev(nodes$layers)

    for (layer in layers_rev) {
      # skip if it's the input layer
      if (all(layer %in% nodes$input)) {
        next
      }

      for (node_j in layer) {
        r_j <- rewards[[node_j]]
        if (is.null(r_j)) next

        # net input: z_j
        z_j <- net_inputs[[node_j]]
        z_j_sign <- sign(z_j)

        incoming_edges <- edges[sapply(edges, function(e) e$to == node_j)]
        z_j_pos <- 0
        z_j_neg <- 0
        z_ij_list <- list()

        # Decompose z_j into positive & negative contributions
        for (e in incoming_edges) {
          from_node <- e$from
          w_ij <- e$weight
          a_i  <- outputs[[from_node]]
          z_ij <- a_i * w_ij
          z_ij_list[[from_node]] <- z_ij

          if (z_ij >= 0) {
            z_j_pos <- z_j_pos + z_ij
          } else {
            z_j_neg <- z_j_neg + z_ij
          }
        }

        # Also incorporate bias in that decomposition
        z_ij_bias <- biases[[node_j]]
        z_ij_list[["bias"]] <- z_ij_bias
        if (z_ij_bias >= 0) {
          z_j_pos <- z_j_pos + z_ij_bias
        } else {
          z_j_neg <- z_j_neg + z_ij_bias
        }

        z_j_pos_abs <- abs(z_j_pos)
        z_j_neg_abs <- abs(z_j_neg)
        z_j_total_abs <- z_j_pos_abs + z_j_neg_abs + 1e-9  # small epsilon to avoid /0

        # Propagate reward backward
        for (e in incoming_edges) {
          from_node <- e$from
          w_ij <- e$weight
          a_i  <- outputs[[from_node]]
          z_ij <- z_ij_list[[from_node]]

          # fraction1 & fraction2 separate pos vs. neg side
          if (z_ij >= 0) {
            fraction1 <- z_j_pos_abs / z_j_total_abs
            fraction2 <- z_ij / (z_j_pos + 1e-9)
          } else {
            fraction1 <- z_j_neg_abs / z_j_total_abs
            fraction2 <- z_ij / (z_j_neg + 1e-9)
          }

          # r_ij = fraction1 * fraction2 * sign(z_j) * r_j
          r_ij <- fraction1 * fraction2 * z_j_sign * r_j

          # Add partial reward to the from_node
          if (is.null(rewards[[from_node]])) {
            rewards[[from_node]] <- r_ij
          } else {
            rewards[[from_node]] <- rewards[[from_node]] + r_ij
          }

          # Weight update eq:
          # delta_w_ij = (|w_ij| * a_i / z_j_total_abs) * sign(z_j) * r_j
          delta_w_ij <- (abs(w_ij) * a_i) / z_j_total_abs * z_j_sign * r_j
          delta_w_ij <- learning_rate * delta_w_ij

          key <- paste(from_node, node_j, sep="->")
          if (weight_update_rule == 'Additive') {
            cumulative_weight_updates[[key]] <- cumulative_weight_updates[[key]] + delta_w_ij
          } else if (weight_update_rule == 'Multiplicative') {
            cumulative_weight_updates[[key]] <- cumulative_weight_updates[[key]] * exp(delta_w_ij)
          } else {
            stop("Invalid weight_update_rule. Use 'Additive' or 'Multiplicative'.")
          }
        }

        # Bias update (treated as from a 'bias node' with a=1)
        delta_b_j <- (abs(biases[[node_j]]) * 1) / z_j_total_abs * z_j_sign * r_j
        delta_b_j  <- learning_rate * delta_b_j

        if (weight_update_rule == 'Additive') {
          cumulative_bias_updates[[node_j]] <- cumulative_bias_updates[[node_j]] + delta_b_j
        } else if (weight_update_rule == 'Multiplicative') {
          cumulative_bias_updates[[node_j]] <- cumulative_bias_updates[[node_j]] * exp(delta_b_j)
        } else {
          stop("Invalid weight_update_rule. Use 'Additive' or 'Multiplicative'.")
        }
      } # end for node_j
    } # end for layer
  } # end for i in samples

  # Return the net updates
  list(
    weight_updates = cumulative_weight_updates,
    bias_updates   = cumulative_bias_updates
  )
}


##########################################################
# tanh activation_function
##########################################################
activation_function <- function(x) {
  tanh(x)
}

##########################################################
# compute_outputs():
#   Given a partial 'state' for input nodes and an entire net of edges/biases,
#   do a forward pass across each layer, returning net inputs & outputs.
##########################################################
compute_outputs <- function(state, edges, biases, nodes) {
  net_inputs <- list()
  outputs    <- state  # initialize with the known input states

  layers <- nodes$layers
  for (layer in layers) {
    for (node in layer) {
      # If it's an input node, skip (already set)
      if (node %in% nodes$input) {
        next
      }
      # sum up w_ij*a_i + bias
      incoming_edges <- Filter(function(edge) edge$to == node, edges)
      net_in <- biases[[node]]
      for (e in incoming_edges) {
        from_node <- e$from
        weight    <- e$weight
        act_from  <- outputs[[from_node]]
        net_in    <- net_in + weight * act_from
      }

      net_inputs[[node]] <- net_in
      # e.g. tanh
      outputs[[node]] <- activation_function(net_in)
    }
  }
  list(net_inputs = net_inputs, outputs = outputs)
}

1.5 LRP-Based Causal Scores on the Updated Network

The final step in our procedure is to compute causal impact scores after the network has been updated via Layer-wise Feedback Propagation. As discussed in the thesis:

  1. We compute a relevance score for each input neuron by applying Layer-Wise Relevance Propagation (LRP) to the updated network \(\mathcal{N}^{(\mathrm{up})}\). This yields a distribution of “importance” or “relevance” across the input variables.

  2. To turn these “importance scores” into causal impact scores, we also factor in the structural complexity of how each cause set \(C\) triggers the outcome node \(O\). Concretely:

    • We identify \(\#_{\mathrm{ED}}(C,O)\), the max number of edge-disjoint positive paths from the cause set \(C\) to the outcome \(O\).
    • We track \(\mathrm{Idle}(C,O)\), the subset of \(C\) that never participates in any positive path.
    • The final “causal impact” \(\kappa(C,O)\) is a ratio:
      \[ \kappa(C,O) = \frac{\sum_{c \in C} R_c}{\#_{\mathrm{ED}}(C,O) + \lvert \mathrm{Idle}(C,O)\rvert}\,, \] for a positive outcome. A similar logic (or a separate penalty factor) applies for negative outcomes.

Below is the code that implements the LRP-based measure (compute_lrp_importance()) and the subsequent computation of causal impact scores (compute_causal_impact_scores()). We also provide two small wrapper functions:

  • perform_sampling(): runs either Metropolis-Hastings or Gibbs sampling to produce a series of states.
  • run_simulation(): orchestrates all steps (sampling, weight update, final LRP, and causal score output).

# =====================================================================
# compute_lrp_importance():
#   Applies a Layer-Wise Relevance Propagation approach to a given 
#   neural network state (weights, biases, final updated edges).
#
#   1) We do a forward pass with the actual-world input_values 
#      to get a_Oc^AW (the outcome node's activation).
#   2) We do a second forward pass but store net_inputs for 
#      each layer to guide the backward pass.
#   3) We initialize the outcome node with R=1, 
#      then propagate relevance backward via an alpha-beta rule 
#      (Equation or logic from your LRP chapter).
#
#   The final result is a named list R[...] of 
#   "importance" or "relevance" assigned to each input node 
#   (and hidden nodes, though typically we focus on input).
# =====================================================================
compute_lrp_importance <- function(
  input_values, 
  weights, 
  biases, 
  nodes, 
  edges,
  lrp_variant = 'alpha_beta', 
  alpha = 1, 
  beta = 0
) {
  ############################################################
  # 1) Forward pass to find a_Oc^AW for reference
  ############################################################
  {
    # For the real-world scenario
    activations_aw <- list()
    for (node in nodes$input) {
      activations_aw[[node]] <- input_values[[node]]
    }
    activation_function <- function(x) tanh(x)
    all_nodes_aw <- c(nodes$hidden, nodes$output)
    net_inputs_aw <- list()
    
    for (node in all_nodes_aw) {
      incoming_edges_aw <- Filter(function(e) e$to == node, edges)
      net_input_aw <- biases[[node]]
      for (e in incoming_edges_aw) {
        from_node <- e$from
        w <- weights[[paste(from_node, node, sep="->")]]
        a_from <- activations_aw[[from_node]]
        net_input_aw <- net_input_aw + w * a_from
      }
      net_inputs_aw[[node]] <- net_input_aw
      activations_aw[[node]] <- activation_function(net_input_aw)
    }
    # outcome node
    output_node <- nodes$output[[1]]
    a_Oc_AW <- activations_aw[[output_node]]
  }
  
  ############################################################
  # 2) Possibly adjust alpha,beta if outcome is negative
  #    (Equation or logic from alpha-beta LRP).
  ############################################################
  epsilon <- 1e-6
  if (a_Oc_AW > 0) {
    alpha <- 1 + epsilon
    beta  <- 0 + epsilon
  } else {
    alpha <- 0 + epsilon
    beta  <- -1 + epsilon
  }
  
  ############################################################
  # 3) Full forward pass for the LRP process itself
  ############################################################
  activations <- list()
  net_inputs  <- list()
  for (node in nodes$input) {
    activations[[node]] <- input_values[[node]]
  }
  
  activation_function <- function(x) tanh(x)
  all_nodes <- c(nodes$hidden, nodes$output)
  
  for (node in all_nodes) {
    incoming_edges_node <- Filter(function(e) e$to == node, edges)
    net_in <- biases[[node]]
    for (e in incoming_edges_node) {
      from_node <- e$from
      w <- weights[[paste(from_node, node, sep="->")]]
      a_from <- activations[[from_node]]
      net_in <- net_in + w * a_from
    }
    net_inputs[[node]] <- net_in
    activations[[node]] <- activation_function(net_in)
  }
  
  ############################################################
  # 4) Initialize relevance at the outcome node
  #    We set R[outcome_node] = 1
  ############################################################
  R <- list()
  output_node <- nodes$output[[1]]
  R[[output_node]] <- 1
  
  ############################################################
  # 5) Backward pass: alpha-beta LRP or other variants
  ############################################################
  all_nodes_rev <- rev(all_nodes)
  for (node in all_nodes_rev) {
    relevance <- R[[node]]
    if (is.null(relevance)) next
    
    incoming_edges_node <- Filter(function(e) e$to == node, edges)
    if (length(incoming_edges_node) == 0) {
      next  # skip input nodes
    }
    
    # gather activations + weights
    a_i <- sapply(incoming_edges_node, function(e) activations[[e$from]])
    w_ij <- sapply(incoming_edges_node, function(e) {
      weights[[paste(e$from, e$to, sep="->")]]
    })
    z_ij <- a_i * w_ij
    
    # handle alpha-beta or other variants
    if (lrp_variant == 'alpha_beta') {
      z_ij_pos <- pmax(0, z_ij)
      z_ij_neg <- pmin(0, z_ij)
      denom_pos <- sum(z_ij_pos) + 1e-9
      denom_neg <- sum(z_ij_neg) - 1e-9
      
      relevance_pos <- alpha * (z_ij_pos / denom_pos) * relevance
      relevance_neg <- beta  * (z_ij_neg / denom_neg) * relevance
      relevance_i   <- relevance_pos - relevance_neg
    } else if (lrp_variant == 'basic') {
      denominator  <- sum(z_ij) + 1e-9
      relevance_i  <- (z_ij / denominator) * relevance
    } else if (lrp_variant == 'softmax') {
      z_ij_exp    <- exp(z_ij / 1.0)
      denom       <- sum(z_ij_exp) + 1e-9
      relevance_i <- (z_ij_exp / denom) * relevance
    } else {
      stop("Invalid LRP variant. Choose 'alpha_beta', 'basic', or 'softmax'.")
    }
    
    # Normalize so sum(relevance_i) = relevance
    sum_rel_i <- sum(relevance_i)
    if (sum_rel_i > 1e-9) {
      relevance_i <- relevance_i / sum_rel_i * relevance
    }
    
    # Assign to each predecessor
    for (idx in seq_along(incoming_edges_node)) {
      from_node <- incoming_edges_node[[idx]]$from
      if (is.null(R[[from_node]])) {
        R[[from_node]] <- 0
      }
      R[[from_node]] <- R[[from_node]] + relevance_i[idx]
    }
  }
  
  # Return final relevance distribution
  return(R)
}


# =====================================================================
# compute_causal_impact_scores():
#   Uses the final importance_scores from LRP + the updated network
#   to compute kappa(C,O) for each cause set C.
#
#   (1) Identify # of edge-disjoint positive paths from C->O, if outcome>0
#       or a penalty if outcome<0.
#   (2) Identify idle causes in C that never appear on a positive path.
#   (3) Combine with sum(importance_scores[C]) in a ratio.
# =====================================================================
compute_causal_impact_scores <- function(
  importance_scores,
  network,
  input_values,
  outcome_node = NULL,
  cause_sets   = NULL,
  w            = 1,
  negation_handling = TRUE
) {
  ############################################################
  # PART 1: Setup & default cause_sets
  ############################################################
  if (is.null(outcome_node)) {
    outcome_node <- network$nodes$output[[1]]
  }
  if (is.null(cause_sets)) {
    # generate combos of up to size 3 from the input nodes
    input_nodes <- network$nodes$input
    cause_sets  <- list()
    for (size in 1:3) {
      if (size <= length(input_nodes)) {
        combos <- combn(input_nodes, size, simplify=FALSE)
        cause_sets <- c(cause_sets, combos)
      }
    }
  }
  
  ############################################################
  # PART 2: Forward pass to see outcome in the real world
  ############################################################
  activation_function <- function(x) tanh(x)
  activations <- list()
  for (node in network$nodes$input) {
    activations[[node]] <- input_values[[node]]
  }
  
  edges  <- network$edges
  biases <- network$biases
  all_nodes <- c(network$nodes$hidden, network$nodes$output)
  net_inputs <- list()
  
  for (nd in all_nodes) {
    incoming_e <- Filter(function(e) e$to == nd, edges)
    net_in     <- biases[[nd]]
    for (e in incoming_e) {
      from_node <- e$from
      w_ij      <- e$weight
      a_from    <- activations[[from_node]]
      net_in    <- net_in + w_ij*a_from
    }
    net_inputs[[nd]]    <- net_in
    activations[[nd]]   <- activation_function(net_in)
  }
  # real outcome
  a_O_AW <- activations[[outcome_node]]
  
  ############################################################
  # PART 3: Helpers for positive path logic
  ############################################################
  # We'll define is_hidden_contributing(), is_hidden_supported()
  # then count_edge_disjoint_paths() & count_idle_causes().
  
  is_hidden_contributing <- function(h) {
    # does h->outcome have w*h>0?
    e_hO <- Filter(function(e) e$from==h && e$to==outcome_node, edges)
    if (length(e_hO)!=1) return(FALSE)
    w_hO <- e_hO[[1]]$weight
    a_h  <- activations[[h]]
    return(w_hO*a_h>0)
  }
  
  is_hidden_supported <- function(h, C) {
    # does h get a positive push from any c in C?
    inedges_h <- Filter(function(e) e$to==h, edges)
    for (ed in inedges_h) {
      c_node <- ed$from
      if (c_node %in% C) {
        w_ch <- ed$weight
        a_c  <- activations[[c_node]]
        if (w_ch*a_c>0) {
          return(TRUE)
        }
      }
    }
    return(FALSE)
  }
  
  count_edge_disjoint_paths <- function(C) {
    # # of hidden neurons that are "positively activated" & supported by c in C
    path_count <- 0
    for (h in network$nodes$hidden) {
      if (is_hidden_contributing(h) && is_hidden_supported(h,C)) {
        path_count <- path_count+1
      }
    }
    path_count
  }
  
  count_idle_causes <- function(C) {
    # c is idle if it doesn't appear on any positive path
    idle_ct <- 0
    for (c in C) {
      used_by_any_hidden <- FALSE
      for (h in network$nodes$hidden) {
        if (!is_hidden_contributing(h)) next
        # is h supported by c?
        w_edge <- Filter(function(e) e$from==c && e$to==h, edges)
        if (length(w_edge)==1) {
          w_ch <- w_edge[[1]]$weight
          a_c  <- activations[[c]]
          if (w_ch*a_c>0) {
            used_by_any_hidden<-TRUE
            break
          }
        }
      }
      if (!used_by_any_hidden) {
        idle_ct <- idle_ct+1
      }
    }
    idle_ct
  }
  
  ############################################################
  # PART 3bis: NEGATIVE outcome logic (approx penalty)
  ############################################################
  # If outcome < 0, we do a simpler approach:
  #   Z=1 + penalty_count, where penalty_count 
  #   lumps together all "oppositely-coded" or "fitting negative" cases.
  
  get_outgoing_weights <- function(c) {
    out_e <- Filter(function(e) e$from==c, edges)
    sapply(out_e, function(x) x$weight)
  }
  all_positive <- function(ws) { length(ws)>0 && all(ws>0) }
  all_negative <- function(ws) { length(ws)>0 && all(ws<0) }
  
  compute_penalty_for_cause <- function(c, a_c, positive_outcome) {
    w_list <- get_outgoing_weights(c)
    if (length(w_list)==0) return(0)
    
    if (positive_outcome) {
      # if c=+1 but all weights<0 => penalty
      # if c=-1 but all weights>0 => penalty
      if (a_c==1) {
        if (all_negative(w_list)) return(1) else return(0)
      } else {
        if (all_positive(w_list)) return(1) else return(0)
      }
    } else {
      # negative outcome => opposite pattern
      if (a_c==1) {
        if (all_positive(w_list)) return(1) else return(0)
      } else {
        if (all_negative(w_list)) return(1) else return(0)
      }
    }
  }
  
  compute_penalty_count <- function(C) {
    penalty_ct <- 0
    positive_outcome <- (a_O_AW>0)
    for (c in C) {
      a_c_init <- input_values[[c]]
      penalty_ct <- penalty_ct + compute_penalty_for_cause(c, a_c_init, positive_outcome)
    }
    penalty_ct
  }
  
  ############################################################
  # PART 4: Loop over cause_sets, compute sum_Rc, then define 
  #         Z = #_ED + idle or 1 + penalty, then kappa= sum_Rc / (Z*w)
  ############################################################
  results <- data.frame(C=I(character(0)), 
                        Z=numeric(0), 
                        sum_Rc=numeric(0), 
                        kappa=numeric(0),
                        stringsAsFactors=FALSE)
  
  for (C in cause_sets) {
    sum_Rc <- sum(unlist(importance_scores[C]), na.rm=TRUE)
    idle_count <- count_idle_causes(C)
    
    if (a_O_AW>0) {
      num_paths <- count_edge_disjoint_paths(C)
      Z_val     <- num_paths + idle_count
    } else {
      penalty_count <- compute_penalty_count(C)
      Z_val         <- 1 + penalty_count
    }
    
    kappa_val <- 0
    if (Z_val>0) {
      kappa_val <- sum_Rc/(Z_val*w)
    }
    
    results <- rbind(results, data.frame(
      C=paste(C, collapse="+"),
      Z=Z_val, 
      sum_Rc=sum_Rc, 
      kappa=kappa_val,
      stringsAsFactors=FALSE
    ))
  }
  
  results
}


# =====================================================================
# perform_sampling():
#   A small helper to unify the call for 'MH' vs. 'Gibbs' 
#   samplers, returning the sample matrix + last state.
# =====================================================================
perform_sampling <- function(
  nodes, layers, edges, biases, 
  initial_state, n_samples, sampling_method='MH'
) {
  if (sampling_method=='MH') {
    result <- mh_layered_sampler(
      iterations   = n_samples,
      nodes        = nodes,
      layers       = layers,
      edges        = edges,
      biases       = biases,
      initial_state= initial_state,
      temperature  = 1
    )
    samples_df <- result$states
  } else if (sampling_method=='Gibbs') {
    result <- layer_gibbs_sampler(
      iterations   = n_samples,
      nodes        = nodes,
      layers       = layers,
      edges        = edges,
      biases       = biases,
      initial_state= initial_state
    )
    samples_df <- result$states
  } else {
    stop("Unsupported sampling method: ", sampling_method)
  }
  
  # last row => last sample
  last_sample <- samples_df[nrow(samples_df),]
  last_sample_list <- as.list(last_sample)
  names(last_sample_list) <- names(last_sample)
  
  list(
    samples    = samples_df,
    last_sample= last_sample_list
  )
}


# =====================================================================
# run_simulation():
#   A "master" function that orchestrates:
#    1) repeated sampling of states,
#    2) LFP-style weight updates (either 'AsWeGo' or 'Cumulative'),
#    3) final LRP to get importance_scores,
#    4) computing causal impact scores,
#    5) returning the updated network, plus optional all-samples.
# =====================================================================
run_simulation <- function(
  network, initial_state, input_values,
  num_simulations       = 100,
  n_samples             = 20,
  sampling_method       = 'MH',
  weight_update_timing  = 'Cumulative',
  learning_rate         = 1,
  weight_update_rule    = "Additive",
  lrp_variant           = 'alpha_beta',
  alpha                 = 1,
  beta                  = 0,
  return_samples        = FALSE,
  cause_sets            = NULL,
  negation_handling     = TRUE,
  w                     = 1
) {
  # cat("Starting run_simulation...\n")
  
  original_network <- network
  
  # Extract node lists
  nodes <- list(
    input  = network$network$input_neurons,
    hidden = network$network$hidden_neurons,
    output = network$network$output_neurons,
    layers = network$layers
  )
  

  
  edges_list  <- original_network$edges
  biases_list <- original_network$biases
  
  # Initialize accumulators for 'Cumulative' approach
  total_weight_updates <- list()
  total_bias_updates   <- list()
  for (edge in edges_list) {
    key <- paste(edge$from, edge$to, sep="->")
    total_weight_updates[[key]] <- 0
  }
  for (nd in names(biases_list)) {
    total_bias_updates[[nd]] <- 0
  }
  
  final_networks <- list()
  if (return_samples) {
    all_samples <- list()
  }
  

  for (sim in seq_len(num_simulations)) {
    
      ##Commented out: can be added back if one wants to make the process verbose.
  
    # if (sim%%50==0) {
    #   cat("\nSimulation", sim, "of", num_simulations, "\n")
    # }
    
    # Start from original each time
    network <- original_network
    edges   <- network$edges
    biases  <- network$biases
    
    # We'll handle 'AsWeGo' or 'Cumulative' differently
    if (weight_update_timing=='AsWeGo') {
      # Weighted updates after each sample
      current_state <- initial_state
      for (sample_num in seq_len(n_samples)) {
        all_nodes <- unique(c(nodes$input, nodes$hidden, nodes$output))
        
        sampling_result <- perform_sampling(
          nodes         = all_nodes,
          layers        = nodes$layers,
          edges         = edges,
          biases        = biases,
          initial_state = current_state,
          n_samples     = 1,
          sampling_method = sampling_method
        )
        new_sample <- sampling_result$last_sample
        current_state <- new_sample
        
        if (return_samples) {
          all_samples[[length(all_samples)+1]] <- current_state
        }
        
        # LFP weight updates
        updates <- compute_weight_updates(
          samples            = sampling_result$samples,
          edges              = edges,
          nodes              = nodes,
          biases             = biases,
          learning_rate      = learning_rate,
          weight_update_rule = weight_update_rule,
          input_values       = input_values
        )
        
        # Apply these increments
        cumulative_weight_updates <- updates$weight_updates
        cumulative_bias_updates   <- updates$bias_updates
        
        for (edge_idx in seq_along(edges)) {
          key   <- paste(edges[[edge_idx]]$from, edges[[edge_idx]]$to, sep="->")
          d_w   <- cumulative_weight_updates[[key]]
          if (is.null(d_w)) d_w<-0
          if (weight_update_rule=='Additive') {
            edges[[edge_idx]]$weight <- edges[[edge_idx]]$weight + d_w
          } else {
            edges[[edge_idx]]$weight <- edges[[edge_idx]]$weight * exp(d_w)
          }
        }
        for (nd in names(biases)) {
          d_b <- cumulative_bias_updates[[nd]]
          if (is.null(d_b)) d_b<-0
          if (weight_update_rule=='Additive') {
            biases[[nd]] <- biases[[nd]]+d_b
          } else {
            biases[[nd]] <- biases[[nd]]*exp(d_b)
          }
        }
      } # end sample_num
      
      final_networks[[sim]] <- list(edges=edges,biases=biases)
      
    } else if (weight_update_timing=='Cumulative') {
      # We gather all updates first, then apply once
      current_state <- initial_state
      for (sample_num in seq_len(n_samples)) {
        all_nodes <- unique(c(nodes$input, nodes$hidden, nodes$output))
        
        sampling_result <- perform_sampling(
          nodes         = all_nodes,
          layers        = nodes$layers,
          edges         = edges,
          biases        = biases,
          initial_state = current_state,
          n_samples     = 1,
          sampling_method = sampling_method
        )
        new_sample <- sampling_result$last_sample
        current_state <- new_sample
        
        if (return_samples) {
          all_samples[[length(all_samples)+1]] <- current_state
        }
        
        updates <- compute_weight_updates(
          samples            = sampling_result$samples,
          edges              = edges,
          nodes              = nodes,
          biases             = biases,
          learning_rate      = learning_rate,
          weight_update_rule = weight_update_rule,
          input_values       = input_values
        )
        
        c_w_updates <- updates$weight_updates
        c_b_updates <- updates$bias_updates
        
        for (key in names(c_w_updates)) {
          dw <- c_w_updates[[key]]
          if (is.null(dw)) dw<-0
          total_weight_updates[[key]]<- total_weight_updates[[key]]+dw
        }
        for (nd in names(c_b_updates)) {
          db <- c_b_updates[[nd]]
          if (is.null(db)) db<-0
          total_bias_updates[[nd]] <- total_bias_updates[[nd]]+db
        }
      } # end sample_num
    } else {
      stop("Invalid weight_update_timing. Use 'AsWeGo' or 'Cumulative'.")
    }
  } # end for sim
  
  # After all simulations, finalize
  if (weight_update_timing=='AsWeGo') {
    
    # We average the final networks over num_simulations
    averaged_edges  <- original_network$edges
    averaged_biases <- original_network$biases
    
    weight_sums <- list()
    bias_sums   <- list()
    for (ed in averaged_edges) {
      k <- paste(ed$from, ed$to, sep="->")
      weight_sums[[k]] <- 0
    }
    for (nd in names(averaged_biases)) {
      bias_sums[[nd]] <- 0
    }
    
    for (sim in seq_len(num_simulations)) {
      sim_edges <- final_networks[[sim]]$edges
      sim_biases<- final_networks[[sim]]$biases
      for (ed in sim_edges) {
        k <- paste(ed$from, ed$to, sep="->")
        weight_sums[[k]]<- weight_sums[[k]]+ed$weight
      }
      for (nd in names(sim_biases)) {
        bias_sums[[nd]]<- bias_sums[[nd]]+sim_biases[[nd]]
      }
    }
    for (ed_idx in seq_along(averaged_edges)) {
      k <- paste(averaged_edges[[ed_idx]]$from, averaged_edges[[ed_idx]]$to, sep="->")
      av_w  <- weight_sums[[k]]/num_simulations
      averaged_edges[[ed_idx]]$weight <- av_w
    }
    for (nd in names(averaged_biases)) {
      averaged_biases[[nd]]<- bias_sums[[nd]]/num_simulations
    }
    
    averaged_weights<-list()
    for (ed in averaged_edges) {
      k<- paste(ed$from, ed$to, sep="->")
      averaged_weights[[k]]<- ed$weight
    }
    
    # compute LRP-based importance on the final averaged network
    importance_scores <- compute_lrp_importance(
      input_values = input_values,
      weights      = averaged_weights,
      biases       = averaged_biases,
      nodes        = nodes,
      edges        = averaged_edges,
      lrp_variant  = lrp_variant,
      alpha        = alpha,
      beta         = beta
    )
    updated_network <- list(
      nodes  = nodes,
      edges  = averaged_edges,
      biases = averaged_biases
    )
    
  } else {
    
    # average the total updates
    num_total_updates <- num_simulations*n_samples
    avg_w_updates <- list()
    for (k in names(total_weight_updates)) {
      avg_w_updates[[k]] <- total_weight_updates[[k]]/num_total_updates
    }
    avg_b_updates <- list()
    for (nd in names(total_bias_updates)) {
      avg_b_updates[[nd]] <- total_bias_updates[[nd]]/num_total_updates
    }
    
    edges <- original_network$edges
    biases<- original_network$biases
    for (ed_idx in seq_along(edges)) {
      k<- paste(edges[[ed_idx]]$from, edges[[ed_idx]]$to, sep="->")
      dw<- avg_w_updates[[k]]
      if (is.null(dw)) dw<-0
      if (weight_update_rule=='Additive') {
        edges[[ed_idx]]$weight<- edges[[ed_idx]]$weight+dw
      } else {
        edges[[ed_idx]]$weight<- edges[[ed_idx]]$weight*exp(dw)
      }
    }
    for (nd in names(biases)) {
      db<- avg_b_updates[[nd]]
      if (is.null(db)) db<-0
      if (weight_update_rule=='Additive') {
        biases[[nd]]<- biases[[nd]]+db
      } else {
        biases[[nd]]<- biases[[nd]]*exp(db)
      }
    }
    updated_weights<- list()
    for (ed in edges) {
      k<- paste(ed$from, ed$to, sep="->")
      updated_weights[[k]]<- ed$weight
    }
    
    importance_scores <- compute_lrp_importance(
      input_values = input_values,
      weights      = updated_weights,
      biases       = biases,
      nodes        = nodes,
      edges        = edges,
      lrp_variant  = lrp_variant,
      alpha        = alpha,
      beta         = beta
    )
    updated_network <- list(
      nodes  = nodes,
      edges  = edges,
      biases = biases
    )
  }
  
  if (return_samples) {
    all_samples_df <- do.call(rbind, 
      lapply(all_samples, function(x) as.data.frame(t(unlist(x))))
    )
    rownames(all_samples_df)<-NULL
  }
  result_list <- list(
    updated_network    = updated_network,
    importance_scores  = importance_scores
  )
  if (return_samples) {
    result_list$all_samples <- all_samples_df
  }
  
  
  
  # 7) Once we have importance_scores + updated_network:
  
  impact_scores <- compute_causal_impact_scores(
    importance_scores  = result_list$importance_scores,
    network            = result_list$updated_network,
    input_values       = input_values,
    cause_sets         = cause_sets,
    w                  = w,
    negation_handling  = negation_handling
  )

  result_list$impact_scores <- impact_scores
  
  return(result_list)
}

2 Applying the Model to Concrete Examples

We now apply our model to data, starting with two straightforward reference points: a conjunctive structure and a disjunctive structure. In each case, we demonstrate that the model captures two well-known and empirically observed patterns of causal selection. Specifically, we show:

  1. Abnormal Inflation in the conjunctive scenario (\(E \leftarrow A \land B\)), in which the causal score attributed to \(A\) grows when \(P(A)\) is low and \(P(B)\) is high.
  2. Abnormal Deflation in the disjunctive scenario (\(E \leftarrow A; E \leftarrow B\)), in which the causal score attributed to \(A\) increases when \(P(A)\) is high and \(P(B)\) is also high, and is conversely diminished (“deflated”) when \(P(A)\) is low and \(P(B)\) is low.

We also examine how these predictions vary across two different weight-update timing strategies: Cumulative vs. AsWeGo.

  • The Cumulative strategy applies all weight updates after sampling is complete.
  • The AsWeGo strategy applies updates in the midst of sampling, so that each new sample is partly influenced by the updated weights from the preceding sample.

In both cases, we check for the patterns by looking at the causal importance that the model gives to a focal variable \(A\), as we vary the probability \(P(A)\) and the probability \(P(B)\) of the alternate cause. For simplicity, we fix the number of samples to 20, and the learning rate to \(1\) (in the “Cumulative” version) or 0.2 in the “AsWeGo” version, because this latter generates larger updates comparatively. Our results below that the model reproduces both abnormal inflation and deflation. In the conjunctive structure, \(A\) acquires greater causal importance when \(P(A)\) is abnormally small (and \(P(B)\) large), while in the disjunctive structure, \(A\) becomes abnormally important when \(P(A)\) is large (and \(P(B)\) also large).


# ================================================================
# Grid search over probability combinations for abnormal inflation/deflation
# Uses 500 simulations per cell for stable results (162 total cells)
# Results cached via knitr cache to avoid re-running
# ================================================================

####  HELPER FUNCTIONS ####
build_minimal_program <- function(structure=c("conjunctive","disjunctive")) {
  structure <- match.arg(structure)
  if (structure=="conjunctive") {
    # E <- A,B
    clauses <- list(
      list(head = "E", body = c("A","B"))
    )
  } else {
    # disjunctive: E <- A; E <- B
    clauses <- list(
      list(head = "E", body = c("A")),
      list(head = "E", body = c("B"))
    )
  }
  clauses
}

run_minimal_case <- function(
  structure        = "conjunctive",
  pA               = 0.5,
  pB               = 0.5,
  n_samples        = 50,
  num_simulations  = 300,
  weight_update_timing = "Cumulative",
  learning_rate    = 1
) {
  # Build the minimal program
  clauses <- build_minimal_program(structure)
  input_probabilities <- list(A=pA, B=pB)
  
  sampling_network <- generate_network_from_logic_program(
    clauses             = clauses,
    input_probabilities = input_probabilities,
    plot_network        = FALSE
  )
  
  # We define a "positive outcome" scenario: A=1, B=1 => E=1
  initial_state_io <- list(A=1, B=1, E=1)
  net_ent    <- sampling_network$network
  init_state <- generate_initial_hidden_states(initial_state_io, net_ent)
  
  # We run with fixed n_samples=50, num_sim=300 unless specified
  result <- run_simulation(
    network              = sampling_network,
    initial_state        = init_state,
    input_values         = initial_state_io,
    num_simulations      = num_simulations,
    n_samples            = n_samples,
    sampling_method      = 'MH',
    weight_update_timing = weight_update_timing,
    learning_rate        = learning_rate,
    weight_update_rule   = 'Additive',
    return_samples       = FALSE
  )
  
  impact_scores <- result$impact_scores
  
  # Extract kappa(A,E). We'll ignore B for now.
  rowA <- impact_scores[impact_scores$C=="A",]
  kA <- if (nrow(rowA)==1) rowA$kappa else NA_real_
  
  data.frame(
    structure            = structure,
    pA                   = pA,
    pB                   = pB,
    n_samples            = n_samples,
    num_simulations      = num_simulations,
    weight_update_timing = weight_update_timing,
    learning_rate        = learning_rate,
    kappaA               = kA
  )
}

#### STEP 1: Define the parameter grid ####
structures         <- c("conjunctive","disjunctive")
pA_values          <- seq(0.1, 0.9, by=0.1)
pB_values          <- seq(0.1, 0.9, by=0.1)
n_samples_fixed    <- 50
num_sim_fixed      <- 500
update_timing_vec  <- c("Cumulative","AsWeGo")
lr <- 1

#### STEP 2: Loop over all combinations ####

all_results <- list()
idx <- 1
for (strct in structures) {
    for (ut in update_timing_vec) {
      print(strct)
      print(ut)
      if (ut == "AsWeGo") {lr <- 0.2} else if (ut == "Cumulative") {lr <- 1}
      # We'll build up a sub-list for all pA,pB combos
      sub_res <- list()
      for (pa in pA_values) {
        for (pb in pB_values) {
          row_df <- run_minimal_case(
            structure            = strct,
            pA                   = pa,
            pB                   = pb,
            n_samples           = n_samples_fixed,
            num_simulations     = num_sim_fixed,
            weight_update_timing= ut,
            learning_rate       = lr
          )
          sub_res[[length(sub_res)+1]] <- row_df
        }
      }
      # bind them
      df_sub <- bind_rows(sub_res)
      # We'll store df_sub with an identifier
      all_results[[idx]] <- df_sub
      idx <- idx+1
    }
  }

df_all <- bind_rows(all_results)

#### STEP 3: We produce a separate plot for each (structure, learning_rate, update_timing).
####         x-axis = pB, y-axis = kappa(A,E), color= factor(pA).
####         We'll do distinct plots for 'conjunctive' vs 'disjunctive'.

# We'll define a function that, given structure + lr + ut, filters & plots.
plot_kappaA_pa_pb <- function(data, strct, ut) {
  df_filtered <- data %>%
    filter(structure==strct,
           weight_update_timing==ut)
  # x axis = pB
  # color = factor(pA)
  # y= kappaA
  ggplot(df_filtered, aes(x=pB, y=kappaA, color=factor(pA), group=factor(pA))) +
    geom_line(size=1) +
    geom_point(size=2) +
    labs(
      title    = paste("Causal rule: ", strct, "Timing strategy: ", ut),
      subtitle = "Values averaged across 500 simulations x 50 samples each.",
      x        = "Probability of B (the alternate cause)",
      y        = expression(kappa(A,E)),
      color    = "Prob(A)"
    ) +
    ylim(0, NA) +    # let upper side free
    theme_minimal(base_size=14)
}

#### STEP 4: Generate the 2x3x2=12 plots ####
for (strct in structures) {
    for (ut in update_timing_vec) {
      p_plt <- plot_kappaA_pa_pb(df_all, strct, ut)
      print(p_plt)
    }
  }
Experimental results from Morris et al. showing abnormal deflation (disjunctive) and inflation (conjunctive) patterns.

Experimental results from Morris et al. showing abnormal deflation (disjunctive) and inflation (conjunctive) patterns.

Experimental results from Morris et al. showing abnormal deflation (disjunctive) and inflation (conjunctive) patterns.

Experimental results from Morris et al. showing abnormal deflation (disjunctive) and inflation (conjunctive) patterns.

2.1 Variation of kappa(A,E) with Respect to Learning Rate and Number of Samples

Next, we investigate how our model’s predictions shift as we vary (in both the “Cumulative” and “AsWeGo” version) its two main free parameters: 1. The number of samples (\(n\)) collected in each simulation (shown on a log scale). 2. The learning rate (\(\eta\)), which controls the magnitude of each weight update.

We do so by fixing the probabilities of the focal and alternate cause to 0.1 and 0.9, respectively, and then tracking how \(\kappa(A,E)\) varies with changes in those parameters:


# ================================================================
# Exploring variation with learning rate and number of samples
# Uses 500 simulations per cell for stable results
# Results cached via knitr cache to avoid re-running
# ================================================================

structure_choices   <- c("conjunctive","disjunctive") # or pick one
pA_fixed            <- 0.1
pB_fixed            <- 0.9
update_timing_vec2  <- c("Cumulative","AsWeGo")
learning_rate_vec2  <- c(0.1, 0.25, 0.5, 0.75, 1.0)
# Possibly n_samples in c(10,20,50,100,200,500) => roughly geometric
n_samples_vec2      <- c(10,20,50,100,200,500)
#num_sim_fixed2      <- 100  # or we can vary it

#### We define a new function that loops over structure, lr, n_samples, etc.
explore_lr_nsamps_case <- function(strct, pA, pB, ut, lr, nsamp, num_sim) {
  # Minimal program for strct, same old
  clauses <- build_minimal_program(strct)
  input_probabilities <- list(A=pA, B=pB)
  
  netw <- generate_network_from_logic_program(
    clauses             = clauses,
    input_probabilities = input_probabilities,
    plot_network        = FALSE
  )
  init_io <- list(A=1,B=1,E=1)
  net_ent <- netw$network
  init_state <- generate_initial_hidden_states(init_io, net_ent)
  
  res <- run_simulation(
    network              = netw,
    initial_state        = init_state,
    input_values         = init_io,
    num_simulations      = num_sim,
    n_samples            = nsamp,
    sampling_method      = 'MH',
    weight_update_timing = ut,
    learning_rate        = lr,
    weight_update_rule   = 'Additive',
    return_samples       = FALSE
  )
  # Extract kappa(A,E)
  rowA <- res$impact_scores[res$impact_scores$C=="A",]
  kA <- if (nrow(rowA)==1) rowA$kappa else NA_real_
  
  data.frame(
    structure=strct,
    pA=pA, pB=pB,
    update_timing=ut,
    learning_rate=lr,
    n_samples=nsamp,
    num_sim=num_sim,
    kappaA=kA
  )
}

#### We'll gather results for each structure & updateTiming,
#### while scanning over the learning_rate_vec2 & n_samples_vec2

df_part2 <- list()
idx2 <- 1
for (strct in structure_choices) {
  print(strct)
  for (ut in update_timing_vec2) {
    print(ut)
    for (lr in learning_rate_vec2) {
      for (nsamp in n_samples_vec2) {
num_sim <- 10000/nsamp
        row_df <- explore_lr_nsamps_case(
          strct = strct,
          pA = pA_fixed,
          pB = pB_fixed,
          ut = ut,
          lr = lr,
          nsamp = nsamp,
          num_sim = num_sim
        )
        df_part2[[idx2]] <- row_df
        idx2 <- idx2+1
      }
    }
  }
}
df_part2 <- bind_rows(df_part2)

#### STEP 2: Plot them
#### We'll produce, for each structure, separate plots for update_timing maybe. 
#### x-axis= n_samples, y= kappaA, color= factor(learning_rate). 
#### Possibly use a log scale on x-axis.

p_part2 <- ggplot(df_part2, aes(x=n_samples, y=kappaA, color=factor(learning_rate))) +
  geom_line(size=1) +
  geom_point(size=2) +
  scale_x_log10() +   # to show a "chronological" or geometric scale
  facet_grid(structure ~ update_timing, labeller=label_both) +
  labs(
    title    = "Effect of n_samples & learning_rate on kappa(A,E)",
    subtitle = paste("pA=", pA_fixed, " pB=", pB_fixed,
                     " (A=1,B=1 => E=1)"),
    x        = "Number of samples (log scale)",
    y        = expression(kappa(A,E)),
    color    = "Learning Rate"
  ) +
  theme_minimal(base_size=14)

print(p_part2)




df_conjunctive <- df_part2 %>%
  filter(structure == "conjunctive")

df_disjunctive <- df_part2 %>%
  filter(structure == "disjunctive")


p_part2_disjunctive <- ggplot(
  df_disjunctive,
  aes(x = n_samples, y = kappaA, color = factor(learning_rate))
) +
  geom_line(size=1) +
  geom_point(size=2) +
  scale_x_log10() +
  facet_grid(
    . ~ update_timing,
    # This labeller ensures we only show the factor value, not "update_timing: ..."
    labeller = labeller(update_timing = label_value)
  ) +
  labs(
    title    = "Effect of n_samples & learning_rate on kappa(A,E)",
    subtitle = "Logic Program:\n{ E \u2190 A ;\n  E \u2190 B }",
    x        = "Number of samples (log scale)",
    y        = expression(kappa(A,E)),
    color    = "Learning Rate"
  ) +
  theme_minimal(base_size = 14)


p_part2_conjunctive <- ggplot(
  df_conjunctive,
  aes(x = n_samples, y = kappaA, color = factor(learning_rate))
) +
  geom_line(size=1) +
  geom_point(size=2) +
  scale_x_log10() +
  facet_grid(
    . ~ update_timing,
    labeller = labeller(update_timing = label_value)
  ) +
  labs(
    title    = "Effect of n_samples & learning_rate on kappa(A,E)",
    subtitle = "Logic Program:\n{ E \u2190 A, B }",
    x        = "Number of samples (log scale)",
    y        = expression(kappa(A,E)),
    color    = "Learning Rate"
  ) +
  theme_minimal(base_size = 14)

print(p_part2_disjunctive)
print(p_part2_conjunctive)
Causal importance scores (kappa) for conjunctive and disjunctive structures.

Causal importance scores (kappa) for conjunctive and disjunctive structures.

Causal importance scores (kappa) for conjunctive and disjunctive structures.

Causal importance scores (kappa) for conjunctive and disjunctive structures.

2.1.1 Main Observations

  1. Cumulative Strategy
    In the right panel of each figure, we see that, for a given learning rate \(\eta\), the model’s causal score for \(A\) remains nearly constant as we vary \(n\). This is because the Cumulative method only applies weight updates once at the end. Changing the number of samples has little effect on the network’s final parameter updates, except for the fact that it makes the samples generated less anchored to the initial state, which doesn’t have a great importance for these simple abnormal inflation and deflation patterns. By contrast, the effect of \(\eta\) is both stable and monotonic, in that a larger \(\eta\) intensifies the updates. This magnifies whichever pattern of abnormal inflation or deflation is emergent, but does so in a fairly linear way.

  2. AsWeGo Strategy
    The left panel of the Figures depict the predictions AsWeGo approach, with the same parameter ranges. Here we observe stronger, and sometimes also nonlinear trends, especially for the conjunctive logic program \(E \leftarrow A,B\):

    For low number of samples (up to a 100), the impact of the learning rate seems similar to the one it has in the “Cumulative” model: as we move from a learning of 0.1 to 0.75, the pattern of abnormal inflation is amplified, in that \(A\), here the low probability variables, gets a higher score. The dynamics gets overturned, however, when we consider the even greater learning rate of 1, which gives a lower score than any of the other learning rates (!)

    In a similar vein, when we consider number of samples of 100 and above, we see a change in the ranking between the ratings obtained with a learning rate of 0.5, and those obtained with 0.75: although the latter generates higher scores for small numbers of samples, we see the dynamic switching so that a learning rate of 0.5 now becomes more advantageous.

2.1.1.1 The “Runaway” Effect in the AsWeGo Strategy

Under AsWeGo, each sample modifies the network’s weights before the next sample is drawn. Hence, if the initial states or initial updates push the system in a certain direction (for instance, reinforcing the weight from \(A\) to the outcome), subsequent samples become more likely to favor that same cause. This positive feedback loop can “run away” in the sense that it heavily biases the sampling distribution toward states that keep boosting \(A\).

This dynamic explains why contrasts are larger (keeping learning rate constant) in that version of the model, than in the “Cumulative” version.

When pushed past a certain limit, it can also have even more dramatic effects by changing the nature of the relationship between \(A,B\) and the outcome: in the Conjunctive Example \(E \leftarrow A,B\), the initial sampling episodes see \((A=1, B=1)\), since that is our actual world’s starting point. These first steps reward the connections between \(A\) and \(B\) and their common hidden node \(H_{ab}\), and decrease the size of the (negative) bias on this latter. If updates are modest (low \(\eta\)), the gate effectively remains conjunctive, preserving “abnormal inflation” dynamic in favor of \(A\). But, if \(\eta\) is large or if \(n\) is high enough, the repeated updates can eventually push the weights to approximate a “disjunctive” gate instead. The activation of \(A\) and \(B\) become such that the outcome node might saturate whenever either occurs. In effect, the system transitions from a conjunctive to a nearly disjunctive gate, causing \(\kappa(A,E)\) to drop sharply, because \(A\) is the abnormal variable.

2.1.2 Rationale for Fixing n

Because of the runaway effect, letting both \(\eta\) and \(n\) vary freely introduces a great deal of complexity. Small differences in \(\eta\) can, at high \(n\), result in qualitatively different prediction patterns, especially when updates are performed “As we go”.

Hence, we choose to fix the the number of samples to \(\ (n=20)\) for subsequent analyses. This choice has several advantages:

  1. Reducing model complexity
    Leaving both \(n\) and \(\eta\) as free parameters would consider increase the model complexity, without any substantial upshot, as the the impact of each parameter is redundant on that of the other in the best case scenario. On top of that, their interdependence makes it harder to properly fit either one to data, the moment we allow both to vary. By holding \(n\) constant, we reduce the dimensionality of the parameter space and can more reliably fit (\(\eta\)).

  2. Relation to anchoring As discussed in the chapter dedicated, using a small number of samples guarantees that the sampling process is anchored to the initial state to an extent. It effectively derives the anchoring phenomenon from more fundamental facts about how counterfactual states are generated. What is more, choosing \(n=20\) puts the chain length of the MCMC process in the same range as the kinds of values in other models of causal judgments grounded in such sampling strategies (e.g., Davies & Rehder, 2020).

  3. Psychological Plausibility.
    An additional element of plausibility for fixing \(n\) to a low value comes from a consideration of the cognitive costs of sampling. Engaging in counterfactual simulation is cognitively costly. Adopting a moderate chain length like \(n=20\) aligns with the idea that humans only consider enough alternative states to track “normal” or “plausible” alternative scenarios, but do not necessarily exhaustively search the space of possibilities.

  4. Robustness of the model in the “Cumulative” version

Although the precise number \(\ (n=20)\) we choose is arbitrary (beyond the fact it sits in a plausible range of small values, as per the arguments in (2)-(3) above), the model dynamics are not hurt by this arbitrariness. Indeed, as shown above, the predictions barely vary with the number of samples in the “Cumulative” version, and still relatively little in the “AsWeGo” version, provided that the number we picked is small.

Provided that fixed number of samples, we will, in a second step, fit the learning rate parameter for each of the update timing strategy to data collected from our second experiment in Konuk et al. (under review).

############################################################################
# One-stop pipeline to load data for Exp1 & Exp2, exclude comprehension
# failures, parse questions into letter-based sets, scale responses, and
# combine the data frames into df_judgments.
############################################################################
# HELPER FUNCTIONS
############################################################################

# For Exp1, we only expect boxes A, B, C.
# We'll ignore the actual number of colored balls. We'll just parse the
# question text to see if it references box A, box B, box C, or combos.
extract_letters_exp1 <- function(question_string) {
  if (is.na(question_string)) return(NA_character_)
  # Gather all occurrences of A/B/C
  letters_found <- unlist(regmatches(question_string, gregexpr("[A-C]", question_string)))
  if (length(letters_found) == 0) return(NA_character_)
  letters_found <- sort(unique(letters_found))
  paste(letters_found, collapse = ",")
}

# For Exp2, the question might involve any subset of A/B/C/D.
extract_letters_exp2 <- function(question_string) {
  if (is.na(question_string)) return(NA_character_)
  # Gather all occurrences of A/B/C/D
  letters_found <- unlist(regmatches(question_string, gregexpr("[A-D]", question_string)))
  if (length(letters_found) == 0) return(NA_character_)
  letters_found <- sort(unique(letters_found))
  paste(letters_found, collapse = ",")
}

############################################################################
# A) Experiment 1
############################################################################

# 1) Load data
exp1_raw <- read_csv("/Users/cankonuk/Documents/LAB/plural-actual-causes/analyses/quillien-2a-plural-version/experiment1-plurals.csv")
## Rows: 13101 Columns: 17
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (13): trial_type, trial_index, internal_node_id, subject_ID, response, a...
## dbl  (3): rt, time_elapsed, group
## lgl  (1): scoring
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# 2) Exclude comprehension failures
incorrect_subjects_exp1 <- exp1_raw %>%
  filter(
    (trial_index == "6.WhitePoints" & response != "0") |
      (trial_index == "6.WinningCondition" & response != "You must get at least 2 points")
  ) %>%
  distinct(subject_ID)

exp1_clean <- exp1_raw %>%
  anti_join(incorrect_subjects_exp1, by = "subject_ID")

# 3) Filter to target trials (20..26). Convert response to numeric if needed.
exp1_clean <- exp1_clean %>%
  mutate(trial_index_num = as.numeric(trial_index)) %>%
  filter(trial_index_num >= 20, trial_index_num <= 26)
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `trial_index_num = as.numeric(trial_index)`.
## Caused by warning:
## ! NAs introduced by coercion
# 4) Create a cause_set column with the letter(s) referenced by the question
#    (instead of "lo"/"mid"/"hi"/"all three").
exp1_clean <- exp1_clean %>%
  rowwise() %>%
  mutate(
    cause_set = extract_letters_exp1(question),  # A, B, C, or combos
    # In the original code, "target_response" was the 1..9 scale.
    # If you only have "response", adapt as needed:
    # Possibly: scaled_val = (as.numeric(response) + 1 - 1) / 8
    # or if "target_response" is the correct 1..9 measure:
    scaled_val = as.numeric(response) / 8  # 1->0, 9->1
  )

# 5) Summarize by cause_set
df_exp1 <- exp1_clean %>%
  filter(!is.na(cause_set)) %>%
  group_by(cause_set) %>%
  summarize(
    mean_judgment = mean(scaled_val, na.rm = TRUE),
    .groups       = "drop"
  ) %>%
  mutate(
    experiment    = "Exp1",
    scenario_name = "All-Pos"  # or whichever label you prefer
  ) %>%
  select(experiment, scenario_name, cause_set, mean_judgment)

############################################################################
# B) Experiment 2
############################################################################

# 1) Load data
exp2_raw <- read_csv("/Users/cankonuk/Documents/LAB/plural-actual-causes/analyses/conj-disj-analysis/data-conj-disj.csv")
## Rows: 24656 Columns: 34
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (15): trial_type, trial_index, internal_node_id, subject_ID, response, a...
## dbl (13): rt, time_elapsed, group, urn--n-balls, urn--n-colored-balls, urn-A...
## lgl  (6): urn--is-results-color, urn-A-is-results-color, urn-C-is-results-co...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# 2) Exclude comprehension failures
failed_subjects_whiteballs <- exp2_raw %>%
  filter(trial_index == "17.WhiteBalls", response != "2") %>%
  pull(subject_ID) %>%
  unique()

failed_subjects_winningcondition <- exp2_raw %>%
  filter(
    trial_index == "17.WinningCondition",
    response != "You must draw two yellow, or two purple balls"
  ) %>%
  pull(subject_ID) %>%
  unique()

failed_subjects_exp2 <- unique(c(failed_subjects_whiteballs, failed_subjects_winningcondition))

exp2_clean <- exp2_raw %>%
  filter(!(subject_ID %in% failed_subjects_exp2))

# 3) Identify & filter to the relevant "trial_kind" conditions
exp2_clean <- exp2_clean %>%
  rowwise() %>%
  mutate(
    trial_kind = case_when(
      trial_index %in% 6:15 ~ "habituation",
      # triple 0
      (`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == FALSE &
       `urn-B-n-colored-balls` == 2  & `urn-B-is-results-color` == FALSE &
       `urn-C-n-colored-balls` == 4  & `urn-C-is-results-color` == TRUE  &
       `urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == FALSE) ~ "triple 0",
      # overdetermined negative
      (`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == FALSE &
       `urn-B-n-colored-balls` == 2  & `urn-B-is-results-color` == FALSE &
       `urn-C-n-colored-balls` == 4  & `urn-C-is-results-color` == FALSE &
       `urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == FALSE) ~ "overdetermined negative",
      # overdetermined positive
      (`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == TRUE  &
       `urn-B-n-colored-balls` == 2  & `urn-B-is-results-color` == TRUE  &
       `urn-C-n-colored-balls` == 4  & `urn-C-is-results-color` == TRUE  &
       `urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == TRUE)  ~ "overdetermined positive",
      # triple 1
      (`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == TRUE  &
       `urn-B-n-colored-balls` == 2  & `urn-B-is-results-color` == TRUE  &
       `urn-C-n-colored-balls` == 4  & `urn-C-is-results-color` == FALSE &
       `urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == TRUE)  ~ "triple 1",
      TRUE ~ NA_character_
    )
  ) %>%
  filter(trial_kind %in% c("overdetermined positive", "triple 1",
                           "overdetermined negative", "triple 0"))

# 4) Extract the letter-based sets (A, B, C, D) from the question and scale response
exp2_clean <- exp2_clean %>%
  mutate(
    cause_set  = extract_letters_exp2(question),
    scaled_val = as.numeric(response) / 8  # scale 0..8 => [0..1]
  )

# 5) Summarize
df_exp2 <- exp2_clean %>%
  filter(!is.na(cause_set)) %>%
  group_by(trial_kind, cause_set) %>%
  summarize(
    mean_judgment = mean(scaled_val, na.rm = TRUE),
    .groups       = "drop"
  ) %>%
  rename(scenario_name = trial_kind) %>%
  mutate(experiment = "Exp2") %>%
  select(experiment, scenario_name, cause_set, mean_judgment)

############################################################################
# C) Combine both experiments
############################################################################

#df_judgments <- bind_rows(df_exp1, df_exp2)
df_judgments <- df_exp2
#print(df_judgments)

############################################################################
# D) Example quick plot
############################################################################

cause_levels <- c("A","B","C","D","A,B","A,C","A,D", "B,C","B,D", "C,D", "A,B,C", "A,B,D", "A,C,D", "B,C,D")
df_judgments$cause_set <- factor(df_judgments$cause_set, levels=cause_levels)

p <- ggplot(df_judgments, aes(x = cause_set, y = mean_judgment, fill = experiment)) +
  geom_col(position = position_dodge(width = 0.7), width = 0.6) +
  facet_wrap(~ scenario_name, scales = "free_x") +
  labs(
    title = "Participants’ Mean Judgments Across Conditions",
    x     = "Cause Set",
    y     = "Mean Judgment (scaled 0–1)",
    fill  = "Experiment"
  ) +
  theme_minimal(base_size = 14) +
  theme(
    axis.text.x         = element_text(angle = 45, hjust = 1),
    panel.grid.major.x  = element_blank(),
    panel.grid.minor    = element_blank()
  )

print(p)

##############################################################################
## 1) A function that runs the model for a given scenario and parameters,
##    returning a named vector of predictions, then optionally raises them
##    to gamma. (Same approach as before.)
##############################################################################

run_model_for_scenario <- function(clauses, input_probabilities, initial_state,
                                   n_samples, learning_rate, w,
                                   update_timing="Cumulative",
                                   gamma=1) {
  # Build the sampling network
  sampling_network <- generate_network_from_logic_program(
    clauses              = clauses,
    input_probabilities  = input_probabilities,
    plot_network         = FALSE
  )
  net_ent <- sampling_network$network
  init_state <- generate_initial_hidden_states(initial_state, net_ent)

  # Run the simulation
  result <- run_simulation(
    network               = sampling_network,
    initial_state         = init_state,
    input_values          = init_state,
    num_simulations       = 500,
    n_samples             = n_samples,
    sampling_method       = "MH",
    weight_update_timing  = update_timing,
    learning_rate         = learning_rate,
    weight_update_rule    = "Additive",
    w                     = w,
    return_samples        = FALSE
  )

  # Extract raw kappa values
  preds <- result$impact_scores   # data.frame: C, kappa
  # Raise them by gamma
  preds$kappa <- preds$kappa ^ gamma
  # Return named vector
  setNames(preds$kappa, preds$C)
}


##############################################################################
## 2) A function that obtains scenario info (clauses, probabilities, init_state)
##    given a scenario_name.
##############################################################################

get_scenario_info <- function(scenario_name) {
  # Fill in your real logic. Example stubs:
  if (scenario_name=="All-Pos") {
  clauses <- list(
    list(head="E", body=c("A","B")),
    list(head="E", body=c("A","C")),
    list(head="E", body=c("B","C"))
    )
    input_probabilities <- list(A=0.05, B=0.5, C=0.95)
    initial_state <- list(A=1,B=1,C=1,E=1)
  } else if (scenario_name=="overdetermined positive") {
    clauses <- list(
      list(head="E", body=c("A","B")),
      list(head="E", body=c("C","D"))
    )
    input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
    initial_state <- list(A=1,B=1,C=1,D=1,E=1)
  } else if (scenario_name=="triple 1") {
    clauses <- list(
      list(head="E", body=c("A","B")),
      list(head="E", body=c("C","D"))
    )
    input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
    initial_state <- list(A=1,B=1,C=1,D=-1,E=1)
  } else if (scenario_name=="overdetermined negative") {
    clauses <- list(
      list(head="E", body=c("A","B")),
      list(head="E", body=c("C","D"))
    )
    input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
    initial_state <- list(A=-1,B=-1,C=-1,D=-1,E=-1)
  } else if (scenario_name=="triple 0") {
    clauses <- list(
      list(head="E", body=c("A","B")),
      list(head="E", body=c("C","D"))
    )
    input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
    initial_state <- list(A=-1,B=-1,C=1,D=-1,E=-1)
  } else {
    stop("Unknown scenario: ", scenario_name)
  }
  list(clauses=clauses, input_probabilities=input_probabilities,
       initial_state=initial_state)
}


##############################################################################
## 3) A new measure of fit: item-level correlation between data and model
##############################################################################

compute_correlation_for_dataset <- function(df_judgments, param) {
  # param is a list with (n_samples, learning_rate, w, update_timing, gamma)
  # We'll collect predicted kappa vs. actual means for each item across scenarios

  predictions_vec <- c()
  data_vec        <- c()

  scenario_list <- unique(df_judgments$scenario_name)
  for (scn in scenario_list) {
    # gather scenario info
    info <- get_scenario_info(scn)
    # run model
    pred_kappa <- run_model_for_scenario(
      clauses             = info$clauses,
      input_probabilities = info$input_probabilities,
      initial_state       = info$initial_state,
      n_samples           = param$n_samples,
      learning_rate       = param$learning_rate,
      w                   = param$w,
      update_timing       = param$update_timing,
      gamma               = param$gamma
    )

    df_sub <- df_judgments[df_judgments$scenario_name==scn, ]
    for (i in seq_len(nrow(df_sub))) {
      cause <- df_sub$cause_set[i]
      obs   <- df_sub$mean_judgment[i]  # data
      if (!is.na(pred_kappa[cause])) {
        # store predicted vs. data
        predictions_vec <- c(predictions_vec, pred_kappa[cause])
        data_vec        <- c(data_vec, obs)
      }
    }
  }

  # If we have fewer than 2 items, correlation is undefined => return 0 or -Inf
  if (length(predictions_vec) < 2) {
    return(NA_real_)
  }

  # compute Pearson correlation
  cor_val <- cor(predictions_vec, data_vec, use="complete.obs", method="pearson")
  cor_val
}

##############################################################################
## 4) A small wrapper that the parallel workers call for each param combo,
##    returning a 1-row data.frame with correlation
##############################################################################

calc_correlation <- function(df_judgments, ns, lr, w, ut, gm) {
  param_list <- list(n_samples=ns, learning_rate=lr, w=w,
                     update_timing=ut, gamma=gm)
  cor_val <- compute_correlation_for_dataset(df_judgments, param_list)
  data.frame(n_samples=ns, learning_rate=lr, w=w, update_timing=ut,
             gamma=gm, correlation=cor_val)
}


##############################################################################
## 5) Parallel search for param combos => we want the maximum correlation
##############################################################################

parallel_search_over_params_correlation <- function(df_judgments,
                                                    n_samples_set,
                                                    lr_set,
                                                    w_set,
                                                    update_timing_set,
                                                    gamma_set,
                                                    n_cores=2) {
  library(doParallel)
  library(foreach)

  param_df <- expand.grid(
    n_samples      = n_samples_set,
    learning_rate  = lr_set,
    w              = w_set,
    update_timing  = update_timing_set,
    gamma          = gamma_set,
    stringsAsFactors=FALSE
  )

  cl <- makeCluster(n_cores)
  registerDoParallel(cl)

  # each iteration => .combine=rbind => a data.frame row with correlation
  results <- foreach(i=1:nrow(param_df), .combine=rbind,
                     .packages=c("dplyr"),
                     .export=c(
                       # all subfunctions
                       "calc_correlation",
                       "compute_correlation_for_dataset",
                       "run_model_for_scenario",
                       "get_scenario_info",
                       "generate_network_from_logic_program",
                       "generate_initial_hidden_states",
                       "run_simulation",
                       "perform_sampling",
                       "mh_layered_sampler",
                       "layer_gibbs_sampler",
                       "compute_lrp_importance",
                       "compute_causal_impact_scores",
                       "sigmoid",
                       "compute_loss_NLL",
                       "activation_function",
                       "compute_weight_updates",
                       "compute_outputs"
                     )) %dopar% {

    ns <- param_df$n_samples[i]
    lr <- param_df$learning_rate[i]
    ww <- param_df$w[i]
    ut <- param_df$update_timing[i]
    gm <- param_df$gamma[i]

    row_df <- calc_correlation(df_judgments, ns, lr, ww, ut, gm)
    row_df
  }

  stopCluster(cl)

  # find best row => max correlation
  idx_best <- which.max(results$correlation)
  best_row <- results[idx_best, ]

  list(
    best_correlation = best_row$correlation,
    best_params      = best_row[, c("n_samples","learning_rate","w","update_timing","gamma")],
    results          = results
  )
}

2.1.3 Fitting the Model in the Cumulative and AsWeGo Conditions

Below, we present two separate code cells—one for the Cumulative condition, one for the AsWeGo condition—each performing a grid search over \(\eta\) (the learning rate) and \(\gamma\) (a scaling exponent). Recall that \(\gamma\) rescales the model’s raw predictions \(\kappa\) via \(\kappa^{\gamma}\). This accounts for the fact that participants’ numerical judgments may follow a different absolute scale than the model’s raw outputs but still preserve relative differences.

2.1.3.1 Fitting in the Cumulative Condition

In the first cell, we restrict updates to happen at the end of sampling (the “Cumulative” strategy). We systematically vary \(\eta\) and \(\gamma\) (with \(n=20\) and \(w=1\) fixed) and compute the Pearson correlation between the model’s predictions and participants’ mean judgments. The best-fitting parameters emerge as: - \(\eta \approx 1.15\) - \(\gamma \approx 0.58\) - Update timing = "Cumulative"

These values indicate that participants’ judgments align most closely with the model when the final updates are relatively strong (\(\eta \approx 1.15\)) and the predictions are sharpened compressed by \(\gamma \approx 0.58\) (because raw predictions are on the range 0-1, low gamma exponents mean upward scaling). The code below illustrates the full search procedure and final plots.

##############################################################################
## 6) Usage example: searching for max correlation, then plotting
##############################################################################

run_model_for_scenario <- function(clauses, input_probabilities, initial_state,
                                   n_samples, learning_rate, w,
                                   update_timing="Cumulative",
                                   gamma=1) {
  # Build the sampling network
  sampling_network <- generate_network_from_logic_program(
    clauses              = clauses,
    input_probabilities  = input_probabilities,
    plot_network         = FALSE
  )
  net_ent <- sampling_network$network
  init_state <- generate_initial_hidden_states(initial_state, net_ent)

  # Run the simulation
  result <- run_simulation(
    network               = sampling_network,
    initial_state         = init_state,
    input_values          = init_state,
    num_simulations       = 500,
    n_samples             = n_samples,
    sampling_method       = "MH",
    weight_update_timing  = update_timing,
    learning_rate         = learning_rate,
    weight_update_rule    = "Additive",
    w                     = w,
    return_samples        = FALSE
  )

  # Extract raw kappa values
  preds <- result$impact_scores   # data.frame: C, kappa
  # Raise them by gamma
  preds$kappa <- preds$kappa ^ gamma
  # Return named vector
  setNames(preds$kappa, preds$C)
}



n_samples_vec <- 20
lr_vec        <- seq(0.5, 1.5, by = 0.05)
w_vec         <- 1
update_vec    <- c("Cumulative") 
gamma_vec     <- seq(0.4, 0.7, by = 0.02)

res_parallel_corr_cumulative <- parallel_search_over_params_correlation(
  df_judgments,
  n_samples_set   = n_samples_vec,
  lr_set          = lr_vec,
  w_set           = w_vec,
  update_timing_set=update_vec,
  gamma_set       = gamma_vec,
  n_cores         = 4
)

cat("Best param:\n")
print(res_parallel_corr_cumulative$best_params)
cat("Best correlation =", res_parallel_corr_cumulative$best_correlation, "\n")

df_res_corr_cumulative <- res_parallel_corr_cumulative$results

# e.g. quick visualization of correlation heatmap

df_res_corr_cumulative_plot <- df_res_corr_cumulative #%>% filter(n_samples == 20, gamma < 0.81, w == 1)
p_corr <- ggplot(df_res_corr_cumulative_plot, aes(x=learning_rate, y=correlation, color=gamma, shape=as.factor(w))) +
  geom_point(size=3) +
  #facet_wrap(~n_samples) +
  #facet_grid(n_samples ~ gamma, labeller=label_both) +
  labs(title="Grid Search: correlation measure",
       subtitle="Higher correlation is better") +
  scale_color_viridis_c(option="plasma", direction=-1) +
  theme_minimal(base_size=14)
print(p_corr)
##############################################################################
## 7) Final step: for each scenario, produce a side-by-side plot:
##    - data means vs. model predictions for the best param.
##############################################################################

## a) Extract the best param
#best_p_cumulative <- as.list(res_parallel_corr_cumulative$best_params[1, ])
# e.g. best_p$n_samples, best_p$learning_rate, best_p$w, best_p$update_timing, best_p$gamma

## b) Or hardcode the best params:

best_p_cumulative <- list(
  n_samples     = 20,        #  chosen sample size
  learning_rate = 1.15,      # approximate best-fit η
  w             = 1,         # typically we fix w=1
  update_timing = "Cumulative",
  gamma         = 0.58       # approximate best-fit γ
)
# define a function that merges the data and model predictions for each scenario
get_comparison_df <- function(df_judgments, best_p) {
  # We'll produce a data frame with columns:
  #   scenario_name, cause_set, data_mean, model_pred
  # for each scenario_name

  scenario_list <- unique(df_judgments$scenario_name)
  out_list <- list()

  for (scn in scenario_list) {
    info <- get_scenario_info(scn)
    # run model
    preds <- run_model_for_scenario(
      clauses             = info$clauses,
      input_probabilities = info$input_probabilities,
      initial_state       = info$initial_state,
      n_samples           = best_p$n_samples,
      learning_rate       = best_p$learning_rate,
      w                   = best_p$w,
      update_timing       = best_p$update_timing,
      gamma               = best_p$gamma
    )

    df_sub <- df_judgments[df_judgments$scenario_name==scn, ]
    # for each cause_set in df_sub, retrieve model
    # build small df with columns scenario_name, cause_set, data_mean, model_pred
    model_vec <- numeric(nrow(df_sub))
    for (i in seq_len(nrow(df_sub))) {
      cs <- df_sub$cause_set[i]
      model_vec[i] <- if (!is.na(preds[cs])) preds[cs] else NA
    }
    tmp <- data.frame(
      scenario_name = scn,
      cause_set     = df_sub$cause_set,
      data_mean     = df_sub$mean_judgment,
      model_pred    = model_vec
    )
    out_list[[length(out_list)+1]] <- tmp
  }
  do.call(rbind, out_list)
}

comparison_df_cumulative <- get_comparison_df(df_judgments, best_p_cumulative)

# We now have scenario_name, cause_set, data_mean, model_pred

## c) Plot for each scenario: a bar for data vs a "line" or "point" for model
##    We'll do a side-by-side or overlaid approach.

#------------------------------------------------------------------
# 3) Plot approach #1: side-by-side bars
#------------------------------------------------------------------
comparison_long_cumulative <- comparison_df_cumulative %>%
  tidyr::gather(key = "type", value = "value", data_mean, model_pred)



#------------------------------------------------------------------
# 4) Plot approach #2: data bars + model points/line
#------------------------------------------------------------------
p_comparison_2_cumulative <- ggplot(comparison_df_cumulative, aes(x = cause_set)) +
  geom_col(
    aes(y = data_mean),
    fill  = "steelblue",
    alpha = 0.5,
    width = 0.6,
    position = position_nudge(x = -0.2)
  ) +
  geom_point(
    aes(y = model_pred),
    color = "darkred",
    size  = 3,
    position = position_nudge(x = 0.2)
  ) +
  geom_line(
    aes(y = model_pred, group = 1),
    color     = "darkred",
    linewidth = 1.0,
    position  = position_nudge(x = 0.2)
  ) +
  facet_wrap(~ scenario_name, scales = "free_x") +
  labs(
    title = "Data (bar) vs. Model (point+line)",
    subtitle = "Cumulative update timing, η=1.15, γ=0.58",
    x = "Cause Set",
    y = "Rating"
  ) +
  theme_minimal(base_size = 14) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

print(p_comparison_2_cumulative)

wanted_order <- c("A","B","C","D",            # singles
                  "A,B","A,C","B,C","C,D",        # pairs
                  "A,B,C","A,B,D")                # triples


participant_summary <- df_judgments %>% 
  filter(scenario_name == "overdetermined positive",
         cause_set      %in% wanted_order) %>% 
  transmute(
    cause_set,
    mean  = mean_judgment)

## ───────────────────────────────────────────────────────────────
## a) participants: mean + 95 % CI (change ‘rating’ if needed)
## ───────────────────────────────────────────────────────────────


## ───────────────────────────────────────────────────────────────
## b) model predictions – we already have them in
##    ‘comparison_df_cumulative’
## ───────────────────────────────────────────────────────────────
model_df <- comparison_df_cumulative %>% 
  filter(scenario_name == "overdetermined positive",
         cause_set      %in% wanted_order) %>% 
  select(cause_set, model_pred)

## If your model is on [0,1] but the scale is 1–9, rescale like this:
## model_df <- model_df %>% mutate(model_pred = model_pred*8 + 1)

## ───────────────────────────────────────────────────────────────
## c) merge + lock the x-axis order
## ───────────────────────────────────────────────────────────────
plot_df <- participant_summary %>% 
  left_join(model_df, by = "cause_set") %>% 
  mutate(cause_set = factor(cause_set, levels = wanted_order))


p_overdet <- ggplot(plot_df, aes(x = cause_set)) +

  ## participants ────────────────────────────────────────────────
  geom_point(aes(y = mean), colour = "black", size = 2) +
  geom_line (aes(y = mean, group = 1),
             colour = "black", linewidth = .8) +

  ## model ───────────────────────────────────────────────────────
  geom_point(aes(y = model_pred), colour = "red",  size = 3) +
  geom_line (aes(y = model_pred, group = 1),
             colour = "red",  linewidth = 1.2) +

  ## cosmetics ───────────────────────────────────────────────────
  scale_y_continuous(limits = c(0, 1), breaks = 0:1) +
  labs(x = NULL, y = "Prediction / Mean Response",
       title = "Overdetermined positive: participants vs. model") +
  theme_minimal(base_size = 18) +
  theme(axis.text.x        = element_text(angle = 45, hjust = 1, size = 16),
        panel.grid.major.x = element_blank(),
        axis.line          = element_line(size = .8),
        axis.ticks.length  = unit(0.15, "cm"),
        legend.position    = "none")

print(p_overdet)
# ------------------------------------------------------------------
# a) participant means   (multiply by 8, then add 1)
# ------------------------------------------------------------------
participant_summary <- df_judgments %>% 
  filter(scenario_name == "overdetermined positive",
         cause_set      %in% wanted_order) %>% 
  transmute(
    cause_set,
    mean = mean_judgment * 8 + 1      # ← now on 1-to-9 scale
  )

# ------------------------------------------------------------------
# b) model predictions   (same rescale)
# ------------------------------------------------------------------
model_df <- comparison_df_cumulative %>% 
  filter(scenario_name == "overdetermined positive",
         cause_set      %in% wanted_order) %>% 
  select(cause_set, model_pred) %>% 
  mutate(model_pred = model_pred * 8 + 1)   # ← 1-to-9

# ------------------------------------------------------------------
# c) merge + lock x-axis order (unchanged)
# ------------------------------------------------------------------
plot_df <- participant_summary %>% 
  left_join(model_df, by = "cause_set") %>% 
  mutate(cause_set = factor(cause_set, levels = wanted_order))

# ------------------------------------------------------------------
# d) plot
# ------------------------------------------------------------------
p_overdet <- ggplot(plot_df, aes(x = cause_set)) +
  geom_point(aes(y = mean),        colour = "black", size = 2) +
  geom_line (aes(y = mean, group = 1),
             colour = "black", linewidth = .8) +
  geom_point(aes(y = model_pred),  colour = "red",   size = 3) +
  geom_line (aes(y = model_pred, group = 1),
             colour = "red",  linewidth = 1.2) +
  scale_y_continuous(limits = c(1, 9), breaks = 1:9) +  # ← new range
  labs(x = NULL, y = "Prediction / Mean Response") +
  theme_minimal(base_size = 18) +
  theme(axis.text.x        = element_text(angle = 45, hjust = 1, size = 16),
        panel.grid.major = element_blank(),
         panel.grid.minor = element_blank(),
        axis.line          = element_line(size = .8),
        axis.ticks.length  = unit(0.15, "cm"),
        legend.position    = "none")

print(p_overdet)
ggsave("overdet_positive_lines.png", p_overdet,
       width = 9, height = 5, dpi = 300)

2.2 Fitting in the AsWeGo Condition

Next, we turn to the “AsWeGo” strategy. The best-fitting parameters now become:

  • \(\eta \approx 0.13\)
  • \(\gamma \approx 0.56\)
##############################################################################
## 6) Usage example: searching for max correlation, then plotting
##############################################################################


# We'll define param sets:

n_samples_vec <- 20
lr_vec        <- seq(0.1,0.25, by = 0.01)
w_vec         <- 1
update_vec    <- c("AsWeGo") # ,""Cumulative) 
gamma_vec     <- seq(0.5, 0.65, by = 0.02)

res_parallel_corr <- parallel_search_over_params_correlation(
  df_judgments,
  n_samples_set   = n_samples_vec,
  lr_set          = lr_vec,
  w_set           = w_vec,
  update_timing_set=update_vec,
  gamma_set       = gamma_vec,
  n_cores         = 4
)

cat("Best param:\n")
print(res_parallel_corr$best_params)
cat("Best correlation =", res_parallel_corr$best_correlation, "\n")

df_res_corr <- res_parallel_corr$results

# e.g. quick visualization of correlation heatmap

df_res_corr_plot <- df_res_corr %>% filter(n_samples == 20, gamma < 0.81, w == 1)
p_corr <- ggplot(df_res_corr_plot, aes(x=learning_rate, y=correlation, color=gamma, shape=as.factor(w))) +
  geom_point(size=3) +
  facet_wrap(~n_samples) +
  #facet_grid(n_samples ~ gamma, labeller=label_both) +
  labs(title="Grid Search: correlation measure",
       subtitle="Higher correlation is better") +
  scale_color_viridis_c(option="plasma", direction=-1) +
  theme_minimal(base_size=14)
print(p_corr)
##############################################################################
## 7) Final step: for each scenario, produce a side-by-side plot:
##    - data means vs. model predictions for the best param.
##############################################################################

## a) Extract the best param
#best_p <- as.list(res_parallel_corr$best_params[1, ])
# e.g. best_p$n_samples, best_p$learning_rate, best_p$w, best_p$update_timing, best_p$gamma
best_p <- list(
  n_samples     = 20,        # e.g., your chosen sample size
  learning_rate = 0.13,      # approximate best-fit η
  w             = 1,         # typically we fix w=1
  update_timing = "AsWeGo",
  gamma         = 0.55       # approximate best-fit γ
)
## b) We'll re-run the model for each scenario to get predicted kappa
##    and compare with the data's mean_judgment.


# define a function that merges the data and model predictions for each scenario
get_comparison_df <- function(df_judgments, best_p) {
  # We'll produce a data frame with columns:
  #   scenario_name, cause_set, data_mean, model_pred
  # for each scenario_name

  scenario_list <- unique(df_judgments$scenario_name)
  out_list <- list()

  for (scn in scenario_list) {
    info <- get_scenario_info(scn)
    # run model
    preds <- run_model_for_scenario(
      clauses             = info$clauses,
      input_probabilities = info$input_probabilities,
      initial_state       = info$initial_state,
      n_samples           = best_p$n_samples,
      learning_rate       = best_p$learning_rate,
      w                   = best_p$w,
      update_timing       = best_p$update_timing,
      gamma               = best_p$gamma
    )

    df_sub <- df_judgments[df_judgments$scenario_name==scn, ]
    # for each cause_set in df_sub, retrieve model
    # build small df with columns scenario_name, cause_set, data_mean, model_pred
    model_vec <- numeric(nrow(df_sub))
    for (i in seq_len(nrow(df_sub))) {
      cs <- df_sub$cause_set[i]
      model_vec[i] <- if (!is.na(preds[cs])) preds[cs] else NA
    }
    tmp <- data.frame(
      scenario_name = scn,
      cause_set     = df_sub$cause_set,
      data_mean     = df_sub$mean_judgment,
      model_pred    = model_vec
    )
    out_list[[length(out_list)+1]] <- tmp
  }
  do.call(rbind, out_list)
}

comparison_df <- get_comparison_df(df_judgments, best_p)

# We now have scenario_name, cause_set, data_mean, model_pred

## c) Plot for each scenario: a bar for data vs a "line" or "point" for model
##    We'll do a side-by-side or overlaid approach.

# approach 1: gather data_mean & model_pred into a "long" format, so we can do side-by-side bars
comparison_long <- comparison_df %>%
  tidyr::gather(key="type", value="value", data_mean, model_pred)

## Then do a grouped bar chart: x=cause_set, fill=type => data_mean or model_pred
## facet by scenario_name




## approach 2: 
##   - a bar for data, 
##   - a line or point for model in the same plot

p_comparison_2 <- ggplot(comparison_df, aes(x=cause_set)) +
  geom_col(aes(y=data_mean), fill="steelblue", alpha=0.5,
           width=0.6, position=position_nudge(x=-0.2)) +
  geom_point(aes(y=model_pred), color="darkred", size=3,
             position=position_nudge(x=+0.2)) +
  geom_line(aes(y=model_pred, group=1), color="darkred",
            position=position_nudge(x=+0.2)) +
  facet_wrap(~ scenario_name, scales="free_x") +
  labs(
    title = "Data (bar) vs. Model (point+line)",
    subtitle = "AsWeGo update timing, η=0.13, γ=0.56",
    x = "Cause Set",
    y = "Rating"
  ) +
  theme_minimal(base_size=14) +
  theme(axis.text.x=element_text(angle=45, hjust=1))

print(p_comparison_2)

##############################################################################
## In this code cell, we demonstrate two things:
## (1) How to visualize the model's predictions with the color-coded stacked-bar approach
##     (plot_causal_impact_scores()), using the best-fit parameters found for each 
##     update timing (Cumulative vs. AsWeGo).
## (2) How to compute a linear-model fit on *individual-participant* data—rather 
##     than on means alone—and then extract log-likelihood, AIC, and BIC for 
##     that model. We illustrate it for the “Cumulative” best-fit, and you 
##     can adapt the same procedure to “AsWeGo”.
##############################################################################

#------------------------------------------------------------------
# 1) Suppose we have best-fit parameters for Cumulative and AsWeGo
#    found via our correlation-based grid search:
#------------------------------------------------------------------
best_param_cumulative <- list(
  n_samples      = 20,
  learning_rate  = 1.15,
  w              = 1,
  update_timing  = "Cumulative",
  gamma          = 0.58
)

best_param_aswego <- list(
  n_samples      = 20,
  learning_rate  = 0.13,
  w              = 1,
  update_timing  = "AsWeGo",
  gamma          = 0.56
)


runent_model_for_scenario <- function(clauses, input_probabilities, initial_state,
                                   n_samples, learning_rate, w,
                                   update_timing="Cumulative",
                                   gamma=1) {
  # 1) Generate the sampling network
  sampling_network <- generate_network_from_logic_program(
    clauses              = clauses,
    input_probabilities  = input_probabilities,
    plot_network         = FALSE
  )
  net_ent <- sampling_network$network
  init_state <- generate_initial_hidden_states(initial_state, net_ent)

  # 2) Run the simulation
  result <- run_simulation(
    network               = sampling_network,
    initial_state         = init_state,
    input_values          = init_state,
    num_simulations       = 500,
    n_samples             = n_samples,
    sampling_method       = "MH",
    weight_update_timing  = update_timing,
    learning_rate         = learning_rate,
    weight_update_rule    = "Additive",
    w                     = w,
    return_samples        = FALSE
  )

  # 3) Extract and transform the raw kappa values
  preds <- result$impact_scores        # data.frame with columns: C, Z, sum_Rc, kappa
  preds$kappa <- preds$kappa ^ gamma
  
  # 4) Convert from "A+B" notation to "A,B" notation
  #    so that merging with df_individual is direct.
  #    The simplest approach is a global gsub on the C column:
  preds$C <- gsub("\\+", ",", preds$C)

  # Return a named vector, e.g. "A,B" => kappa
  setNames(preds$kappa, preds$C)
}

#------------------------------------------------------------------
# 2) We'll define a small helper to run the model for a 
#    scenario_name using these best-fit parameters. This 
#    function returns a "result_list"-like object so we can 
#    call plot_causal_impact_scores() on it.
#------------------------------------------------------------------
create_result_list_for_plot <- function(
  scenario_name,
  n_samples,
  learning_rate,
  w,
  update_timing,
  gamma
) {
  info <- get_scenario_info(scenario_name)  # must be defined in your environment
  # Re-run the model using run_model_for_scenario() but store
  # the final impact_scores in a structure similar to run_simulation()$impact_scores
  # We'll define an internal version that returns the same shape as a "result_list":

  # We'll rely on your run_model_for_scenario() that returns a named vector
  # cause_set => kappa_value. We'll transform that into a data.frame with columns:
  #   C, kappa, Z=??? (unknown; we set Z=1 if we do not have the actual Z from model),
  #   sum_Rc=??? (we can set it to 1 for plotting), etc.
  preds_vec <- runent_model_for_scenario(
    clauses             = info$clauses,
    input_probabilities = info$input_probabilities,
    initial_state       = info$initial_state,
    n_samples           = n_samples,
    learning_rate       = learning_rate,
    w                   = w,
    update_timing       = update_timing,
    gamma               = gamma
  )

  # Convert preds_vec => data.frame(C, kappa). We'll guess Z=1, sum_Rc=1 for plotting.
  df_impact <- data.frame(
    C     = names(preds_vec),
    kappa = as.numeric(preds_vec),
    Z     = 1,       # for a stacked-bar approach
    sum_Rc= 1        # placeholder
  )

  # We'll return a "result_list" structure that 
  #   - has $impact_scores => df_impact
  #   - has $importance_scores => an empty or dummy set for your code's sake
  #   - optionally store $nodes, $layers, etc., but not strictly necessary
  #   - store $w or $num_samples if desired
  # This is enough to let plot_causal_impact_scores() handle it.

  result_list <- list(
    impact_scores    = df_impact,
    importance_scores= numeric(0),  # we won't do partial decomposition
    num_samples      = n_samples
  )
  result_list
}

#------------------------------------------------------------------
# 3) Example: produce color-coded stacked-bar plots for 
#    "overdetermined positive" in Exp2, comparing 
#    best_param_cumulative vs best_param_aswego.
#------------------------------------------------------------------
# We'll define a function to plot each scenario with a custom label.
plot_scenario_with_best_params <- function(scenario_name, 
                                           param, 
                                           logic_program_label="Model Predictions") {
  result_list <- create_result_list_for_plot(
    scenario_name = scenario_name,
    n_samples     = param$n_samples,
    learning_rate = param$learning_rate,
    w             = param$w,
    update_timing = param$update_timing,
    gamma         = param$gamma
  )
  # We'll pass some "input_values" for color-coded legend. 
  # If scenario_name has A=1,B=1,..., we approximate the same as get_scenario_info
  # but not strictly required for the stacked portion. 
  # We'll do a small trick: we guess the relevant input nodes from the scenario
  # e.g. if we know the experiment is "A,B" only => input_nodes=c("A","B","C","D")
  
  # For Exp2, we have up to A,B,C,D. 
  # We'll define input_probabilities. If you want the real ones, you can call get_scenario_info again.
  # For a purely illustrative approach:
  info <- get_scenario_info(scenario_name)
  
  p <- plot_causal_impact_scores(
    result_list           = result_list,
    input_values          = info$initial_state,
    input_probabilities   = info$input_probabilities,
    logic_program_label   = logic_program_label,
    num_samples           = param$n_samples,
    w                     = param$w,
    initial_state         = info$initial_state,
    input_nodes           = c("A","B","C","D"),  # in Exp2
    output_nodes          = c("E")
  )
  p
}

# For demonstration: let's choose one scenario from Exp2,
# e.g. "overdetermined positive"
p_cum <- plot_scenario_with_best_params(
  scenario_name = "overdetermined positive",
  param         = best_param_cumulative,
  logic_program_label = "Cumulative Fit Predictions"
)
p_asw <- plot_scenario_with_best_params(
  scenario_name = "overdetermined positive",
  param         = best_param_aswego,
  logic_program_label = "AsWeGo Fit Predictions"
)

# You can print or arrange them side by side:
p_cum + p_asw


#------------------------------------------------------------------
# 4) Now we compute the linear-model fits on *individual* data 
#    for that same scenario, using the best-fit param (Cumulative
#    or AsWeGo). We'll illustrate with "Cumulative" here, 
#    but you can replicate for "AsWeGo".
#------------------------------------------------------------------

# a) We'll build a new data frame that has:
#    - subject_ID
#    - cause_set
#    - scaled_val (the individual's response)
#    - model_pred (the model's predicted rating for that cause_set)
#    from the best-param setting

## We'll define a function that returns "model_pred" for each cause_set, 
## then we can left_join participant data.

get_predictions_for_individuals <- function(df_raw, param) {
  # df_raw is the individual's data with columns:
  #   subject_ID, cause_set, scenario_name, scaled_val, etc.
  # param => best_param_xxx
  # We'll run the scenario model for each scenario_name, store predictions in a 
  # small data frame, then merge with df_raw by (scenario_name, cause_set).
  
  scenario_list <- unique(df_raw$scenario_name)
  
  pred_list <- list()
  for (scn in scenario_list) {
    info <- get_scenario_info(scn)
    preds_vec <- runent_model_for_scenario(
      clauses             = info$clauses,
      input_probabilities = info$input_probabilities,
      initial_state       = info$initial_state,
      n_samples           = param$n_samples,
      learning_rate       = param$learning_rate,
      w                   = param$w,
      update_timing       = param$update_timing,
      gamma               = param$gamma
    )
    # Turn that into a df: scenario_name, cause_set, model_pred
    df_tmp <- data.frame(
      scenario_name = scn,
      cause_set     = names(preds_vec),
      model_pred    = as.numeric(preds_vec)
    )
    pred_list[[length(pred_list)+1]] <- df_tmp
  }
  df_pred <- bind_rows(pred_list)
  
  # Now merge with df_raw
  df_merged <- df_raw %>%
    left_join(df_pred, by=c("scenario_name","cause_set"))
  
  df_merged
}


# Let’s define df_individual:

df_individual <- exp2_clean %>%
  rename(
    scenario_name = trial_kind
  ) %>%
  select(subject_ID, scenario_name, cause_set, scaled_val) %>%
  filter(!is.na(cause_set))

# Now we get the predictions:
df_individual_pred <- get_predictions_for_individuals(df_individual, best_param_cumulative)

# c) Build a linear model: scaled_val ~ 0 + model_pred
#    i.e. no intercept. We'll compute logLik, then AIC/BIC manually or with built-ins.
model_cumulative <- lm(scaled_val ~ 0 + model_pred, data=df_individual_pred)

summary(model_cumulative)
logLik_cumulative <- lrtest(model_cumulative)[1, 2]
cat("Log-likelihood (Cumulative Model) =", logLik_cumulative,"\n")

# d) Suppose we define K=2 as we have 2 free parameters: (learning_rate, gamma)
#    Then manual AIC = 2*K -2* logLik, BIC= log(N)*K -2* logLik
n_obs    <- nrow(df_individual_pred)
K        <- 2
AIC_cum   <- 2*K - 2*logLik_cumulative
BIC_cum   <- log(n_obs)*K - 2*logLik_cumulative

cat("AIC (Cumulative) =", AIC_cum,"\n")
cat("BIC (Cumulative) =", BIC_cum,"\n")

# e) If you want to do the same for "AsWeGo", just define:
# df_individual_pred_asw <- get_predictions_for_individuals(df_individual, best_param_aswego)
# model_aswego <- lm(scaled_val ~ 0 + model_pred, data=df_individual_pred_asw)
# etc.

2.3 Intuitive Explanations of Model Predictions

2.3.1 1. General Observations

Our model achieves a strong fit to the data. With the best-fit parameters (learning rate \(\eta \approx 0.13\) for AsWeGo, \(\eta \approx 1.15\) for Cumulative, and \(\gamma \approx 0.56\)), the model attains correlations with human judgments exceeding \(r = 0.85\) across experimental conditions. The correlation and AIC/BIC values are computed and displayed in the code output sections above.

This strong performance constitutes a strong argument in favor of our program-based approach. Below we give an intuitive explanation for the main patterns of predictions of the model.


2.3.2 2. Positive conditions

2.3.2.1 2.1 Overdetermined positive condition

2.3.2.1.1 2.1.1 Singular Variables: Abnormal Inflation & Deflation

In the Overdetermined positive condition, the player drew a colored ball from all four urns (\(A=1, B=1, C=1, D = 1\)). As a result they win the round (\(E = 1\)).

  • Variables with fewer colored balls (B, C) often receive higher causal importance than their partners (A, D) when they share the same branch in \(\bigl( (A \land B)\ \text{ vs. }(C\land D) \bigr)\). Intuitively, the model “inflates” the importance of the more “abnormal” event on each side of the disjunction.
  • However, variable \(C\) (which is normally less probable) can paradoxically end up lower in rating than \(A\). This is because \(A\) connects to the hidden neuron \(H_{ab}\) associated with the clause that is slightly more often “activated” given the overall probabilities. Thus, more relevance flows to \(H_{ab}\) than to \(H_{cd}\), and \(A\) inherits part of that credit.
2.3.2.1.2 2.1.2 Plural Causes

Several groupings of variables also stand out:

  • Cross-disjunction sets, such as \(A \land C\) or \(B \land D\), receive lower ratings than \(A \land B\) or \(C \land D\). This is a direct consequence of the notion of procedural complexity: the cross-disjunction pairs gather their impact onto the outcome via two different sources, while the latter’s contribution only travels through one branch of the network.

  • Triples (e.g., \(A \land B \land C\)) also incur extra penalty: they are more complex than the “pure pairs” in the same branch. They are, however, not more complex than the cross-disjunction pairs contained in them. That leads them to land at an intermediate rating in the model, consistent with participants’ own judgments.


2.3.3 3. Triple 1 Condition

In the \(\textbf{Triple 1}\) condition, three of the four urns are colored (\(A=1, B=1, C=1\)), but \(D=-1\). The outcome is still positive.

  • Singulars: Our model highlights the more “abnormal” event \(B\) here again. Meanwhile, \(D\) is essentially irrelevant to the observed outcome (it does not help produce the win, since it is \(-1\)), so it is assigned near-zero rating. Crucially, the model does not forcibly exclude it by any “first-phase filter”: it just happens that credit assignment to an unused path is negligible.
  • Plural causes:
    1. \(A \land B\) can reach a near-ceiling rating because it is fully responsible for the “successful” branch in the network (the hidden neuron for \(A,B\) saturates).
    2. Adding \(D\) to any cause set lowers the set’s rating, because \(D\) provides no additional path to the outcome, but it increases procedural complexity. The model sees “\(A \land B \land D\)” or “\(B \land C \land D\)” as more complex for no net gain, lowering their final scores.

2.3.4 4. Negative Conditions & the Role of Negation

For the Overdetermined Negative and Triple 0 conditions, the outcome is absent \((E\approx -1)\). This leads to a different outlook on the procedural complexity of the explanations for that outcome. As explained in the relevant section (Section 4.5 in the dissertation), for \(\neg E\) to be derived in a General Logic program means (1) to fail to derive it \(E\) from any of the clauses of the program and (2) to assume a special operator in our model that allows us to go from failure to derive \(E\) to a derivation of the negation \(\neg E\). This perspective changes the way in which judgments are going to be computed in two ways: (1) it shifts the credit attribution from positive to negative: events that contribute to the occurrence of the outcome should get negative, instead of positive credit for it; (2) it removes the penalty for complexity that obtained in the positive case: given that failure to derive \(E\) involves trying out all of the paths through which \(E\) could in principle have been derived anyway, there is no sense in which some explanations are more straightforward than others anyway.

2.3.4.1 4.1 Overdetermined Negative

  • We see a reversal of abnormal inflation and deflation across variables. If no change were made to the way in which credit was assigned, urns \(B\) and \(C\) would have gotten the highest score; instead our model favors urns \(A\) and \(D\), in line with participants’ judgments.

  • Similarly, we observe, in conformity with the changes in apprehending the path complexity, that predictions in that case follow a trend of “the more variables, the better”, which aligns with people’s judgments in that case too.

2.3.4.2 4.2 Triple 0

We observe here the same dynamics as in the Overdetermined negative case (and for the same reason); two dynamics do deserve to be highlighted in particular however: - Our model explains why the plural cause \((A\land B)\) can be assigned a higher credit than either one of A or B alone, in spite of the fact that either one of \(A\) or \(B\) would have been enough to explain the outcome here.

  • The variable \(D=-1\) here gets a particular high score; this is because it is single-handedly responsible for making the \((C\land D)\) path inactive (since \(C = 1\)), and therefore inherits all of the relevance flowing through that path. This highlights how the changes in complexity we encounter in the negative condition do not mean that all effects of structure disappear.

##############################################################################
## In this code cell, we compute a linear-model fit on *individual-participant* data—rather 
##     than on means alone—and then extract log-likelihood, AIC, and BIC for 
##     that model. We illustrate it for the “Cumulative” best-fit
##############################################################################

#------------------------------------------------------------------
# 1) Suppose we have best-fit parameters for Cumulative and AsWeGo
#    found via our correlation-based grid search:
#------------------------------------------------------------------
best_param_cumulative <- list(
  n_samples      = 20,
  learning_rate  = 1.15,
  w              = 1,
  update_timing  = "Cumulative",
  gamma          = 0.58
)

best_param_aswego <- list(
  n_samples      = 20,
  learning_rate  = 0.13,
  w              = 1,
  update_timing  = "AsWeGo",
  gamma          = 0.56
)


runent_model_for_scenario <- function(clauses, input_probabilities, initial_state,
                                   n_samples, learning_rate, w,
                                   update_timing="Cumulative",
                                   gamma=1) {
  # 1) Generate the sampling network
  sampling_network <- generate_network_from_logic_program(
    clauses              = clauses,
    input_probabilities  = input_probabilities,
    plot_network         = FALSE
  )
  net_ent <- sampling_network$network
  init_state <- generate_initial_hidden_states(initial_state, net_ent)

  # 2) Run the simulation
  result <- run_simulation(
    network               = sampling_network,
    initial_state         = init_state,
    input_values          = init_state,
    num_simulations       = 500,
    n_samples             = n_samples,
    sampling_method       = "MH",
    weight_update_timing  = update_timing,
    learning_rate         = learning_rate,
    weight_update_rule    = "Additive",
    w                     = w,
    return_samples        = FALSE
  )

  # 3) Extract and transform the raw kappa values
  preds <- result$impact_scores        # data.frame with columns: C, Z, sum_Rc, kappa
  preds$kappa <- preds$kappa ^ gamma
  
  # 4) Convert from "A+B" notation to "A,B" notation
  #    so that merging with df_individual is direct.
  #    The simplest approach is a global gsub on the C column:
  preds$C <- gsub("\\+", ",", preds$C)

  # Return a named vector, e.g. "A,B" => kappa
  setNames(preds$kappa, preds$C)
}

#------------------------------------------------------------------
# 2) We'll define a small helper to run the model for a 
#    scenario_name using these best-fit parameters. This 
#    function returns a "result_list"-like object so we can 
#    call plot_causal_impact_scores() on it.
#------------------------------------------------------------------
create_result_list_for_plot <- function(
  scenario_name,
  n_samples,
  learning_rate,
  w,
  update_timing,
  gamma
) {
  info <- get_scenario_info(scenario_name)  # must be defined in your environment
  # Re-run the model using run_model_for_scenario() but store
  # the final impact_scores in a structure similar to run_simulation()$impact_scores
  # We'll define an internal version that returns the same shape as a "result_list":

  # We'll rely on your run_model_for_scenario() that returns a named vector
  # cause_set => kappa_value. We'll transform that into a data.frame with columns:
  #   C, kappa, Z=??? (unknown; we set Z=1 if we do not have the actual Z from model),
  #   sum_Rc=??? (we can set it to 1 for plotting), etc.
  preds_vec <- runent_model_for_scenario(
    clauses             = info$clauses,
    input_probabilities = info$input_probabilities,
    initial_state       = info$initial_state,
    n_samples           = n_samples,
    learning_rate       = learning_rate,
    w                   = w,
    update_timing       = update_timing,
    gamma               = gamma
  )

  # Convert preds_vec => data.frame(C, kappa). We'll guess Z=1, sum_Rc=1 for plotting.
  df_impact <- data.frame(
    C     = names(preds_vec),
    kappa = as.numeric(preds_vec),
    Z     = 1,       # for a stacked-bar approach
    sum_Rc= 1        # placeholder
  )

  # We'll return a "result_list" structure that 
  #   - has $impact_scores => df_impact
  #   - has $importance_scores => an empty or dummy set for your code's sake
  #   - optionally store $nodes, $layers, etc., but not strictly necessary
  #   - store $w or $num_samples if desired
  # This is enough to let plot_causal_impact_scores() handle it.

  result_list <- list(
    impact_scores    = df_impact,
    importance_scores= numeric(0),  # we won't do partial decomposition
    num_samples      = n_samples
  )
  result_list
}

#------------------------------------------------------------------
# 3) Example: produce color-coded stacked-bar plots for 
#    "overdetermined positive" in Exp2, comparing 
#    best_param_cumulative vs best_param_aswego.
#------------------------------------------------------------------
# We'll define a function to plot each scenario with a custom label.
plot_scenario_with_best_params <- function(scenario_name, 
                                           param, 
                                           logic_program_label="Model Predictions") {
  result_list <- create_result_list_for_plot(
    scenario_name = scenario_name,
    n_samples     = param$n_samples,
    learning_rate = param$learning_rate,
    w             = param$w,
    update_timing = param$update_timing,
    gamma         = param$gamma
  )
  # We'll pass some "input_values" for color-coded legend. 
  # If scenario_name has A=1,B=1,..., we approximate the same as get_scenario_info
  # but not strictly required for the stacked portion. 
  # We'll do a small trick: we guess the relevant input nodes from the scenario
  # e.g. if we know the experiment is "A,B" only => input_nodes=c("A","B","C","D")
  
  # For Exp2, we have up to A,B,C,D.
  # We'll define input_probabilities. If you want the real ones, you can call get_scenario_info again.
  # For a purely illustrative approach:
  info <- get_scenario_info(scenario_name)

  p <- plot_causal_impact_scores(
    result_list           = result_list,
    input_values          = info$initial_state,
    input_probabilities   = info$input_probabilities,
    logic_program_label   = logic_program_label,
    num_samples           = param$n_samples,
    w                     = param$w,
    initial_state         = info$initial_state,
    input_nodes           = c("A","B","C","D"),  # in Exp2
    output_nodes          = c("E")
  )
  p
}
#------------------------------------------------------------------
# 4) Now we compute the linear-model fits on *individual* data 
#    for that same scenario, using the best-fit param (Cumulative
#    or AsWeGo). We'll illustrate with "Cumulative" here, 
#    but you can replicate for "AsWeGo".
#------------------------------------------------------------------

# a) We'll build a new data frame that has:
#    - subject_ID
#    - cause_set
#    - scaled_val (the individual's response)
#    - model_pred (the model's predicted rating for that cause_set)
#    from the best-param setting

## We'll define a function that returns "model_pred" for each cause_set, 
## then we can left_join participant data.

get_predictions_for_individuals <- function(df_raw, param) {
  # df_raw is the individual's data with columns:
  #   subject_ID, cause_set, scenario_name, scaled_val, etc.
  # param => best_param_xxx
  # We'll run the scenario model for each scenario_name, store predictions in a 
  # small data frame, then merge with df_raw by (scenario_name, cause_set).
  
  scenario_list <- unique(df_raw$scenario_name)
  
  pred_list <- list()
  for (scn in scenario_list) {
    info <- get_scenario_info(scn)
    preds_vec <- runent_model_for_scenario(
      clauses             = info$clauses,
      input_probabilities = info$input_probabilities,
      initial_state       = info$initial_state,
      n_samples           = param$n_samples,
      learning_rate       = param$learning_rate,
      w                   = param$w,
      update_timing       = param$update_timing,
      gamma               = param$gamma
    )
    # Turn that into a df: scenario_name, cause_set, model_pred
    df_tmp <- data.frame(
      scenario_name = scn,
      cause_set     = names(preds_vec),
      model_pred    = as.numeric(preds_vec)
    )
    pred_list[[length(pred_list)+1]] <- df_tmp
  }
  df_pred <- bind_rows(pred_list)
  
  # Now merge with df_raw
  df_merged <- df_raw %>%
    left_join(df_pred, by=c("scenario_name","cause_set"))
  
  df_merged
}


# Let’s define df_individual:

df_individual <- exp2_clean %>%
  rename(
    scenario_name = trial_kind
  ) %>%
  select(subject_ID, scenario_name, cause_set, scaled_val) %>%
  filter(!is.na(cause_set))

# Now we get the predictions:
df_individual_pred <- get_predictions_for_individuals(df_individual, best_param_cumulative)

# c) Build a linear model: scaled_val ~ 0 + model_pred
#    i.e. no intercept. We'll compute logLik, then AIC/BIC manually or with built-ins.
model_cumulative <- lm(scaled_val ~ 0 + model_pred, data=df_individual_pred)

summary(model_cumulative)
logLik_cumulative <- lrtest(model_cumulative)[1, 2]
cat("Log-likelihood (Cumulative Model) =", logLik_cumulative,"\n")

# d) Suppose we define K=2 as we have 2 free parameters: (learning_rate, gamma)
#    Then manual AIC = 2*K -2* logLik, BIC= log(N)*K -2* logLik
n_obs    <- nrow(df_individual_pred)
K        <- 2
AIC_cum   <- 2*K - 2*logLik_cumulative
BIC_cum   <- log(n_obs)*K - 2*logLik_cumulative

cat("AIC (Cumulative) =", AIC_cum,"\n")
cat("BIC (Cumulative) =", BIC_cum,"\n")

# e) Same can be done with the "AsWeGo" version.