Neuroplastic Transfer Learning: LNN / Transformer Hybrids

This is an original experiment I've done and wanted to share the architecture - it's not the latest work, but it shows the direction I'm heading with teaching and informing a Liquid Time Constant neuron to take the place of the perceptron, attention heads and final hidden layers and eventually learn to "be" a transformer, except with a twist: It continuously learns.

This experiment doesn't have deep layers so the learning is surface level - and much of the subsequent work I've done shows that learning slowly or allowing internal reflection on what's learned is important with LTC neurons.

Some of the work here was prophetic - the human brain uses localized synchronization of time and works in groups. My concept was to take recent inferences from the transformer side and cache the tokens in a container and hold it there until the liquid neurons could replicate what was learned. Since the embeddings are laid out in predictable blocks, I thought of a shared memory manager and threads that have hashed identities with pointers to the "cube" of embeddings.

The cubes of data would in theory do transfer learning over time from the transformer and predict the outputs based on past activations. In this hybrid architecture, the LNN gradually influences the outputs of the transformer based, leveraging its strengths with attention and having neuroplastic properties, can adapt over time and make new predictions. My experiments later on on Hebian layers and high precision floating point required for deep learning came later.

Future plans

The original design was to have the cube of inference to have the liquid neurons take over inference for that block and switch off the gate and become a fully mentored LNN matrix.

Cubes would also solve the limitations of the LNNs requiring the whole block of neurons being computed together. By breaking up the inference into chunks and having parallel processing of the chunks, you essentially have a distributed LNN which may be able to take advantage of multiprocessing and threading. This was refined later on, but this does make for an interesting demonstration of how this may happen.

This Python implementation is more or less a proof leading to the Kotlin port which will use Kotlin's actor and channel system to allow LangChain-like interface design with "virtual dendrites" and virtual cortical connections for blocks.

The blocks would then have hybridized LNN/Transformer feed forwards to other blocks and so on, similar to the layering built up in the human brain and leveraging a sort of phonological loop of reflection which strengthens the learning. This part of the design is still pure theory.

Introduction

This paper explores the integration of Transformer and Liquid Neural Network (LNN) architectures, building upon foundational work by Hasani et al. (2020). The key innovation lies in combining the parallel processing capabilities of Transformers with the adaptive dynamics of LNNs.

Mathematical Framework

1. Core LNN Dynamics

Building on recent advances in closed-form continuous-time neural networks (Hasani et al., 2022), we start with the fundamental LNN equation (apologies for the lack of mathematical equation support on LinkedIn):

dx(t)dt=?[1τ+f(x(t),I(t),t,θ)]x(t)+f(x(t),I(t),t,θ)A

Which has the approximate closed-form solution:

x(t)≈(x0?A)e?(wτ+f(I(t),θ))tf(?I(t))+A

This formulation provides several key advantages:

  1. Eliminates the need for ODE solvers while maintaining continuous-time properties.
  2. Enables explicit time-dependent gating mechanisms.
  3. Allows direct optimization of temporal dynamics.

Where:


2. Bounded Dynamics Proof

Key theorem from Hasani et al. (2020):


This guarantees stable system dynamics even with unbounded inputs.

Proposed Architecture Enhancement

Integration with Transformer Blocks

My approach involves implementing this through:

  1. Continuous-Time Gating: dx(t)/dt = ? [1/τ + f(x(t), I(t), t, θ)]x(t) + f(x(t), I(t), t, θ)A
  2. Memory Cubes Architecture:

Where:

  • PP: perceptron layer
  • FF: feed-forward network
  • OO: output transformation

Cube-to-Cube Communication:

C?? = f(W?C? + W?C?)

Dynamic Time Constants

Building on Hasani's work, we introduce variable time constants:

τ??? = τ? + τf(x(t), I(t), t, θ)

This allows for adaptive computation based on input complexity.

Experimental Validation

Our implementation builds on proven performance metrics from Hasani et al.:

  1. Time Series Prediction:
  2. Computational Efficiency:

