Next, we investigate how our model’s predictions shift as we vary (in
both the “Cumulative” and “AsWeGo” version) its two main free
parameters: 1. The number of samples (\(n\)) collected in each simulation (shown on
a log scale). 2. The learning rate (\(\eta\)), which controls the magnitude of
each weight update.
We do so by fixing the probabilities of the focal and alternate cause
to 0.1 and 0.9, respectively, and then tracking how \(\kappa(A,E)\) varies with changes in those
parameters:
# ================================================================
# Exploring variation with learning rate and number of samples
# Uses 500 simulations per cell for stable results
# Results cached via knitr cache to avoid re-running
# ================================================================
structure_choices <- c("conjunctive","disjunctive") # or pick one
pA_fixed <- 0.1
pB_fixed <- 0.9
update_timing_vec2 <- c("Cumulative","AsWeGo")
learning_rate_vec2 <- c(0.1, 0.25, 0.5, 0.75, 1.0)
# Possibly n_samples in c(10,20,50,100,200,500) => roughly geometric
n_samples_vec2 <- c(10,20,50,100,200,500)
#num_sim_fixed2 <- 100 # or we can vary it
#### We define a new function that loops over structure, lr, n_samples, etc.
explore_lr_nsamps_case <- function(strct, pA, pB, ut, lr, nsamp, num_sim) {
# Minimal program for strct, same old
clauses <- build_minimal_program(strct)
input_probabilities <- list(A=pA, B=pB)
netw <- generate_network_from_logic_program(
clauses = clauses,
input_probabilities = input_probabilities,
plot_network = FALSE
)
init_io <- list(A=1,B=1,E=1)
net_ent <- netw$network
init_state <- generate_initial_hidden_states(init_io, net_ent)
res <- run_simulation(
network = netw,
initial_state = init_state,
input_values = init_io,
num_simulations = num_sim,
n_samples = nsamp,
sampling_method = 'MH',
weight_update_timing = ut,
learning_rate = lr,
weight_update_rule = 'Additive',
return_samples = FALSE
)
# Extract kappa(A,E)
rowA <- res$impact_scores[res$impact_scores$C=="A",]
kA <- if (nrow(rowA)==1) rowA$kappa else NA_real_
data.frame(
structure=strct,
pA=pA, pB=pB,
update_timing=ut,
learning_rate=lr,
n_samples=nsamp,
num_sim=num_sim,
kappaA=kA
)
}
#### We'll gather results for each structure & updateTiming,
#### while scanning over the learning_rate_vec2 & n_samples_vec2
df_part2 <- list()
idx2 <- 1
for (strct in structure_choices) {
print(strct)
for (ut in update_timing_vec2) {
print(ut)
for (lr in learning_rate_vec2) {
for (nsamp in n_samples_vec2) {
num_sim <- 10000/nsamp
row_df <- explore_lr_nsamps_case(
strct = strct,
pA = pA_fixed,
pB = pB_fixed,
ut = ut,
lr = lr,
nsamp = nsamp,
num_sim = num_sim
)
df_part2[[idx2]] <- row_df
idx2 <- idx2+1
}
}
}
}
df_part2 <- bind_rows(df_part2)
#### STEP 2: Plot them
#### We'll produce, for each structure, separate plots for update_timing maybe.
#### x-axis= n_samples, y= kappaA, color= factor(learning_rate).
#### Possibly use a log scale on x-axis.
p_part2 <- ggplot(df_part2, aes(x=n_samples, y=kappaA, color=factor(learning_rate))) +
geom_line(size=1) +
geom_point(size=2) +
scale_x_log10() + # to show a "chronological" or geometric scale
facet_grid(structure ~ update_timing, labeller=label_both) +
labs(
title = "Effect of n_samples & learning_rate on kappa(A,E)",
subtitle = paste("pA=", pA_fixed, " pB=", pB_fixed,
" (A=1,B=1 => E=1)"),
x = "Number of samples (log scale)",
y = expression(kappa(A,E)),
color = "Learning Rate"
) +
theme_minimal(base_size=14)
print(p_part2)
df_conjunctive <- df_part2 %>%
filter(structure == "conjunctive")
df_disjunctive <- df_part2 %>%
filter(structure == "disjunctive")
p_part2_disjunctive <- ggplot(
df_disjunctive,
aes(x = n_samples, y = kappaA, color = factor(learning_rate))
) +
geom_line(size=1) +
geom_point(size=2) +
scale_x_log10() +
facet_grid(
. ~ update_timing,
# This labeller ensures we only show the factor value, not "update_timing: ..."
labeller = labeller(update_timing = label_value)
) +
labs(
title = "Effect of n_samples & learning_rate on kappa(A,E)",
subtitle = "Logic Program:\n{ E \u2190 A ;\n E \u2190 B }",
x = "Number of samples (log scale)",
y = expression(kappa(A,E)),
color = "Learning Rate"
) +
theme_minimal(base_size = 14)
p_part2_conjunctive <- ggplot(
df_conjunctive,
aes(x = n_samples, y = kappaA, color = factor(learning_rate))
) +
geom_line(size=1) +
geom_point(size=2) +
scale_x_log10() +
facet_grid(
. ~ update_timing,
labeller = labeller(update_timing = label_value)
) +
labs(
title = "Effect of n_samples & learning_rate on kappa(A,E)",
subtitle = "Logic Program:\n{ E \u2190 A, B }",
x = "Number of samples (log scale)",
y = expression(kappa(A,E)),
color = "Learning Rate"
) +
theme_minimal(base_size = 14)
print(p_part2_disjunctive)
print(p_part2_conjunctive)
Main
Observations
Cumulative Strategy
In the right panel of each figure, we see that, for a given learning
rate \(\eta\), the model’s causal
score for \(A\) remains nearly constant
as we vary \(n\). This is because the
Cumulative method only applies weight updates once at the end.
Changing the number of samples has little effect on the network’s final
parameter updates, except for the fact that it makes the samples
generated less anchored to the initial state, which doesn’t have a great
importance for these simple abnormal inflation and deflation patterns.
By contrast, the effect of \(\eta\) is
both stable and monotonic, in that a larger \(\eta\) intensifies the updates. This
magnifies whichever pattern of abnormal inflation or deflation is
emergent, but does so in a fairly linear way.
AsWeGo Strategy
The left panel of the Figures depict the predictions
AsWeGo approach, with the same parameter ranges. Here
we observe stronger, and sometimes also nonlinear trends,
especially for the conjunctive logic program \(E \leftarrow A,B\):
For low number of samples (up to a 100), the impact of the learning
rate seems similar to the one it has in the “Cumulative” model: as we
move from a learning of 0.1 to 0.75, the pattern of abnormal inflation
is amplified, in that \(A\), here the
low probability variables, gets a higher score. The dynamics gets
overturned, however, when we consider the even greater learning rate of
1, which gives a lower score than any of the other learning rates
(!)
In a similar vein, when we consider number of samples of 100 and
above, we see a change in the ranking between the ratings obtained with
a learning rate of 0.5, and those obtained with 0.75: although the
latter generates higher scores for small numbers of samples, we see the
dynamic switching so that a learning rate of 0.5 now becomes more
advantageous.
The “Runaway”
Effect in the AsWeGo Strategy
Under AsWeGo, each sample modifies the network’s weights before the
next sample is drawn. Hence, if the initial states or initial updates
push the system in a certain direction (for instance, reinforcing the
weight from \(A\) to the outcome),
subsequent samples become more likely to favor that same cause.
This positive feedback loop can “run away” in the sense that it heavily
biases the sampling distribution toward states that keep boosting \(A\).
This dynamic explains why contrasts are larger (keeping learning rate
constant) in that version of the model, than in the “Cumulative”
version.
When pushed past a certain limit, it can also have even more dramatic
effects by changing the nature of the relationship between \(A,B\) and the outcome: in the
Conjunctive Example \(E \leftarrow
A,B\), the initial sampling episodes see \((A=1, B=1)\), since that is our actual
world’s starting point. These first steps reward the connections between
\(A\) and \(B\) and their common hidden node \(H_{ab}\), and decrease the size of the
(negative) bias on this latter. If updates are modest (low
\(\eta\)), the gate effectively remains
conjunctive, preserving “abnormal inflation” dynamic in favor of \(A\). But, if \(\eta\) is large or if \(n\) is high enough, the repeated updates
can eventually push the weights to approximate a “disjunctive” gate
instead. The activation of \(A\) and
\(B\) become such that the outcome node
might saturate whenever either occurs. In effect, the system
transitions from a conjunctive to a nearly disjunctive gate, causing
\(\kappa(A,E)\) to drop
sharply, because \(A\) is the abnormal
variable.
Rationale for
Fixing n
Because of the runaway effect, letting both \(\eta\) and \(n\) vary freely introduces a great deal of
complexity. Small differences in \(\eta\) can, at high \(n\), result in qualitatively different
prediction patterns, especially when updates are performed “As we
go”.
Hence, we choose to fix the the number of samples to \(\ (n=20)\) for subsequent analyses. This
choice has several advantages:
Reducing model complexity
Leaving both \(n\) and \(\eta\) as free parameters would consider
increase the model complexity, without any substantial upshot, as the
the impact of each parameter is redundant on that of the other in the
best case scenario. On top of that, their interdependence makes it
harder to properly fit either one to data, the moment we allow both to
vary. By holding \(n\) constant, we
reduce the dimensionality of the parameter space and can more reliably
fit (\(\eta\)).
Relation to anchoring As discussed in the
chapter dedicated, using a small number of samples guarantees that the
sampling process is anchored to the initial state to an extent. It
effectively derives the anchoring phenomenon from more fundamental facts
about how counterfactual states are generated. What is more, choosing
\(n=20\) puts the chain length of the
MCMC process in the same range as the kinds of values in other models of
causal judgments grounded in such sampling strategies (e.g., Davies
& Rehder, 2020).
Psychological Plausibility.
An additional element of plausibility for fixing \(n\) to a low value comes from a
consideration of the cognitive costs of sampling. Engaging in
counterfactual simulation is cognitively costly. Adopting a moderate
chain length like \(n=20\) aligns with
the idea that humans only consider enough alternative states to track
“normal” or “plausible” alternative scenarios, but do not necessarily
exhaustively search the space of possibilities.
Robustness of the model in the “Cumulative”
version
Although the precise number \(\
(n=20)\) we choose is arbitrary (beyond the fact it sits in a
plausible range of small values, as per the arguments in (2)-(3) above),
the model dynamics are not hurt by this arbitrariness. Indeed, as shown
above, the predictions barely vary with the number of samples in the
“Cumulative” version, and still relatively little in the “AsWeGo”
version, provided that the number we picked is small.
Provided that fixed number of samples, we will, in a second step, fit
the learning rate parameter for each of the update timing strategy to
data collected from our second experiment in Konuk et al. (under
review).
############################################################################
# One-stop pipeline to load data for Exp1 & Exp2, exclude comprehension
# failures, parse questions into letter-based sets, scale responses, and
# combine the data frames into df_judgments.
############################################################################
# HELPER FUNCTIONS
############################################################################
# For Exp1, we only expect boxes A, B, C.
# We'll ignore the actual number of colored balls. We'll just parse the
# question text to see if it references box A, box B, box C, or combos.
extract_letters_exp1 <- function(question_string) {
if (is.na(question_string)) return(NA_character_)
# Gather all occurrences of A/B/C
letters_found <- unlist(regmatches(question_string, gregexpr("[A-C]", question_string)))
if (length(letters_found) == 0) return(NA_character_)
letters_found <- sort(unique(letters_found))
paste(letters_found, collapse = ",")
}
# For Exp2, the question might involve any subset of A/B/C/D.
extract_letters_exp2 <- function(question_string) {
if (is.na(question_string)) return(NA_character_)
# Gather all occurrences of A/B/C/D
letters_found <- unlist(regmatches(question_string, gregexpr("[A-D]", question_string)))
if (length(letters_found) == 0) return(NA_character_)
letters_found <- sort(unique(letters_found))
paste(letters_found, collapse = ",")
}
############################################################################
# A) Experiment 1
############################################################################
# 1) Load data
exp1_raw <- read_csv("/Users/cankonuk/Documents/LAB/plural-actual-causes/analyses/quillien-2a-plural-version/experiment1-plurals.csv")
## Rows: 13101 Columns: 17
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (13): trial_type, trial_index, internal_node_id, subject_ID, response, a...
## dbl (3): rt, time_elapsed, group
## lgl (1): scoring
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# 2) Exclude comprehension failures
incorrect_subjects_exp1 <- exp1_raw %>%
filter(
(trial_index == "6.WhitePoints" & response != "0") |
(trial_index == "6.WinningCondition" & response != "You must get at least 2 points")
) %>%
distinct(subject_ID)
exp1_clean <- exp1_raw %>%
anti_join(incorrect_subjects_exp1, by = "subject_ID")
# 3) Filter to target trials (20..26). Convert response to numeric if needed.
exp1_clean <- exp1_clean %>%
mutate(trial_index_num = as.numeric(trial_index)) %>%
filter(trial_index_num >= 20, trial_index_num <= 26)
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `trial_index_num = as.numeric(trial_index)`.
## Caused by warning:
## ! NAs introduced by coercion
# 4) Create a cause_set column with the letter(s) referenced by the question
# (instead of "lo"/"mid"/"hi"/"all three").
exp1_clean <- exp1_clean %>%
rowwise() %>%
mutate(
cause_set = extract_letters_exp1(question), # A, B, C, or combos
# In the original code, "target_response" was the 1..9 scale.
# If you only have "response", adapt as needed:
# Possibly: scaled_val = (as.numeric(response) + 1 - 1) / 8
# or if "target_response" is the correct 1..9 measure:
scaled_val = as.numeric(response) / 8 # 1->0, 9->1
)
# 5) Summarize by cause_set
df_exp1 <- exp1_clean %>%
filter(!is.na(cause_set)) %>%
group_by(cause_set) %>%
summarize(
mean_judgment = mean(scaled_val, na.rm = TRUE),
.groups = "drop"
) %>%
mutate(
experiment = "Exp1",
scenario_name = "All-Pos" # or whichever label you prefer
) %>%
select(experiment, scenario_name, cause_set, mean_judgment)
############################################################################
# B) Experiment 2
############################################################################
# 1) Load data
exp2_raw <- read_csv("/Users/cankonuk/Documents/LAB/plural-actual-causes/analyses/conj-disj-analysis/data-conj-disj.csv")
## Rows: 24656 Columns: 34
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (15): trial_type, trial_index, internal_node_id, subject_ID, response, a...
## dbl (13): rt, time_elapsed, group, urn--n-balls, urn--n-colored-balls, urn-A...
## lgl (6): urn--is-results-color, urn-A-is-results-color, urn-C-is-results-co...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# 2) Exclude comprehension failures
failed_subjects_whiteballs <- exp2_raw %>%
filter(trial_index == "17.WhiteBalls", response != "2") %>%
pull(subject_ID) %>%
unique()
failed_subjects_winningcondition <- exp2_raw %>%
filter(
trial_index == "17.WinningCondition",
response != "You must draw two yellow, or two purple balls"
) %>%
pull(subject_ID) %>%
unique()
failed_subjects_exp2 <- unique(c(failed_subjects_whiteballs, failed_subjects_winningcondition))
exp2_clean <- exp2_raw %>%
filter(!(subject_ID %in% failed_subjects_exp2))
# 3) Identify & filter to the relevant "trial_kind" conditions
exp2_clean <- exp2_clean %>%
rowwise() %>%
mutate(
trial_kind = case_when(
trial_index %in% 6:15 ~ "habituation",
# triple 0
(`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == FALSE &
`urn-B-n-colored-balls` == 2 & `urn-B-is-results-color` == FALSE &
`urn-C-n-colored-balls` == 4 & `urn-C-is-results-color` == TRUE &
`urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == FALSE) ~ "triple 0",
# overdetermined negative
(`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == FALSE &
`urn-B-n-colored-balls` == 2 & `urn-B-is-results-color` == FALSE &
`urn-C-n-colored-balls` == 4 & `urn-C-is-results-color` == FALSE &
`urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == FALSE) ~ "overdetermined negative",
# overdetermined positive
(`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == TRUE &
`urn-B-n-colored-balls` == 2 & `urn-B-is-results-color` == TRUE &
`urn-C-n-colored-balls` == 4 & `urn-C-is-results-color` == TRUE &
`urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == TRUE) ~ "overdetermined positive",
# triple 1
(`urn-A-n-colored-balls` == 14 & `urn-A-is-results-color` == TRUE &
`urn-B-n-colored-balls` == 2 & `urn-B-is-results-color` == TRUE &
`urn-C-n-colored-balls` == 4 & `urn-C-is-results-color` == FALSE &
`urn-D-n-colored-balls` == 18 & `urn-D-is-results-color` == TRUE) ~ "triple 1",
TRUE ~ NA_character_
)
) %>%
filter(trial_kind %in% c("overdetermined positive", "triple 1",
"overdetermined negative", "triple 0"))
# 4) Extract the letter-based sets (A, B, C, D) from the question and scale response
exp2_clean <- exp2_clean %>%
mutate(
cause_set = extract_letters_exp2(question),
scaled_val = as.numeric(response) / 8 # scale 0..8 => [0..1]
)
# 5) Summarize
df_exp2 <- exp2_clean %>%
filter(!is.na(cause_set)) %>%
group_by(trial_kind, cause_set) %>%
summarize(
mean_judgment = mean(scaled_val, na.rm = TRUE),
.groups = "drop"
) %>%
rename(scenario_name = trial_kind) %>%
mutate(experiment = "Exp2") %>%
select(experiment, scenario_name, cause_set, mean_judgment)
############################################################################
# C) Combine both experiments
############################################################################
#df_judgments <- bind_rows(df_exp1, df_exp2)
df_judgments <- df_exp2
#print(df_judgments)
############################################################################
# D) Example quick plot
############################################################################
cause_levels <- c("A","B","C","D","A,B","A,C","A,D", "B,C","B,D", "C,D", "A,B,C", "A,B,D", "A,C,D", "B,C,D")
df_judgments$cause_set <- factor(df_judgments$cause_set, levels=cause_levels)
p <- ggplot(df_judgments, aes(x = cause_set, y = mean_judgment, fill = experiment)) +
geom_col(position = position_dodge(width = 0.7), width = 0.6) +
facet_wrap(~ scenario_name, scales = "free_x") +
labs(
title = "Participants’ Mean Judgments Across Conditions",
x = "Cause Set",
y = "Mean Judgment (scaled 0–1)",
fill = "Experiment"
) +
theme_minimal(base_size = 14) +
theme(
axis.text.x = element_text(angle = 45, hjust = 1),
panel.grid.major.x = element_blank(),
panel.grid.minor = element_blank()
)
print(p)

