Decoder
The Flexible Multi-Decoder System leverages rich embeddings from the specialized disentangler model to reconstruct high-quality separated audio for any number of signal types. This architecture supports voice/noise separation and easily extends to musical instruments like trumpet, violin, etc.
- Signal-Type Agnostic: Supports any number and type of audio sources (voice, noise, trumpet, violin, etc.)
- Flexible Configuration: Easy to add new signal types without architectural changes
- Specialized Feature Processing: Dedicated conditioners per signal type with appropriate features
- Scalable Design: Parameter count scales linearly with number of signal types
- Dual Mode Support: Batch processing for training, streaming for real-time deployment
voice_embeddings: 160-dim (main aggregated features)
āāā voice_pitch_features: 32-dim (F0 and harmonic tracking)
āāā voice_harmonic_features: 32-dim (harmonic template matching)
āāā voice_formant_features: 32-dim (vocal tract resonances)
āāā voice_vad_features: 16-dim (voice activity detection)
āāā voice_spectral_features: 48-dim (multi-scale voice patterns)
noise_embeddings: 144-dim (main aggregated features)
āāā noise_broadband_features: 32-dim (non-harmonic spectral)
āāā noise_transient_features: 24-dim (onset/impact detection)
āāā noise_environmental_features: 28-dim (environmental sounds)
āāā noise_texture_features: 20-dim (noise pattern analysis)
āāā noise_nonharmonic_features: 24-dim (non-periodic content)
āāā noise_statistical_features: 16-dim (temporal statistics)
- Individual Feature Processing: Separate projections for each voice feature type
- Feature Enhancement: Specialized enhancers for pitch and harmonic features
- Attention Fusion: Multi-head attention to intelligently combine features
- Output: 512-dim conditioning vector for voice decoder
- Texture Modeling: Enhanced processing for noise coherence and texture
- Spectral Balance: Learnable frequency band weighting (32 bands)
- Transient Enhancement: Specialized processing for impact/onset detection
- Output: 512-dim conditioning vector + spectral balance weights
- Temporal Processing: 6-layer causal TCN with dilations [1,2,4,8]
- Streaming Compatible: No future information dependencies
- Input: 512-dim conditioning vectors at 200Hz (5ms hop)
- Output: 512-dim hidden states for rendering
- Learned Basis: 160 basis atoms optimized for voice vs noise
- Voice Basis: Harmonic patterns (40%), formant resonances (30%), transients (30%)
- Noise Basis: White noise (25%), colored noise (25%), transients (25%), bandlimited (25%)
- Output: 400-sample frames (25ms) with proper windowing
- Frame Length: 400 samples (25ms at 16kHz)
- Hop Size: 80 samples (5ms at 16kHz)
- Window: Hann window for COLA reconstruction
- Output Rate: Constant latency streaming at 200 Hz frame rate
- Streaming OLA: Proper COLA normalization with circular buffer
- DC Removal: 1-pole high-pass filter (995Hz cutoff)
- Quality: Artifact-free reconstruction with proper windowing
models/decoder/
āāā enhanced_decoder.py # Main decoder architecture
āāā streaming_decoder.py # Streaming interfaces
āāā README.md # This file
āāā conditioners/ # Existing conditioners (legacy)
āāā hf_aware_conditioner.py
āāā noise_conditioner.py
āāā ...
dataset_prep/decoder/
āāā create_enhanced_decoder_dataset.py # Enhanced dataset creation
from enhanced_decoder import create_enhanced_decoder_system from streaming_decoder import create_streaming_decoder # Create lightweight system (~6.7M parameters) system = create_enhanced_decoder_system(lightweight=True) # Create unified interface (batch + streaming) decoder = create_streaming_decoder(system, mode='unified')
# Set to batch mode for training decoder.set_mode('batch') # Process full sequences outputs = decoder.process(voice_embeddings, noise_embeddings) # Returns: voice_waveform, noise_waveform, voice_f0, voice_loudness, etc.
# Set to streaming mode decoder.set_mode('streaming') # Process frame-by-frame for chunk in embedding_stream: voice_chunk = chunk['voice_embeddings'] # [1, chunk_size, dims] noise_chunk = chunk['noise_embeddings'] # [1, chunk_size, dims] outputs = decoder.process(voice_chunk, noise_chunk, return_auxiliary=True) # Returns: voice_audio, noise_audio, mixture_audio + auxiliary features # Real-time audio output play_audio(outputs['voice_audio']) # [chunk_size * 80] samples
# Create enhanced dataset with specialized embeddings
python dataset_prep/decoder/create_enhanced_decoder_dataset.py \
--data_root /path/to/triplet/data \
--checkpoint_path /path/to/disentangler/checkpoint.pt \
--output_dir /path/to/enhanced_decoder_dataset \
--batch_size 144 \
--max_samples 100000- Voice Decoder: 3.22M parameters
- Noise Decoder: 3.46M parameters
- Total: 6.68M parameters
- Memory: ~27MB model weights + ~200MB activation memory (batch=16)
- Theoretical: 5ms per frame (chunk_size=1)
- Practical: ~10-15ms including GPU processing
- Buffer: Minimal buffering required due to causal design
- Voice Reconstruction: Enhanced pitch accuracy, formant preservation
- Noise Reconstruction: Improved texture coherence, transient preservation
- Mixture Consistency: Perfect reconstruction when voice + noise = mixture
# Each sample contains rich embeddings sample = { # Main embeddings 'voice_embeddings': torch.Tensor, # [T, 160] 'noise_embeddings': torch.Tensor, # [T, 144] # Individual voice features 'voice_pitch_features': torch.Tensor, # [T, 32] 'voice_harmonic_features': torch.Tensor, # [T, 32] 'voice_formant_features': torch.Tensor, # [T, 32] 'voice_vad_features': torch.Tensor, # [T, 16] 'voice_spectral_features': torch.Tensor, # [T, 48] # Individual noise features 'noise_broadband_features': torch.Tensor, # [T, 32] 'noise_transient_features': torch.Tensor, # [T, 24] 'noise_environmental_features': torch.Tensor, # [T, 28] 'noise_texture_features': torch.Tensor, # [T, 20] 'noise_nonharmonic_features': torch.Tensor, # [T, 24] 'noise_statistical_features': torch.Tensor, # [T, 16] # Target audio 'voice': torch.Tensor, # [32000] samples (2 seconds) 'noise': torch.Tensor, # [32000] samples 'mixture': torch.Tensor, # [32000] samples # Auxiliary features 'f0': torch.Tensor, # [T] F0 values 'loudness': torch.Tensor, # [T] RMS loudness }
# Recommended training config config = { 'batch_size': 16, # Fits in RTX 4090 24GB 'learning_rate': 1e-4, # Conservative for stability 'scheduler': 'cosine', # Smooth LR decay 'max_epochs': 100, 'patience': 15, # Loss weights 'reconstruction_weight': 1.0, # Main reconstruction loss 'spectral_weight': 0.5, # Multi-scale STFT loss 'auxiliary_weight': 0.1, # F0/loudness prediction 'consistency_weight': 0.2, # voice + noise = mixture }
The architecture is designed to easily add new heads for different sound sources:
# Example: Add instrument decoder
class InstrumentDecoder(EnhancedVoiceDecoder):
def __init__(self, instrument_embeddings_config):
super().__init__(
# Configure for instrument-specific features
)
# Add to system
system.instrument_decoder = InstrumentDecoder(config)- Causal Processing: No future dependencies in any layer
- Minimal State: Only OLA buffers need persistence
- GPU Efficient: Optimized for real-time GPU inference
- Warmup: Automatic model warmup prevents initial artifacts
- Basis Regularization: L2 + decorrelation penalties on learned basis
- Spectral Smoothing: Optional noise coherence enhancement
- DC Removal: Streaming high-pass filter for clean output
- Numerical Stability: Extensive clamping and normalization
Raw Audio [16kHz, mono]
ā
Enhanced Disentangler (specialized encoders)
ā
Rich Embeddings (voice: 160-dim, noise: 144-dim + features)
ā
Enhanced Decoder System
ā
Reconstructed Audio [voice, noise, mixture]
- Disentangler: Must use enhanced disentangler with specialized encoders
- Embeddings: Expects all individual feature components
- Frame Rate: 200Hz (5ms hop) for enhanced temporal resolution
- Normalization: Handles both normalized and raw embeddings
- Dataset Creation: Use
create_enhanced_decoder_dataset.pywith trained disentangler - Training Setup: Implement training loop with multi-component losses
- Evaluation: Compare against previous decoder on separation metrics
- Optimization: Fine-tune basis initialization and conditioning strategies
- Deployment: Test streaming performance on target hardware