Ladislas Nalborczyk
  • HOME
  • RESEARCH
  • PUBLICATIONS
  • TEAM
  • POSITIONS
  • CV
  • BLOG
  • INTERNAL

Drift Diffusion Model

#| '!! shinylive warning !!': |
#|   shinylive does not work in self-contained HTML documents.
#|   Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 1200

# app.R
# Shiny visualiser for 4- and 7-parameter drift diffusion models.
# Run with: shiny::runApp("path/to/folder_containing_app.R")

required_packages <- c(
    "shiny",
    "bslib",
    "ggplot2",
    "dplyr",
    "tidyr",
    "tibble",
    "scales"
)

missing_packages <- required_packages[
    !vapply(required_packages, requireNamespace, logical(1), quietly = TRUE)
]

if (length(missing_packages) > 0) {
    stop(
        "Please install the following packages before running the app: ",
        paste(missing_packages, collapse = ", "),
        call. = FALSE
    )
}

library(shiny)
library(bslib)
library(ggplot2)
library(dplyr)
library(tidyr)
library(tibble)
library(scales)

# -----------------------------------------------------------------------------
# Simulation helpers
# -----------------------------------------------------------------------------

clip <- function(x, lower, upper) {
    pmin(pmax(x, lower), upper)
}

sample_trial_parameters <- function(n, a, z, v, tau, eta, sz, stau) {
    # z and sz are expressed as proportions of boundary separation.
    # tau and stau are expressed in seconds.
    drift <- if (eta > 0) {
        rnorm(n, mean = v, sd = eta)
    } else {
        rep(v, n)
    }

    z_prop <- if (sz > 0) {
        runif(n, min = z - sz / 2, max = z + sz / 2)
    } else {
        rep(z, n)
    }

    ndt <- if (stau > 0) {
        runif(n, min = tau - stau / 2, max = tau + stau / 2)
    } else {
        rep(tau, n)
    }

    tibble(
        trial = seq_len(n),
        drift = drift,
        z_prop = clip(z_prop, lower = 0.02, upper = 0.98),
        z_abs = clip(z_prop, lower = 0.02, upper = 0.98) * a,
        ndt = pmax(ndt, 0)
    )
}

simulate_ddm <- function(
    n_trials = 2000,
    a = 1.5,
    z = 0.5,
    v = 1.0,
    tau = 0.3,
    eta = 0,
    sz = 0,
    stau = 0,
    dt = 0.002,
    max_decision_time = 5,
    seed = NULL
) {
    if (!is.null(seed) && !is.na(seed)) {
        set.seed(seed)
    }

    pars <- sample_trial_parameters(
        n = n_trials,
        a = a,
        z = z,
        v = v,
        tau = tau,
        eta = eta,
        sz = sz,
        stau = stau
    )

    evidence <- pars$z_abs
    status <- integer(n_trials)
    decision_time <- rep(NA_real_, n_trials)
    n_steps <- ceiling(max_decision_time / dt)

    # status: -1 = lower boundary, 0 = active/censored, 1 = upper boundary.
    for (step in seq_len(n_steps)) {
        active <- which(status == 0L)

        if (length(active) == 0L) {
            break
        }

        evidence[active] <- evidence[active] +
            pars$drift[active] * dt +
            sqrt(dt) * rnorm(length(active))

        hit_upper <- active[evidence[active] >= a]
        hit_lower <- active[evidence[active] <= 0]

        if (length(hit_upper) > 0L) {
            status[hit_upper] <- 1L
            decision_time[hit_upper] <- step * dt
        }

        if (length(hit_lower) > 0L) {
            status[hit_lower] <- -1L
            decision_time[hit_lower] <- step * dt
        }
    }

    pars |>
        mutate(
            choice = case_when(
                status == 1L ~ "upper",
                status == -1L ~ "lower",
                TRUE ~ "censored"
            ),
            choice = factor(choice, levels = c("lower", "upper", "censored")),
            decision_time = decision_time,
            rt = decision_time + ndt,
            evidence_final = evidence
        )
}