##############################################################################
## 1) A function that runs the model for a given scenario and parameters,
## returning a named vector of predictions, then optionally raises them
## to gamma. (Same approach as before.)
##############################################################################
run_model_for_scenario <- function(clauses, input_probabilities, initial_state,
n_samples, learning_rate, w,
update_timing="Cumulative",
gamma=1) {
# Build the sampling network
sampling_network <- generate_network_from_logic_program(
clauses = clauses,
input_probabilities = input_probabilities,
plot_network = FALSE
)
net_ent <- sampling_network$network
init_state <- generate_initial_hidden_states(initial_state, net_ent)
# Run the simulation
result <- run_simulation(
network = sampling_network,
initial_state = init_state,
input_values = init_state,
num_simulations = 500,
n_samples = n_samples,
sampling_method = "MH",
weight_update_timing = update_timing,
learning_rate = learning_rate,
weight_update_rule = "Additive",
w = w,
return_samples = FALSE
)
# Extract raw kappa values
preds <- result$impact_scores # data.frame: C, kappa
# Raise them by gamma
preds$kappa <- preds$kappa ^ gamma
# Return named vector
setNames(preds$kappa, preds$C)
}
##############################################################################
## 2) A function that obtains scenario info (clauses, probabilities, init_state)
## given a scenario_name.
##############################################################################
get_scenario_info <- function(scenario_name) {
# Fill in your real logic. Example stubs:
if (scenario_name=="All-Pos") {
clauses <- list(
list(head="E", body=c("A","B")),
list(head="E", body=c("A","C")),
list(head="E", body=c("B","C"))
)
input_probabilities <- list(A=0.05, B=0.5, C=0.95)
initial_state <- list(A=1,B=1,C=1,E=1)
} else if (scenario_name=="overdetermined positive") {
clauses <- list(
list(head="E", body=c("A","B")),
list(head="E", body=c("C","D"))
)
input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
initial_state <- list(A=1,B=1,C=1,D=1,E=1)
} else if (scenario_name=="triple 1") {
clauses <- list(
list(head="E", body=c("A","B")),
list(head="E", body=c("C","D"))
)
input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
initial_state <- list(A=1,B=1,C=1,D=-1,E=1)
} else if (scenario_name=="overdetermined negative") {
clauses <- list(
list(head="E", body=c("A","B")),
list(head="E", body=c("C","D"))
)
input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
initial_state <- list(A=-1,B=-1,C=-1,D=-1,E=-1)
} else if (scenario_name=="triple 0") {
clauses <- list(
list(head="E", body=c("A","B")),
list(head="E", body=c("C","D"))
)
input_probabilities <- list(A=0.7,B=0.1,C=0.3,D=0.9)
initial_state <- list(A=-1,B=-1,C=1,D=-1,E=-1)
} else {
stop("Unknown scenario: ", scenario_name)
}
list(clauses=clauses, input_probabilities=input_probabilities,
initial_state=initial_state)
}
##############################################################################
## 3) A new measure of fit: item-level correlation between data and model
##############################################################################
compute_correlation_for_dataset <- function(df_judgments, param) {
# param is a list with (n_samples, learning_rate, w, update_timing, gamma)
# We'll collect predicted kappa vs. actual means for each item across scenarios
predictions_vec <- c()
data_vec <- c()
scenario_list <- unique(df_judgments$scenario_name)
for (scn in scenario_list) {
# gather scenario info
info <- get_scenario_info(scn)
# run model
pred_kappa <- run_model_for_scenario(
clauses = info$clauses,
input_probabilities = info$input_probabilities,
initial_state = info$initial_state,
n_samples = param$n_samples,
learning_rate = param$learning_rate,
w = param$w,
update_timing = param$update_timing,
gamma = param$gamma
)
df_sub <- df_judgments[df_judgments$scenario_name==scn, ]
for (i in seq_len(nrow(df_sub))) {
cause <- df_sub$cause_set[i]
obs <- df_sub$mean_judgment[i] # data
if (!is.na(pred_kappa[cause])) {
# store predicted vs. data
predictions_vec <- c(predictions_vec, pred_kappa[cause])
data_vec <- c(data_vec, obs)
}
}
}
# If we have fewer than 2 items, correlation is undefined => return 0 or -Inf
if (length(predictions_vec) < 2) {
return(NA_real_)
}
# compute Pearson correlation
cor_val <- cor(predictions_vec, data_vec, use="complete.obs", method="pearson")
cor_val
}
##############################################################################
## 4) A small wrapper that the parallel workers call for each param combo,
## returning a 1-row data.frame with correlation
##############################################################################
calc_correlation <- function(df_judgments, ns, lr, w, ut, gm) {
param_list <- list(n_samples=ns, learning_rate=lr, w=w,
update_timing=ut, gamma=gm)
cor_val <- compute_correlation_for_dataset(df_judgments, param_list)
data.frame(n_samples=ns, learning_rate=lr, w=w, update_timing=ut,
gamma=gm, correlation=cor_val)
}
##############################################################################
## 5) Parallel search for param combos => we want the maximum correlation
##############################################################################
parallel_search_over_params_correlation <- function(df_judgments,
n_samples_set,
lr_set,
w_set,
update_timing_set,
gamma_set,
n_cores=2) {
library(doParallel)
library(foreach)
param_df <- expand.grid(
n_samples = n_samples_set,
learning_rate = lr_set,
w = w_set,
update_timing = update_timing_set,
gamma = gamma_set,
stringsAsFactors=FALSE
)
cl <- makeCluster(n_cores)
registerDoParallel(cl)
# each iteration => .combine=rbind => a data.frame row with correlation
results <- foreach(i=1:nrow(param_df), .combine=rbind,
.packages=c("dplyr"),
.export=c(
# all subfunctions
"calc_correlation",
"compute_correlation_for_dataset",
"run_model_for_scenario",
"get_scenario_info",
"generate_network_from_logic_program",
"generate_initial_hidden_states",
"run_simulation",
"perform_sampling",
"mh_layered_sampler",
"layer_gibbs_sampler",
"compute_lrp_importance",
"compute_causal_impact_scores",
"sigmoid",
"compute_loss_NLL",
"activation_function",
"compute_weight_updates",
"compute_outputs"
)) %dopar% {
ns <- param_df$n_samples[i]
lr <- param_df$learning_rate[i]
ww <- param_df$w[i]
ut <- param_df$update_timing[i]
gm <- param_df$gamma[i]
row_df <- calc_correlation(df_judgments, ns, lr, ww, ut, gm)
row_df
}
stopCluster(cl)
# find best row => max correlation
idx_best <- which.max(results$correlation)
best_row <- results[idx_best, ]
list(
best_correlation = best_row$correlation,
best_params = best_row[, c("n_samples","learning_rate","w","update_timing","gamma")],
results = results
)
}