Overview
Thetrainer.py module extends the Hugging Face Trainer with custom implementations optimized for Qwen-VL models. It includes custom attention mechanisms using Flash Attention 2, monkey-patched forward methods for various Qwen model versions, and a specialized optimizer creation method that supports different learning rates for vision tower and multimodal projector components.
Custom Optimizer Creation
create_optimizer()
Creates an optimizer with separate learning rate configurations for different model components.
Learning rate for the multimodal projector (merger) module. If set, the projector parameters will use this learning rate instead of the base learning rate.
Learning rate for the vision tower (visual encoder) module. If set, the vision tower parameters will use this learning rate instead of the base learning rate.
Weight decay coefficient applied to parameters with decay (excludes bias terms).
- Base model parameters (with and without weight decay)
- Vision tower parameters (with and without weight decay, if
vision_tower_lris set) - Multimodal projector parameters (with and without weight decay, if
mm_projector_lris set)
Flash Attention Methods
flash_attention_forward()
Custom Flash Attention 2 forward pass for efficient attention computation.
The attention module.
Query tensor with shape
(batch, head, seq_len, dim).Key tensor with shape
(batch, head, seq_len, dim).Value tensor with shape
(batch, head, seq_len, dim).Cumulative sequence lengths tensor for variable-length attention.
Dropout probability for attention weights.
Attention scaling factor.
Sliding window size for local attention.
Softcap value for attention logits.
(attn_output, None) where attn_output is the attention output tensor.
qwen2vl_forward()
Custom forward method for Qwen2-VL and Qwen2.5-VL attention layers.
Input hidden states with shape
(batch_size, seq_len, hidden_dim).Attention mask tensor.
Position indices for positional embeddings.
Cached key/value states for efficient generation.
Whether to return attention weights.
Whether to use key/value caching.
Position indices for cache updates.
Precomputed rotary position embeddings (cos, sin).
(attn_output, attn_weights) where weights may be None.
qwen3vl_forward()
Custom forward method for Qwen3-VL and Qwen3-VL-MoE attention layers.
Input hidden states.
Rotary position embeddings (cos, sin).
Attention mask tensor.
Cached key/value states.
Position indices for cache updates.
(attn_output, attn_weights).
Utility Functions
replace_qwen2_vl_attention_class()
Replaces the default attention forward methods in transformers with optimized Flash Attention implementations for all supported Qwen-VL model variants:
- Qwen2-VL
- Qwen2.5-VL
- Qwen3-VL
- Qwen3-VL-MoE
print_trainable_parameters_visual()
Prints the trainable status of vision module components.
Output:
- Trainable/non-trainable attention block indices
- Merger module trainable status
print_trainable_parameters()
Prints the trainable status of language model components.
Output:
- Embed tokens trainable status
- Trainable/non-trainable decoder layer indices
Usage Example
Notes
- Flash Attention 2 is used by default for all attention computations
- The custom optimizer supports up to 6 parameter groups with different learning rates
- Vision tower parameters are identified by “visual” in their name
- Multimodal projector parameters are identified by “merger” in their name
- Weight decay is not applied to bias parameters