simulate_paths <- function(
    n_paths = 25,
    a = 1.5,
    z = 0.5,
    v = 1.0,
    tau = 0.3,
    eta = 0,
    sz = 0,
    stau = 0,
    dt = 0.005,
    max_decision_time = 4,
    seed = NULL
) {
    if (!is.null(seed) && !is.na(seed)) {
        set.seed(seed + 1L)
    }

    pars <- sample_trial_parameters(
        n = n_paths,
        a = a,
        z = z,
        v = v,
        tau = tau,
        eta = eta,
        sz = sz,
        stau = stau
    )

    n_steps <- ceiling(max_decision_time / dt)
    paths <- vector("list", n_paths)

    for (i in seq_len(n_paths)) {
        evidence_values <- numeric(n_steps + 1L)
        time_values <- numeric(n_steps + 1L)
        outcome_values <- rep("ongoing", n_steps + 1L)

        evidence_values[1L] <- pars$z_abs[i]
        time_values[1L] <- pars$ndt[i]

        status <- "censored"
        last_row <- 1L

        for (step in seq_len(n_steps)) {
            current_evidence <- evidence_values[step] +
                pars$drift[i] * dt +
                sqrt(dt) * rnorm(1)

            current_time <- pars$ndt[i] + step * dt
            last_row <- step + 1L

            if (current_evidence >= a) {
                status <- "upper"
                evidence_values[last_row] <- a
                time_values[last_row] <- current_time
                outcome_values[last_row] <- status
                break
            }

            if (current_evidence <= 0) {
                status <- "lower"
                evidence_values[last_row] <- 0
                time_values[last_row] <- current_time
                outcome_values[last_row] <- status
                break
            }

            evidence_values[last_row] <- current_evidence
            time_values[last_row] <- current_time
        }

        if (status == "censored") {
            outcome_values[last_row] <- "censored"
        }

        paths[[i]] <- tibble(
            path = i,
            time = time_values[seq_len(last_row)],
            evidence = evidence_values[seq_len(last_row)],
            outcome = outcome_values[seq_len(last_row)]
        ) |>
            mutate(
                outcome = factor(
                    if_else(outcome == "ongoing", status, outcome),
                    levels = c("lower", "upper", "censored")
                ),
                ndt = pars$ndt[i],
                drift = pars$drift[i],
                z_abs = pars$z_abs[i]
            )
    }

    bind_rows(paths)
}

empty_boundary_density <- function() {
    tibble(
        rt = numeric(),
        ymin = numeric(),
        ymax = numeric(),
        choice = factor(levels = c("lower", "upper"))
    )
}

make_boundary_density_df <- function(sim, a, density_height = 0.30) {
    dat <- sim |>
        filter(!is.na(rt), choice %in% c("lower", "upper"))

    if (nrow(dat) < 10) {
        return(empty_boundary_density())
    }

    max_rt <- max(dat$rt, na.rm = TRUE)
    n_total <- nrow(sim)

    raw_dens <- lapply(c("lower", "upper"), function(ch) {
        rt_values <- dat$rt[dat$choice == ch]
        n_choice <- length(rt_values)

        if (n_choice < 3 || length(unique(round(rt_values, 4))) < 3) {
            return(
                tibble(
                    rt = numeric(),
                    weighted_density = numeric(),
                    choice = character()
                )
            )
        }

        density_estimate <- density(
            rt_values,
            from = 0,
            to = max_rt,
            n = 512,
            adjust = 1.1
        )

        # Standard density estimates integrate to 1. Multiplying by the
        # response proportion makes the area proportional to trial count.
        tibble(
            rt = density_estimate$x,
            weighted_density = density_estimate$y * (n_choice / n_total),
            choice = ch
        )
    }) |>
        bind_rows()

    if (
        nrow(raw_dens) == 0 ||
        max(raw_dens$weighted_density, na.rm = TRUE) <= 0
    ) {
        return(empty_boundary_density())
    }

    height_scale <- density_height * a / max(raw_dens$weighted_density, na.rm = TRUE)

    raw_dens |>
        mutate(
            density_height = weighted_density * height_scale,
            ymin = if_else(choice == "upper", a, -density_height),
            ymax = if_else(choice == "upper", a + density_height, 0),
            choice = factor(choice, levels = c("lower", "upper"))
        )
}

summarise_simulation <- function(sim) {
    valid <- sim |>
        filter(!is.na(rt), choice %in% c("lower", "upper"))

    tibble(
        statistic = c(
            "P(upper boundary)",
            "P(lower boundary)",
            "Median RT | upper",
            "Median RT | lower",
            "Mean RT | upper",
            "Mean RT | lower",
            "Censored trials"
        ),
        value = c(
            percent(mean(sim$choice == "upper", na.rm = TRUE), accuracy = 0.1),
            percent(mean(sim$choice == "lower", na.rm = TRUE), accuracy = 0.1),
            sprintf(
                "%.3f s",
                median(valid$rt[valid$choice == "upper"], na.rm = TRUE)
            ),
            sprintf(
                "%.3f s",
                median(valid$rt[valid$choice == "lower"], na.rm = TRUE)
            ),
            sprintf(
                "%.3f s",
                mean(valid$rt[valid$choice == "upper"], na.rm = TRUE)
            ),
            sprintf(
                "%.3f s",
                mean(valid$rt[valid$choice == "lower"], na.rm = TRUE)
            ),
            percent(mean(sim$choice == "censored", na.rm = TRUE), accuracy = 0.1)
        )
    )
}

# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------

app_theme <- bs_theme(
    version = 5,
    bootswatch = "flatly",
    primary = "#334155",
    secondary = "#64748b",
    success = "#0f766e",
    base_font = "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif",
    heading_font = "system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif"
    )

ui <- page_fluid(
    theme = app_theme,
    tags$head(
        tags$style(HTML("\n            :root {\n                --ddm-ink: #0f172a;\n                --ddm-muted: #64748b;\n                --ddm-card: rgba(255, 255, 255, 0.92);\n                --ddm-line: rgba(148, 163, 184, 0.35);\n            }\n\n            body {\n                background:\n                    radial-gradient(circle at top left, rgba(45, 212, 191, 0.20), transparent 35%),\n                    radial-gradient(circle at top right, rgba(129, 140, 248, 0.16), transparent 30%),\n                    linear-gradient(180deg, #f8fafc 0%, #eef2f7 100%);\n                color: var(--ddm-ink);\n            }\n\n            .hero {\n                border-radius: 28px;\n                padding: 28px 32px;\n                margin: 20px 0 18px 0;\n                background: linear-gradient(135deg, rgba(15, 23, 42, 0.96), rgba(51, 65, 85, 0.92));\n                color: white;\n                box-shadow: 0 22px 60px rgba(15, 23, 42, 0.16);\n            }\n\n            .hero h1 {\n                font-weight: 800;\n                letter-spacing: -0.04em;\n                margin-bottom: 6px;\n            }\n\n            .hero p {\n                max-width: 920px;\n                color: rgba(255, 255, 255, 0.78);\n                margin-bottom: 0;\n                font-size: 1.02rem;\n            }\n\n            .card {\n                border: 1px solid var(--ddm-line);\n                border-radius: 22px !important;\n                box-shadow: 0 12px 35px rgba(15, 23, 42, 0.08);\n                background: var(--ddm-card);\n            }\n\n            .card-header {\n                background: transparent;\n                border-bottom: 1px solid var(--ddm-line);\n                font-weight: 750;\n                letter-spacing: -0.02em;\n            }\n\n            .bslib-sidebar-layout > .sidebar {\n                border-radius: 22px;\n                border: 1px solid var(--ddm-line);\n                box-shadow: 0 12px 35px rgba(15, 23, 42, 0.08);\n            }\n\n            .form-label, label {\n                font-weight: 650;\n            }\n\n            .help-text {\n                color: var(--ddm-muted);\n                font-size: 0.92rem;\n            }\n\n            .btn-primary {\n                border-radius: 999px;\n                font-weight: 750;\n                padding: 0.70rem 1rem;\n            }\n\n            .small-note {\n                color: #64748b;\n                font-size: 0.88rem;\n            }\n        "))
    ),
    div(
        class = "hero",
        h1("Drift diffusion model visualiser"),
        p(
            "Explore how boundary separation, starting point, drift rate, ",
            "non-decision time, and across-trial variability shape response-time ",
            "distributions and decision trajectories."
        )
    ),
    layout_sidebar(
        sidebar = sidebar(
            width = 360,
            h4("Model settings"),
            radioButtons(
                inputId = "model_type",
                label = NULL,
                choices = c(
                    "4-parameter DDM" = "4",
                    "7-parameter DDM" = "7"
                ),
                selected = "4"
            ),
            sliderInput(
                "a",
                "Boundary separation, a",
                min = 0.40,
                max = 3.50,
                value = 1.50,
                step = 0.05
            ),
            sliderInput(
                "z",
                "Starting point, z/a",
                min = 0.05,
                max = 0.95,
                value = 0.50,
                step = 0.01
            ),
            sliderInput(
                "v",
                "Drift rate, v",
                min = -4,
                max = 4,
                value = 1.00,
                step = 0.05
            ),
            sliderInput(
                "tau",
                "Non-decision time, tau",
                min = 0,
                max = 1.20,
                value = 0.30,
                step = 0.01
            ),
            conditionalPanel(
                condition = "input.model_type == '7'",
                hr(),
                h4("Across-trial variability"),
                sliderInput(
                    "eta",
                    "Drift variability, eta",
                    min = 0,
                    max = 3,
                    value = 0.50,
                    step = 0.05
                ),
                sliderInput(
                    "sz",
                    "Starting-point range, sz",
                    min = 0,
                    max = 0.90,
                    value = 0.20,
                    step = 0.01
                ),
                sliderInput(
                    "stau",
                    "Non-decision range, stau",
                    min = 0,
                    max = 1.00,
                    value = 0.15,
                    step = 0.01
                )
            ),
            hr(),
            h4("Simulation"),
            sliderInput(
                "n_trials",
                "Trials",
                min = 300,
                max = 4000,
                value = 1200,
                step = 100
            ),
            sliderInput(
                "n_paths",
                "Displayed trajectories",
                min = 5,
                max = 60,
                value = 20,
                step = 5
            ),
            selectInput(
                "dt",
                "Time step",
                choices = c(
                    "1 ms" = 0.001,
                    "2 ms" = 0.002,
                    "5 ms" = 0.005
                ),
                selected = 0.005
            ),
            sliderInput(
                "max_t",
                "Maximum decision time",
                min = 1,
                max = 6,
                value = 4,
                step = 0.5
            ),
            numericInput(
                "seed",
                "Random seed",
                value = 1234,
                min = 1,
                step = 1
            ),
            actionButton(
                "simulate",
                "Generate simulation",
                class = "btn-primary w-100"
            ),
            p(
                class = "small-note mt-3",
                "The simulation uses a simple Euler approximation with diffusion noise fixed to 1."
            )
        ),
        layout_columns(
            value_box(
                title = "Upper choices",
                value = textOutput("vb_upper", container = span),
                showcase = div(class = "display-6", "↑")
            ),
            value_box(
                title = "Lower choices",
                value = textOutput("vb_lower", container = span),
                showcase = div(class = "display-6", "↓")
            ),
            value_box(
                title = "Median RT",
                value = textOutput("vb_median_rt", container = span),
                showcase = div(class = "display-6", "t")
            ),
            value_box(
                title = "Censored",
                value = textOutput("vb_censored", container = span),
                showcase = div(class = "display-6", "∅")
            ),
            col_widths = c(3, 3, 3, 3)
        ),
        card(
            card_header("Evidence trajectories with boundary-attached RT distributions"),
            plotOutput("trajectory_plot", height = 520),
            p(
                class = "help-text px-3 pb-3",
                "The red and blue boundary ribbons show the RT distributions for ",
                "upper and lower choices. Their area is weighted by the proportion ",
                "of trials reaching each boundary."
            )
        ),
        layout_columns(
            card(
                card_header("Trial-wise parameter distribution"),
                plotOutput("parameter_plot", height = 360)
            ),
            card(
                card_header("Simulation summary"),
                tableOutput("summary_table"),
                p(
                    class = "help-text",
                    "Upper and lower boundaries can be interpreted as two response ",
                    "alternatives. RT includes non-decision time. Censored trials ",
                    "did not hit a boundary before the maximum decision time."
                )
            ),
            col_widths = c(7, 5)
        )
    )
)

# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------

server <- function(input, output, session) {
    current_parameters <- reactive({
        is_7p <- identical(input$model_type, "7")

        list(
            a = input$a,
            z = input$z,
            v = input$v,
            tau = input$tau,
            eta = if (is_7p) input$eta else 0,
            sz = if (is_7p) input$sz else 0,
            stau = if (is_7p) input$stau else 0,
            dt = as.numeric(input$dt),
            max_decision_time = input$max_t,
            seed = input$seed
        )
    })

    sim_data <- eventReactive(
        input$simulate,
        {
            p <- current_parameters()

            simulate_ddm(
                n_trials = input$n_trials,
                a = p$a,
                z = p$z,
                v = p$v,
                tau = p$tau,
                eta = p$eta,
                sz = p$sz,
                stau = p$stau,
                dt = p$dt,
                max_decision_time = p$max_decision_time,
                seed = p$seed
            )
        },
        ignoreNULL = FALSE
    )

    path_data <- eventReactive(
        input$simulate,
        {
            p <- current_parameters()

            simulate_paths(
                n_paths = input$n_paths,
                a = p$a,
                z = p$z,
                v = p$v,
                tau = p$tau,
                eta = p$eta,
                sz = p$sz,
                stau = p$stau,
                dt = p$dt,
                max_decision_time = p$max_decision_time,
                seed = p$seed
            )
        },
        ignoreNULL = FALSE
    )

    output$vb_upper <- renderText({
        sim <- sim_data()
        percent(mean(sim$choice == "upper", na.rm = TRUE), accuracy = 0.1)
    })

    output$vb_lower <- renderText({
        sim <- sim_data()
        percent(mean(sim$choice == "lower", na.rm = TRUE), accuracy = 0.1)
    })

    output$vb_median_rt <- renderText({
        sim <- sim_data() |>
            filter(!is.na(rt), choice %in% c("lower", "upper"))

        sprintf("%.3f s", median(sim$rt, na.rm = TRUE))
    })

    output$vb_censored <- renderText({
        sim <- sim_data()
        percent(mean(sim$choice == "censored", na.rm = TRUE), accuracy = 0.1)
    })

    output$trajectory_plot <- renderPlot({
        paths <- path_data()
        sim <- sim_data()
        p <- current_parameters()

        mean_ndt <- mean(paths$ndt, na.rm = TRUE)
        boundary_density <- make_boundary_density_df(
            sim = sim,
            a = p$a,
            density_height = 0.30
        )

        ggplot() +
            annotate(
                "rect",
                xmin = 0,
                xmax = mean_ndt,
                ymin = -Inf,
                ymax = Inf,
                fill = "#e2e8f0",
                alpha = 0.80
            ) +
            geom_ribbon(
                data = boundary_density,
                aes(x = rt, ymin = ymin, ymax = ymax, fill = choice),
                alpha = 0.66,
                color = "white",
                linewidth = 0.25,
                show.legend = FALSE
            ) +
            geom_hline(
                yintercept = c(0, p$a),
                linewidth = 0.8,
                color = "#0f172a"
            ) +
            geom_hline(
                yintercept = p$z * p$a,
                linewidth = 0.5,
                linetype = "dashed",
                color = "#475569"
            ) +
            geom_line(
                data = paths,
                aes(x = time, y = evidence, group = path, color = outcome),
                alpha = 0.58,
                linewidth = 0.55
            ) +
            annotate(
                "text",
                x = mean_ndt / 2,
                y = p$a * 1.16,
                label = "non-decision",
                size = 3.5,
                color = "#64748b"
            ) +
            scale_color_manual(
                values = c(
                    lower = "#2563eb",
                    upper = "#dc2626",
                    censored = "#64748b"
                ),
                drop = FALSE
            ) +
            scale_fill_manual(
                values = c(lower = "#2563eb", upper = "#dc2626"),
                drop = FALSE
            ) +
            scale_y_continuous(
                breaks = c(0, p$z * p$a, p$a),
                labels = c(
                    "0\nlower boundary",
                    sprintf("%.2f\nstart", p$z * p$a),
                    sprintf("%.2f\nupper boundary", p$a)
                )
            ) +
            coord_cartesian(
                ylim = c(-0.36 * p$a, 1.36 * p$a),
                expand = FALSE
            ) +
            labs(
                x = "Time from stimulus onset (s)",
                y = "Evidence",
                color = "Trajectory outcome",
                fill = "RT density"
            ) +
            guides(
                color = guide_legend(order = 1),
                fill = guide_legend(order = 2)
            ) +
            theme_minimal(base_size = 13) +
            theme(
                panel.grid.minor = element_blank(),
                legend.position = "top",
                plot.margin = margin(8, 12, 8, 8)
            )
    })

    output$parameter_plot <- renderPlot({
        sim <- sim_data()

        pars_long <- sim |>
            transmute(
                `drift rate` = drift,
                `starting point` = z_abs,
                `non-decision time` = ndt
            ) |>
            pivot_longer(
                cols = everything(),
                names_to = "parameter",
                values_to = "value"
            )

        ggplot(pars_long, aes(x = value)) +
            geom_histogram(
                bins = 35,
                fill = "#0f766e",
                color = "white",
                alpha = 0.82
            ) +
            facet_wrap(~ parameter, scales = "free", nrow = 1) +
            labs(x = NULL, y = "Trials") +
            theme_minimal(base_size = 13) +
            theme(
                panel.grid.minor = element_blank(),
                strip.text = element_text(face = "bold"),
                plot.margin = margin(8, 12, 8, 8)
            )
    })

    output$summary_table <- renderTable(
        {
            summarise_simulation(sim_data())
        },
        striped = TRUE,
        bordered = FALSE,
        spacing = "s"
    )
}

shinyApp(ui = ui, server = server)

© 2017-2026, Ladislas Nalborczyk ∙ Made with Quarto

Cookie Preferences