fused_solver_step(x, I, Δt, θ) = x + Δt f(x, I, t, θ)A / (1 + Δt(1/τ + f(x, I, t, θ)))

Theoretical Foundations

Building on established theory from Hasani et al., we prove:

Theorem 1: The combined Transformer-LNN system maintains bounded stability under the conditions:

  1. ?i: τ? > 0
  2. ∥W?∥ ≤ K, for some constant K
  3. f is Lipschitz continuous

Proof sketch:

Given the Lyapunov function V(x) = ∥x∥2, we show: dV/dt ≤ 0 under the above conditions.

Appendix: Implementation Guidelines

Stability Constraints

For system stability, maintain:

  1. 0 < τ??? ≤ τ? ≤ τ???
  2. 0 < τ??? ≤ τ? ≤ τ??? ∥W?∥ ≤ √(2/n), where n is the input dimension
  3. ∥A∥ ≤ A???

Implementation

I am primarily using a Transformer architecture combined with a novel liquid neural network approach. Specifically, we are using several key transformer components:

self.attention = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout)        

The innovative aspect is that I'm combining these Transformer components with Liquid Neural Networks (from Hasani et al.'s work), which isn't on the list but represents a novel architecture that uses continuous-time neural dynamics.

Key components from our implementation:

class TransformerLNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads=4):
        # Transformer components
        self.attention = nn.MultiheadAttention(...)

        # LNN components
        self.ltc = LiquidTimeConstant(...)        

This is a hybrid architecture that combines:

  • Traditional Transformer attention mechanisms
  • Continuous-time neural dynamics from LNNs
  • Time-dependent gating mechanisms

So while "Transformer" is the base architecture, I'm extending it significantly with the liquid neural network approach for handling temporal dynamics.

I'll break down each section of my experiment and explain what it's doing and how:

Data Generation & Processing

Generate synthetic data with multiple patterns

def generate_synthetic_data(num_samples=1000, seq_length=50, input_dim=10):
    # Create time steps
    t = torch.linspace(0, 10, seq_length)

    # Generate diverse patterns (sine waves with different frequencies)
    patterns = []
    for i in range(input_dim):
        freq1, freq2 = (i + 1) * 0.5, (i + 1) * 0.25
        pattern = torch.sin(2 * np.pi * freq1 * t) + \
                 0.5 * torch.sin(2 * np.pi * freq2 * t)
        patterns.append(pattern)        

What: Creates synthetic sequential data with multiple overlapping patterns How: Combines sine waves of different frequencies to create complex temporal patterns Why: Provides controlled data to test the model's ability to learn temporal dependencies

Model Architecture

class TransformerLNN(nn.Module): def init(self, input_size, hidden_size, num_heads):
    # Transformer component for parallel processing
    self.attention = nn.MultiheadAttention(...)

    # LNN component for temporal dynamics
    self.ltc = LiquidTimeConstant(...)
        

What: Hybrid architecture combining Transformer attention with liquid neural networks How:

  • Uses attention to process sequences in parallel
  • Uses LNN for continuous-time modeling of temporal dependencies
  • Combines both through a shared backbone network

Training Loop

