Expander Nets - Project Structure and Implementation Notes
Expander Nets - Project Structure and Implementation Notes
Overview
This is an interactive visualization for training “expander nets” - neural networks with a specific architecture designed to learn Hermite polynomial target functions. The project includes both experimental simulation (gradient descent) and theoretical predictions (ODE solver).
Architecture
Network Structure
- Input dimension: d (user configurable, 1-5)
- Hidden dimension: k (large, typically 1000)
- Layers:
- W₁: d×d diagonal matrix (trainable)
- W_froz: d×k fixed random Gaussian matrix (frozen)
- W₂: k×1 trainable weights
- Forward pass:
f(x) = W₂ᵀ · σ(W_froz · W₁ · x)where σ is ReLU activation
Target Function
- Product of normalized Hermite polynomials:
f*(x) = ∏ᵢ h_αᵢ(xᵢ/√γᵢ) - Uses normalized probabilist’s Hermite polynomials:
h_α(z) = He_α(z) / √(α!)- h₀(z) = 1
- h₁(z) = z
- h₂(z) = (z² - 1) / √2
- h₃(z) = (z³ - 3z) / √6
- Parameters:
αᵢ(alpha): polynomial degree for dimension iγᵢ(gamma): normalization scale for dimension i
File Structure
Core Logic
app.js: Main application entry point- Initializes AppState, managers, and simulation
- Sets up network visualization
- Handles start/pause/reset buttons
- CRITICAL THEORY LOGIC: Only recomputes theory in “preview mode” (before simulation has data)
- Once simulation starts, theory is “locked” to captured parameters
model.js: ExpanderNet neural network implementation- W₁ initialization: diagonal with small values (0.01)
- W_froz: fixed Gaussian random matrix
- W₂: initialized small (0.01)
- Forward pass with ReLU activation
target.js: Hermite polynomial target functions- Normalized Hermite polynomials h_α(z) = He_α(z) / √(α!)
- Derivatives for gradient computation
- Target function evaluation: product over dimensions
training.js: SGD trainer- Batch sampling from Gaussian distribution
- Gradient computation via backprop
- Parameter updates (only W₁ diagonals and W₂)
simulation.js: Training loop controller- Adaptive stepping: fits as many SGD steps as possible per frame (~40 Hz target)
-
Tracks loss and parameter norms (W₁ diagonals, W₁ cross-term [0,1], √k· W₂ ) - Runs ODE theory prediction via
runTheoryODE() - Adaptive dt scaling: For tMax > 2000, scales dt proportionally to keep ~200k ODE steps
state.js: Application state with localStorage persistence- Stores all UI parameters (d, k, γ, α, η, batch size, etc.)
- Stores UI settings (logScale, showTheory, xAxisMode, emaWindow)
UI Components (ui/ directory)
charts.js: Chart initialization and controls- Creates LossChart and NormChart instances
- Sets up logscale, x-axis mode (step vs t_eff), EMA slider, theory checkbox
computeAndShowTheoryOnly(): Shows theory preview before simulation runs
controls.js: Parameter sliders and displays- d slider, k slider
- 5 gamma sliders (one per possible dimension)
- 5 alpha sliders (one per possible dimension)
- η (learning rate) and batch size sliders
- Updates theoretical predictions display
display.js: Theory predictions display panel- Shows c, β_eff, t_rise, etc. computed from theory
Visualization
visualization.js: Chart.js wrapper classes- LossChart: 3 datasets (raw, EMA-smoothed exp, theory)
- NormChart: Multiple datasets for W₁ diagonals, W₁ cross-term, W₂ norm
- Dataset structure (for d dimensions):
- Indices 0 to d-1: W₁ diagonals (exp, thick lines)
- Index d: W₁ cross-term [0,1] (exp, thick, dark gray) - only if d ≥ 2
- Index d + crossOffset: W₂ norm (exp, thick, slate gray)
- Next d datasets: W₁ diagonals (theory, thin, dashed)
- Final dataset: W₂ norm (theory, thin, dashed)
- Theory datasets excluded from legend via
generateLabelsfilter
- Dataset structure (for d dimensions):
- Incremental caching with downsampling (MAX_PLOT_POINTS = 2000)
- EMA smoothing for loss plot
Theory Computation (core/ directory)
theory.js: Analytical theory calculations- Computes c, β_eff, t_rise from parameters
- Used for predictions display and ODE initialization
Configuration
config.js: Constants and network visualization settings
Key Design Patterns
Theory Curve State Management
Two modes:
- Preview mode (before simulation starts):
- Theory updates live as sliders change
- Uses current slider values to compute theory
- Implemented in
app.jsparameter callback andcharts.jscheckbox handler
- Locked mode (after simulation has data):
- Theory stays frozen to captured parameters from simulation start
- Slider changes don’t affect theory curves
- Only show/hide existing theory data
Implementation check in app.js:
const simulationHasData = simulation.model !== null &&
simulation.lossHistory.length > 0;
if (!simulationHasData && appState.showTheory) {
chartsManager.computeAndShowTheoryOnly();
}
Parameter Capture Pattern
When simulation starts (app.js):
- User clicks “start”
simulation.captureParams(...)saves current slider valuessimulation.start()initializes model and runs ODE with captured params- Theory is computed once and stored in
theoryLossHistory,theoryNormHistory - Future slider changes don’t affect theory (until reset)
Adaptive Performance
SGD stepping (simulation.js):
- Target 25ms per frame (40 Hz)
- Estimates time per step with exponential moving average
- Dynamically adjusts steps per frame to meet budget
- Typical: 100-500 steps/frame depending on d, k
ODE integration (simulation.js, charts.js):
- Base dt = 0.01 for tMax ≤ 2000
- For tMax > 2000: dt = 0.01 × (tMax / 2000)
- Keeps computation constant (~200k steps regardless of tMax)
Incremental Plotting Cache
IncrementalCache in visualization.js:
- Tracks last processed index to avoid reprocessing entire history
- Computes EMA incrementally (O(new points) not O(n))
- Downsamples to MAX_PLOT_POINTS for chart rendering
- Stores running max values for y-axis scaling
Integration with Low-Dim-Flows
The project uses the low-dim-flows ODE solver (sibling directory) to compute theory predictions:
Import: import { solveODE } from '../low-dim-flows/api.js'
Usage in simulation.js and charts.js:
const result = solveODE(a0Vec, kVec, tMax, fStar, c, dt);
// Returns: { times, aTrajectories, lossValues }
Mapping:
a0Vec = [1, 1, ..., 1, 1e-10]- initial values for W₁ diagonals + W₂ biaskVec = [...alphas, 1]- exponents for each parametertMax = 2 × t_rise- simulation time- Theory trajectories map to experiment:
aTrajectories[i](i < d) → W₁ diagonal element i-
aTrajectories[d]→ W₂ norm (bias term b, shown asb )
Recent change: Removed HARD_CAP override in low-dim-flows so caller’s tMax is respected
X-Axis Modes
Step mode: x-axis shows iteration count (discrete steps) t_eff mode: x-axis shows effective time = η × step (continuous time)
Both experiment and theory use the same x-axis transformation for consistency.
Important Gotchas
- Dataset indexing in NormChart: Must account for crossTermOffset when d < 2
- Theory curve visibility: Controlled by both
showTheoryflag ANDsimulationHasDatastate - Legend filtering: Theory datasets have “theory” in label → excluded via
generateLabels - Y-axis max: Must include theory range when
showTheoryis true - EMA initialization: Must match initial parameter values (W₁ = 1, W₂ = 0)
- Unicode subscripts: Use
String.fromCharCode(0x2080 + digit)for legend labels - W₁ cross-term: Only tracked/displayed when d ≥ 2, stored as
w1_cross -
W₂ norm display: Shows √k· W₂ (scaled), but theory shows just b (unscaled)
Performance Characteristics
Typical performance (d=3, k=1000, batch=100):
- ~300 steps/sec
- ~1.5ms per SGD step
- ~0.3ms per norm computation
- ~0.5ms per chart update
- Total: ~25ms per frame (40 FPS)
Bottlenecks:
- Forward/backward pass scales with k
- Norm computation (Frobenius norm of W₂)
- History array appends (becomes slow after ~100k points)
Optimizations:
- Incremental cache reduces chart update from O(n) to O(new points)
- Downsampling keeps chart at 2000 points max
- No intermediate allocations in hot loops
- Adaptive stepping maximizes training per frame