Disentangler
The Disentangler Model is a specialized audio source separation system designed for AGI-compatible self-supervised learning with optional supervised guidance. It uses dedicated voice and noise encoders to extract rich, domain-specific features for high-quality audio separation.
- Primary: Strong self-supervised learning (NTXENT + BYOL) for AGI compatibility
- Secondary: Supervised guidance using clean voice/noise targets for audio quality
- Specialized Processing: Voice-specific and noise-specific feature extraction
- Multiple Views: Batch vs streaming variants for different use cases
- Sample Rate: 16,000 Hz
- Hop Length: 80 samples (~5ms frames)
- Encoder Output Frame Rate: ~400 Hz (temporal dimension after encoding)
- Input Format:
[B, 1, T](batch, mono, time) - Temporal Resolution: ~2.5ms per embedding frame
Voice Embeddings: 160-dim total
āāā Pitch Features: 32-dim (F0, harmonic tracking)
āāā Harmonic Features: 32-dim (harmonic template matching)
āāā Formant Features: 32-dim (vocal tract resonances)
āāā VAD Features: 16-dim (voice activity detection)
āāā Spectral Features: 48-dim (multi-scale voice patterns)
Noise Embeddings: 144-dim total
āāā Broadband Features: 32-dim (non-harmonic spectral)
āāā Transient Features: 24-dim (onset/impact detection)
āāā Environmental Features: 28-dim (environmental sounds)
āāā Texture Features: 20-dim (noise pattern analysis)
āāā Non-harmonic Features: 24-dim (non-periodic content)
āāā Statistical Features: 16-dim (temporal statistics)
- Purpose: Training and high-quality inference
- Parameters: ~8.4M total parameters
- Features:
- Higher temporal resolution (ultra-high-res with hop_length//2)
- Better spectral resolution (1024+ FFT vs 512 FFT)
- Enhanced temporal modeling (kernels: 7-11 vs 3-5)
- More sophisticated analysis (12 harmonics vs 8)
- Purpose: Real-time deployment
- Parameters: ~4.3M total parameters
- Features:
- Lower latency processing
- Smaller memory footprint
- Compatible output dimensions
- Optimized for real-time inference
- Multi-resolution pitch analysis: 4 scales (201-1201 sample kernels)
- Frequency range: 65-400 Hz (human voice fundamentals)
- Temporal modeling: Enhanced continuity tracking
- Output: Pitch confidence and F0 trajectory features
- Harmonic templates: 8-12 learnable harmonic patterns
- STFT resolution: 512-1024 FFT
- Template matching: Spectral correlation with harmonic structures
- Output: Harmonic strength and voice characterization
- Mel-scale analysis: Vocal tract modeling (80-4000 Hz)
- Resolution: 80+ mel bins for formant detail
- Tracking: Temporal formant trajectory analysis
- Output: F1, F2, F3 formant features and vocal tract shape
- Spectral VAD: Multi-scale mel-spectrogram analysis
- Frequency range: 80-8000 Hz
- Temporal context: Large kernels for voice/silence discrimination
- Output: Voice probability and activity patterns
- Resolutions: 256/512/1024/2048 FFT windows
- Mel bins: 32-40 per resolution
- Temporal detail: Ultra-high-res (hop_length//2) for articulation
- Output: Multi-resolution voice spectral characteristics
- Full spectrum: Complete frequency range analysis
- FFT resolution: 1024+ for detailed spectral content
- Focus: Non-harmonic, broadband vs narrow-band discrimination
- Output: Broadband spectral patterns and energy distribution
- Multi-resolution: 5 scales (32-1024 sample kernels)
- Detection types:
- Ultra-high-freq: Clicks, pops
- High-freq: Impacts, hits
- Mid-freq: Mechanical sounds
- Low-freq: Rumbles, thuds
- Ultra-low-freq: Large environmental events
- Output: Onset patterns and transient characteristics
- Wide frequency range: 20 Hz - 8 kHz
- Mel bins: 96+ for environmental detail
- Modeling: Natural and urban sound environments
- Output: Environmental sound classification and patterns
- Multi-scale spectrograms: 256/512/1024 FFT
- Pattern analysis: Short/medium/long-term texture patterns
- Combination: 899 channels ā compressed texture features
- Output: Noise texture characteristics and roughness
- Spectral analysis: 80 mel bins, wide frequency range
- Detection: Non-periodic, irregular frequency content
- Temporal tracking: Non-harmonic pattern evolution
- Output: Aperiodic sound characteristics
- Temporal windows: Large analysis windows (64+ samples)
- Statistics: Amplitude, spectral, and temporal statistics
- Modeling: Long-term statistical properties
- Output: Noise statistical fingerprints
- Purpose: Supervised training with clean voice/noise targets
- Architecture: Memory-efficient 2-stage learned upsampling
- Stage 1: 8x upsampling (ConvTranspose1d + BatchNorm + ReLU)
- Stage 2: 10x upsampling (ConvTranspose1d + BatchNorm + ReLU + Tanh)
- Total: 80x upsampling (~400 frames ā 32000 samples)
- Channel Progression: embedding_dim ā 64 ā 32 ā 16 ā 1
- Output Range: [-1, 1] with final Tanh activation
- Loss: Multi-resolution STFT with spectral convergence and HF emphasis
voice_reconstruction: [B, 1, T] - Reconstructed clean voice
noise_reconstruction: [B, 1, T] - Reconstructed clean noise
mixed_reconstruction: [B, 1, T] - Sum of voice + noise (should equal input)
Total Loss Weights:
āāā NTXENT (Contrastive): 50% - Primary disentanglement signal
āāā BYOL (Self-supervised): 30% - Robust representation learning
āāā Supervised Guidance: 20% - Audio quality bootstrap
- NTXENT Contrastive Loss: TRUE positive embeddings (optimized path)
- Separate voice/noise NT-Xent instances with memory banks
- Pre-computed positive embeddings with gradient flow
- BYOL Self-supervised: EMA target encoders (prevents collapse)
- Voice/noise target encoders with Ļ=0.999 momentum
- Stop-gradient on target side for proper BYOL learning
- Supervised Reconstruction: Multi-component reconstruction loss
- Time-domain MSE loss for accuracy
- Multi-resolution STFT loss (5 scales: 256-4096 FFT)
- Spectral convergence loss for magnitude consistency
- High-frequency emphasis (4-8kHz weighted 4-6x)
- Separation Loss: Barlow Twins style decorrelation
- Cross-correlation minimization between voice/noise embeddings
- Bounded variance maintenance (prevents collapse)
disentangler/
āāā disentangler_model.py # Main model class
āāā encoders/
ā āāā __init__.py # Encoder exports
ā āāā voice_encoder.py # Streaming voice encoder
ā āāā noise_encoder.py # Streaming noise encoder
ā āāā batch_voice_encoder.py # Batch voice encoder (default)
ā āāā batch_noise_encoder.py # Batch noise encoder (default)
āāā ../../training/disentangler/
ā āāā enhanced_loss_functions.py # Hybrid loss implementation
ā āāā train_disentangler_sequential.py # Training script
āāā configs/
āāā phase_hybrid_supervised.yaml # Training configuration
from disentangler_model import DisentanglerModel # Default: Batch encoders for best quality model = DisentanglerModel( streaming_mode=False, # Use batch encoders (default) voice_embedding_dim=160, noise_embedding_dim=144, enable_reconstruction=True ) # Streaming: For real-time processing streaming_model = DisentanglerModel( streaming_mode=True, # Use streaming encoders voice_embedding_dim=160, noise_embedding_dim=144, enable_reconstruction=False # Typically disabled for streaming )
import torch # Input: [batch_size, 1, time_samples] mixed_audio = torch.randn(4, 1, 16000) # 4 samples, 1 second each # Forward pass outputs = model(mixed_audio, return_reconstructions=True) # Outputs: # - voice_embeddings: [4, 160, 100] # 160-dim voice features, 100 frames # - noise_embeddings: [4, 144, 100] # 144-dim noise features, 100 frames # - voice_reconstruction: [4, 1, 16000] # Reconstructed clean voice # - noise_reconstruction: [4, 1, 16000] # Reconstructed clean noise # - Individual feature components for analysis
# Training uses batch encoders automatically trainer = DisentanglerTrainer( model=model, hybrid_loss_weights={ 'ntxent_weight': 0.5, # Primary: Contrastive learning 'byol_weight': 0.3, # Primary: Self-supervised 'supervised_weight': 0.2 # Secondary: Clean target guidance } )
- Batch Model: ~8.4M parameters, higher memory usage, optimized batch size: 144-224
- Streaming Model: ~4.3M parameters, lower memory usage
- Training: Requires clean voice/noise targets for supervised guidance
- Inference: Works with mixed audio only
- Memory Optimizations: BatchNorm momentum=0.01, gradient clipping=0.5
Monitor these for disentanglement quality:
disentanglement/voice_noise_similarity: Should decrease to <0.02 (lower = better separation)embeddings/voice_variance,noise_variance: Should stay >0.01 (prevent collapse)embeddings/voice_effective_rank,noise_effective_rank: Target >50 (voice), >40 (noise)reconstructions/voice_snr_db,noise_snr_db: Target >10dB improvementspectral/voice_hf_preservation,noise_hf_preservation: Target >50% (HF content retained)contrastive/voice_temperature,noise_temperature: Should stabilize 10-20
- Batch: Best quality, slower processing, enhanced multi-resolution spectral loss
- Streaming: Good quality, real-time capable
- Compatibility: Near-identical output dimensions (±1 frame)
- NaN Prevention: Multi-layer safeguards, automatic LR reduction on NaN detection
- Small Dataset Support: LR warmup (200 steps), reduced BatchNorm momentum
- Checkpoint Resumption: Preserves NT-Xent queues, BYOL EMA targets, all training state
ā
AGI Compatibility: Strong self-supervised learning (80% of loss)
ā
Voice-Specific Processing: Dedicated pitch, harmonic, formant analysis
ā
Noise-Specific Processing: Environmental, transient, texture analysis
ā
Multiple Views: Batch vs streaming for different deployment needs
ā
Rich Features: No dimension compression, preserve all information
ā
Supervised Bootstrap: Clean targets improve quality without compromising self-learning
ā
Temporal Modeling: Enhanced batch encoders with better temporal detail