for epoch in range(NUM_EPOCHS):
    for batch_idx, (x, y) in enumerate(train_loader):
        # Forward pass
        y_pred = model(x)
        loss = criterion(y_pred, y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Visualization updates
        if batch_idx % visualizer.save_intervals == 0:
            update_visualizations()        

What: Trains the model while tracking multiple metrics How:

  • Processes batches of sequences
  • Updates model parameters using gradient descent
  • Collects performance metrics and model states
  • Updates visualizations in real-time

The key innovation here is how it combines:

  • Transformer's parallel processing (attention)
  • LTC's continuous time dynamics (liquid state)
  • Time-dependent gating to balance both

LiquidTimeConstant (LTC) Class?

Core components

  • Backbone network: Takes combined input+hidden state, processes through 2 linear layers with tanh
  • Three specialized networks:time_net: Controls temporal dynamicsstate_net_g: Transforms state for short-term memorystate_net_h: Transforms state for long-term memory
  • Learnable parameters:tau: Time constants (initialized to ones)A: Bias terms (initialized randomly)

Forward pass:

  1. Combines input and hidden state
  2. Processes through backbone network
  3. Computes time-constant factor using sigmoid
  4. Applies state transformations
  5. Uses time-dependent gating to blend states

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple

class LiquidTimeConstant(nn.Module):
    """
    Implements time-aware processing with dynamic memory
    Like a smart system that knows how to balance past and present information
    """
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # BACKBONE NETWORK
        # Core feature extraction - like a smart filter
        self.backbone = nn.Sequential(
            nn.Linear(input_size + hidden_size, hidden_size),
            nn.Tanh(),  # Smooth activation for stability
            nn.Linear(hidden_size, hidden_size)
        )
        
        # TEMPORAL PROCESSING NETWORKS
        # Time-aware processing mechanisms
        
        # Controls how time influences processing
        self.time_net = nn.Linear(hidden_size, hidden_size)
        
        # Transform state for short-term memory
        self.state_net_g = nn.Linear(hidden_size, hidden_size)
        
        # Transform state for long-term memory
        self.state_net_h = nn.Linear(hidden_size, hidden_size)
        
        # LEARNABLE PARAMETERS
        # Time constants - control information flow speed
        self.tau = nn.Parameter(torch.ones(hidden_size))
        # Bias terms - help network learn better
        self.A = nn.Parameter(torch.randn(hidden_size))

    def forward(self, 
                x: torch.Tensor,  # Current input
                h: torch.Tensor,  # Memory state
                t: torch.Tensor   # Time information
               ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates memory state based on new information and time
        Like dynamically updating notes as you get new information
        """
        # Combine current input with memory
        combined = torch.cat([x, h], dim=-1)
        # Extract features
        features = self.backbone(combined)
        
        # TEMPORAL GATING
        # Compute how much time should influence the update
        f_t = torch.sigmoid(self.time_net(features))
        
        # STATE TRANSFORMATIONS
        # Process for different timescales
        g_x = self.state_net_g(features)  # Short-term
        h_x = self.state_net_h(features)  # Long-term
        
        # TIME-DEPENDENT GATING
        # Decide how to blend different timescales
        gate = torch.sigmoid(-f_t * t.view(-1, 1))
        
        # MEMORY UPDATE
        # Blend short and long-term information
        h_new = gate * g_x + (1 - gate) * h_x
        
        return h_new, h_new        

TransformerLNN Class?

Core components

  • input_proj: Projects input to higher dimension
  • attention: Multi-head attention for parallel processing
  • ltc: Liquid time-constant layer
  • output_proj: Projects back to input dimension
  • Two layer normalizations for stability

Forward pass:

  1. Projects input to higher dimension
  2. Applies self-attention and normalizes
  3. Processes sequence through LTC:
  4. Maintains a running stateUpdates state for each time stepCollects outputs
  5. Combines and normalizes outputs
  6. Projects back to original dimension

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import seaborn as sns
from typing import List, Dict, Optional
from datetime import datetime

class TrainingVisualizer:

    def __init__(self, save_intervals: int = 10, plot_intervals: int = 5):
        self.save_intervals = save_intervals
        self.plot_intervals = plot_intervals

        # History tracking
        self.losses: List[float] = []
        self.accuracies: List[float] = []
        self.ltc_states: List[torch.Tensor] = []
        self.attn_weights: List[torch.Tensor] = []
        self.timestamps: List[str] = []

    def update_metrics(self, loss: float, accuracy: float, 
                      ltc_state: torch.Tensor, attn_weights: torch.Tensor):

        """Update training metrics"""

        self.losses.append(loss)
        self.accuracies.append(accuracy)
        self.ltc_states.append(ltc_state.detach().cpu())
        self.attn_weights.append(attn_weights.detach().cpu())
        self.timestamps.append(datetime.now().strftime("%H:%M:%S"))

    def plot_training_progress(self, epoch: int, batch_idx: int, x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor):
        """Plot comprehensive training visualizations"""
        # Convert tensors to 2D by taking first channel/feature
        x_plot = x[:, :, 0].cpu().detach().numpy()  # Take first feature dimension
        y_plot = y[:, :, 0].cpu().detach().numpy()
        y_pred_plot = y_pred[:, :, 0].cpu().detach().numpy()
     
        clear_output(wait=True)
        plt.figure(figsize=(20, 12))
      
        # Plot actual visualizations using 2D data
        plt.subplot(231)
        plt.plot(self.losses[-100:], label='Training Loss', color='blue')
        plt.title(f'Loss Curve (Current: {self.losses[-1]:.4f})')
        plt.xlabel('Last 100 Steps')
        plt.ylabel('Loss')
        plt.grid(True)

        # Use the 2D versions for prediction plots
        plt.subplot(235)
        idx = 0  # Plot first sequence in batch
        plt.plot(y_plot[idx], label='Ground Truth', color='blue', alpha=0.7)
        plt.plot(y_pred_plot[idx], label='Prediction', color='red', alpha=0.7)
        plt.title('Prediction vs Ground Truth')
        plt.xlabel('Sequence Step')
        plt.ylabel('Value')
        plt.legend()
        plt.grid(True)

    def generate_synthetic_data(num_samples: int = 1000,
                            seq_length: int = 50,
                            input_dim: int = 10,
                            device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
                            ) -> tuple[torch.Tensor, torch.Tensor]:

        """Generate synthetic sequence data with multiple patterns"""

        # Time steps

        t = torch.linspace(0, 10, seq_length, device=device)
        t = t.view(1, -1, 1).repeat(num_samples, 1, input_dim)     

        # Generate diverse patterns

        patterns = []

        for i in range(input_dim):
 
            # Combine sine waves with different frequencies

            freq1 = (i + 1) * 0.5
            freq2 = (i + 1) * 0.25

            pattern = torch.sin(2 * np.pi * freq1 * t[..., i]) + \
                    0.5 * torch.sin(2 * np.pi * freq2 * t[..., i])

            patterns.append(pattern)      

        # Combine patterns

        x = torch.stack(patterns, dim=-1)

        # Create target with transformations

        y = torch.roll(x, shifts=-1, dims=1) * 1.5 + 0.5
        return x, y

    
    def train_epoch(model: nn.Module,
                    train_loader: torch.utils.data.DataLoader,
                    criterion: nn.Module,
                    optimizer: torch.optim.Optimizer,
                    device: str,
                    visualizer,
                    epoch: int) -> float:

        """Train for one epoch"""

        model.train()
        total_loss = 0

        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()           

            # Forward pass

            y_pred = model(x)
            loss = criterion(y_pred, y)       

            # Backward pass

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update visualizations

            if batch_idx % visualizer.save_intervals == 0:
                with torch.no_grad():
                    
                    # Get model states

                    ltc_state, _ = model.ltc(
                        model.input_proj(x[0:1]), 
                        torch.zeros(1, model.hidden_size, device=device),
                        torch.zeros(1, device=device)
                    )

                    attn_weights = model.get_attention_weights(x[0:1])

                    # Compute accuracy

                    mse = ((y_pred - y) ** 2).mean().item()
                    accuracy = 100 * (1 - min(mse, 1))                  

                    # Update visualizer

                    visualizer.update_metrics(loss.item(), accuracy, ltc_state, attn_weights)
     
            # Plot progress

            if batch_idx % visualizer.plot_intervals == 0:
                visualizer.plot_training_progress(epoch, batch_idx, x, y, y_pred)

        return total_loss / len(train_loader)        


Setup Training Parameters

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline


# Model parameters

INPUT_SIZE = 10
HIDDEN_SIZE = 64
NUM_HEADS = 4



# Training parameters

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_EPOCHS = 50


# Data parameters

NUM_SAMPLES = 1000
SEQ_LENGTH = 50

# Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")        



Generate Training Data

# Generate training data

x_train, y_train = TrainingVisualizer.generate_synthetic_data(

    num_samples=NUM_SAMPLES,
    seq_length=SEQ_LENGTH,
    input_dim=INPUT_SIZE,
    device=device

)

# Create dataloader

train_dataset = torch.utils.data.TensorDataset(x_train, y_train)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Plot sample sequence

plt.figure(figsize=(12, 4))
plt.plot(x_train[0, :, 0].cpu().numpy(), label='Input')
plt.plot(y_train[0, :, 0].cpu().numpy(), label='Target')
plt.title('Sample Sequence')
plt.legend()
plt.grid(True)
plt.show()        

Create and Initialize Model

# Initialize model

model = TransformerLNN(
    input_size=INPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    num_heads=NUM_HEADS

).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Initialize visualizer

visualizer = TrainingVisualizer()        

Training Loop with Visualization

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0

    for batch_idx, (x, y) in enumerate(train_loader):
        
        # Forward pass
        y_pred = model(x)
        loss = criterion(y_pred, y)
    
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Update visualization
        if batch_idx % visualizer.save_intervals == 0:
            
            with torch.no_grad():
                # Get model states
                batch_input = x[0:1]  # Shape: [1, seq_len, input_dim]
                projected = model.input_proj(batch_input)  # Shape: [1, seq_len, hidden_size]

                # Take first timestep for LTC state
                ltc_input = projected[:, 0, :]  # Shape: [1, hidden_size]
                ltc_state, _ = model.ltc(
                    ltc_input,
                    torch.zeros(1, HIDDEN_SIZE, device=device),
                    torch.zeros(1, device=device)
                )

        attn_weights = model.get_attention_weights(batch_input)


        # Compute accuracy
        mse = ((y_pred - y) ** 2).mean().item()
        accuracy = 100 * (1 - min(mse, 1))

        # Update visualizer
        visualizer.update_metrics(loss.item(), accuracy, ltc_state, attn_weights)

        # Plot progress

        if batch_idx % visualizer.plot_intervals == 0:
            visualizer.plot_training_progress(epoch, batch_idx, x, y, y_pred)

    # End of epoch summary
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_loss:.4f}")        

Final Evaluation

# Generate test data

x_test, y_test = TrainingVisualizer.generate_synthetic_data(
    num_samples=10,
    seq_length=SEQ_LENGTH,
    input_dim=INPUT_SIZE,
    device=device

)

# Model evaluation

model.eval()

with torch.no_grad():
    y_pred = model(x_test)
    test_loss = criterion(y_pred, y_test)

    # Plot predictions
    plt.figure(figsize=(15, 5))

    for i in range(3):  # Plot first 3 sequences
        plt.subplot(1, 3, i+1)
        plt.plot(y_test[i, :, 0].cpu().numpy(), label='Ground Truth', alpha=0.7)
        plt.plot(y_pred[i, :, 0].cpu().numpy(), label='Prediction', alpha=0.7)
        plt.title(f'Sequence {i+1}')
        plt.legend()
        plt.grid(True)
    plt.tight_layout()

    plt.show()

print(f"Final Test Loss: {test_loss:.4f}")        

Analyze Learning Progression

# Plot training history
plt.figure(figsize=(15, 5))

# Loss progression
plt.subplot(131)
plt.plot(visualizer.losses, label='Training Loss')
plt.title('Loss Progression')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.grid(True)

# Accuracy progression
plt.subplot(132)
plt.plot(visualizer.accuracies, label='Accuracy')
plt.title('Accuracy Progression')
plt.xlabel('Step')
plt.ylabel('Accuracy (%)')
plt.grid(True)

# Final attention pattern
plt.subplot(133)

# Extract a 2D slice from the 3D tensor
attn_weights_2d = visualizer.attn_weights[-1][0].cpu().numpy()

# Plot the 2D attention weights
sns.heatmap(attn_weights_2d, cmap='viridis')
plt.title('Final Attention Pattern')

plt.tight_layout()
plt.show()        

View the whole project on Kaggle - this is open source, so feel free to use, modify, sell or whatever (MIT license). There are no warranties here. This is just lab work.


要查看或添加评论,请登录

社区洞察

其他会员也浏览了