Qwen3-VL-30B-A3B-Instruct: 画像推論パス全容 [MEDIUM] [VERIFIED]
対象モデル:
Qwen3-VL-30B-A3B-Instruct(model_type:qwen3_vl_moe) 調査目的: ECConnector実装に必要なエンコーダキャッシュテンソルの正確なshape/size把握 調査日: 2026-03-21 参照:target/Qwen3-VL-30B-A3B-Instruct/config.json
1. モデル構成
クラス階層
Qwen3VLMoeForConditionalGeneration (qwen3_vl_moe.py:399)
├── Qwen3_VisionTransformer (qwen3_vl.py:312) ← Vision Encoder
│ ├── Qwen3_VisionPatchEmbed (qwen3_vl.py:142)
│ ├── Qwen3_VisionBlock × 27 (qwen3_vl.py:208)
│ │ ├── Qwen2_5_VisionAttention (qwen2_5_vl.py:300)
│ │ └── Qwen3_VisionMLP (qwen3_vl.py:171)
│ ├── Qwen3_VisionPatchMerger (qwen3_vl.py:260) ← main merger
│ └── Qwen3_VisionPatchMerger × 3 (qwen3_vl.py:260) ← deepstack mergers
└── Qwen3MoeLLMForCausalLM (qwen3_vl_moe.py) ← Text (MoE)
└── Qwen3MoeLLMModel (qwen3_vl_moe.py:84)
└── Qwen3MoeSparseMoeBlock × 48 (qwen3_moe.py)
参照: target/vllm/vllm/model_executor/models/qwen3_vl_moe.py:399(登録・初期化)
コンフィグパラメータ
Vision Encoder (config.json → vision_config):
| パラメータ | 値 | 備考 |
|---|---|---|
| depth | 27 | Transformer ブロック数 |
| hidden_size | 1152 | 内部特徴量次元 |
| num_heads | 16 | head_dim = 72 |
| intermediate_size | 4304 | MLP中間層 |
| out_hidden_size | 2048 | merger出力次元(投影先) |
| patch_size | 16 | ピクセル単位(Gemma3は14) |
| spatial_merge_size | 2 | 2×2パッチを1トークンに統合 |
| temporal_patch_size | 2 | 画像は2フレームに複製して入力 |
| num_position_embeddings | 2304 | 48×48 学習済みグリッド |
| deepstack_visual_indexes | [8, 16, 24] | 中間特徴量抽出レイヤー |
| in_channels | 3 | RGB |
| hidden_act | gelu_pytorch_tanh |
Text Model (config.json → text_config):
| パラメータ | 値 | 備考 |
|---|---|---|
| num_hidden_layers | 48 | |
| hidden_size | 2048 | |
| num_attention_heads | 32 | |
| num_key_value_heads | 4 | GQA (8:1) |
| head_dim | 128 | |
| num_experts | 128 | |
| num_experts_per_tok | 8 | アクティブエキスパート数 |
| moe_intermediate_size | 768 | エキスパートあたり |
| intermediate_size | 6144 | Dense層用 |
| max_position_embeddings | 262144 | |
| rope_theta | 5,000,000 | |
| mrope_section | [24, 20, 20] | 3D M-RoPE |
| vocab_size | 151936 |
特殊トークン:
| トークン | ID | テキスト表現 |
|---|---|---|
| vision_start | 151652 | <|vision_start|> |
| vision_end | 151653 | <|vision_end|> |
| image_token | 151655 | <|image_pad|> |
| video_token | 151656 | <|video_pad|> |
2. OpenAI API → 内部表現への変換パス
graph TD
A["OpenAI ChatCompletion API<br/>(image_url in message)"] --> B["OpenAIServingChat<br/>serving_chat.py"]
B --> C["parse_chat_inputs_to_harmony_messages<br/>chat_utils.py"]
C --> D["MediaConnector.fetch_image<br/>URL/base64 → PILイメージ"]
D --> E["AsyncLLM.add_request"]
E --> F["InputProcessor<br/>(tokenize + Qwen3VLMultiModalProcessor)"]
F --> G["Qwen3VLMultiModalProcessor.apply<br/>(qwen3_vl.py:920)"]
G --> H["HF Qwen3VLProcessor.__call__<br/>(smart_resize + normalize + tokenize)"]
H --> I["EngineCoreRequest<br/>(token_ids + mm_kwargs)"]
処理順序
- OpenAI API受信:
image_url付きChatCompletionRequest - 画像取得:
MediaConnector.fetch_image()でURL/base64からPILイメージをデコード - チャットテンプレート適用: Jinja2テンプレートで
<|vision_start|><|image_pad|><|vision_end|>を挿入 - HF Processor呼び出:
Qwen3VLProcessor→Qwen2VLImageProcessorFastでリサイズ・正規化 - プレースホルダー展開:
<|image_pad|>をnum_vision_tokens個の token_id=151655 に置換 - EngineCoreRequest構築:
token_ids+mm_kwargs(pixel_values, image_grid_thw等)
3. Chat Template & Placeholder展開
テンプレート構造
参照: target/Qwen3-VL-30B-A3B-Instruct/chat_template.json
画像を含むユーザーメッセージ:
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>この画像について説明してください<|im_end|>
<|im_start|>assistant
Placeholder展開ロジック
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:1024-1030
def get_image_replacement_qwen3vl(item_idx: int):
grid_thw = out_mm_kwargs["image"][item_idx]["image_grid_thw"].data
merge_length = merge_size ** 2 # = 4
num_tokens = int(grid_thw.prod()) // merge_length
return [hf_processor.image_token_id] * num_tokens # token_id=151655 × N
1枚の画像に対して <|image_pad|> が num_vision_tokens 個のトークン(全て token_id=151655)に展開される。
4. 画像前処理(Preprocessing)
Processor構成
参照: target/Qwen3-VL-30B-A3B-Instruct/preprocessor_config.json
| パラメータ | 値 |
|---|---|
| processor_class | Qwen3VLProcessor |
| image_processor_type | Qwen2VLImageProcessorFast |
| patch_size | 16 |
| merge_size | 2 |
| min_pixels (shortest_edge) | 65,536 (≈256²) |
| max_pixels (longest_edge) | 16,777,216 (=4096²) |
| image_mean | [0.5, 0.5, 0.5] |
| image_std | [0.5, 0.5, 0.5] |
smart_resize
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:670-678
factor = patch_size * merge_size # = 16 × 2 = 32
resized_height, resized_width = smart_resize(
height=image_height, width=image_width,
factor=32,
min_pixels=65536,
max_pixels=16777216,
)
- 画像の H, W を 32の倍数 にリサイズ
- 総ピクセル数が
min_pixels〜max_pixelsの範囲に収まるようアスペクト比を維持してスケール smart_resizeはtransformers.models.qwen2_vl.image_processing_qwen2_vlからインポート
Vision Token数の計算式
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:682-689
# 画像の場合(num_frames=2, temporal_patch_size=2)
padded_num_frames = round_up(2, 2) # = 2
grid_t = max(2 // 2, 1) # = 1
grid_h = resized_height // 16 # patch_size
grid_w = resized_width // 16 # patch_size
num_patches = grid_t × grid_h × grid_w # = (H'/16) × (W'/16)
num_vision_tokens = num_patches // 4 # merge_size² = 4
導出: num_vision_tokens = (H'/32) × (W'/32)
5. Vision Encoder アーキテクチャ詳細
5.1 PatchEmbed
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:142-168
Conv3dLayer(
in_channels=3,
out_channels=1152,
kernel_size=(2, 16, 16), # (temporal_patch_size, patch_size, patch_size)
stride=(2, 16, 16),
bias=True,
)
- 入力:
(num_patches, 1536)— 各パッチは3ch × 2frames × 16px × 16pxがflattened - 処理:
view(L, 3, 2, 16, 16)→ Conv3d →view(L, 1152) - 出力:
(num_patches, 1152)
5.2 Position Embedding
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:464-522
- 48×48 (=2304) の学習済み位置埋め込み (
nn.Embedding(2304, 1152)) - 双線形補間 (
fast_pos_embed_interpolate): 任意のgrid_h × grid_wに対して48×48グリッドから補間 - spatial_merge_size=2 による並べ替え(2×2ブロック単位でグループ化)
- 時間次元t > 1の場合はexpand/repeat
5.3 Rotary Position Embedding
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:419-462
- 2D RoPE:
(h_pos, w_pos)座標ペアから計算 partial_rotary_factor = 0.5(head_dim 72のうち36次元にのみ適用)- max_position=8192
- テキスト側は3D M-RoPE(
mrope_section=[24, 20, 20])で別管理
5.4 Transformer Blocks (27層)
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:208-258
各ブロック:
- LayerNorm (eps=1e-6)
- Multi-Head Attention (16 heads, head_dim=72)
- Residual connection
- LayerNorm
- MLP: Linear(1152→4304) → SiLU → Linear(4304→1152)
- Residual connection
形状は全27層で不変: (num_patches, 1, 1152)(unsqueeze(1) 済み)
5.5 Deepstack Feature Extraction
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:565-579
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states, ...)
if layer_num in self.deepstack_visual_indexes: # [8, 16, 24]
deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
- Layer 8, 16, 24 の出力を それぞれ独立のmerger で投影
- 各deepstack merger:
use_postshuffle_norm=Truenorm(x.view(-1, 4608))→ Linear(4608→4608) → GELU → Linear(4608→2048)
- 出力:
(num_vision_tokens, 2048)× 3本
5.6 Main Merger (最終投影)
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:300-309
# use_postshuffle_norm=False (main merger)
x = self.norm(x).view(-1, self.hidden_size) # (num_patches, 1, 1152) → (num_vision_tokens, 4608)
x = self.linear_fc1(x) # Linear(4608 → 4608)
x = self.act_fn(x) # GELU
out = self.linear_fc2(x) # Linear(4608 → 2048)
# 出力: (num_vision_tokens, 2048)
hidden_size = context_dim × spatial_merge_size² = 1152 × 4 = 4608 — 空間的に隣接する2×2パッチを結合してから投影。
5.7 最終出力(Deepstack連結)
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:580-584
hidden_states = self.merger(hidden_states) # (num_vision_tokens, 2048)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists, dim=1
) # (num_vision_tokens, 2048 + 2048 + 2048 + 2048) = (num_vision_tokens, 8192)
★ Vision Encoder最終出力: (num_vision_tokens, 8192)
out_hidden_size の定義もこれを反映:
# qwen3_vl.py:337
self.out_hidden_size = vision_config.out_hidden_size * (1 + len(self.deepstack_visual_indexes))
# = 2048 * (1 + 3) = 8192
6. Tensor Shape遷移の全体図
入力画像 (H, W, 3)
│
↓ smart_resize: H' = round(H, 32), W' = round(W, 32)
↓ HF Processor: normalize, 2フレーム複製, パッチ分割
│
pixel_values: (num_patches, 1536)
│ num_patches = (H'/16) × (W'/16)
│ 1536 = 3 × 2 × 16 × 16
│
↓ PatchEmbed (Conv3d)
│
(num_patches, 1152)
│
↓ + position_embedding (bilinear interpolation from 48×48)
↓ unsqueeze(1)
│
(num_patches, 1, 1152)
│
├─ Layer 0〜7: VisionBlock ×8 → (num_patches, 1, 1152)
├─ Layer 8: deepstack_merger[0] → ds_8: (num_vision_tokens, 2048)
├─ Layer 9〜15: VisionBlock ×7 → (num_patches, 1, 1152)
├─ Layer 16: deepstack_merger[1] → ds_16: (num_vision_tokens, 2048)
├─ Layer 17〜23: VisionBlock ×7 → (num_patches, 1, 1152)
├─ Layer 24: deepstack_merger[2] → ds_24: (num_vision_tokens, 2048)
├─ Layer 25〜26: VisionBlock ×2 → (num_patches, 1, 1152)
│
↓ main merger: view→(num_vision_tokens, 4608)→Linear→GELU→Linear
│
main: (num_vision_tokens, 2048)
│
↓ torch.cat([main, ds_8, ds_16, ds_24], dim=1)
│
★ ENCODER OUTPUT: (num_vision_tokens, 8192) ← encoder_cache に格納
│ num_vision_tokens = num_patches // 4 = (H'/32) × (W'/32)
│
↓ _process_image_input (qwen3_vl.py:1418-1438)
↓ split by image (grid_thw.prod(-1) // merge_size // merge_size)
│
Per-image tensor: (num_vision_tokens_i, 8192) ← encoder_cache[mm_hash]
7. Encoder Cache テンソル仕様(ECConnector用)
テンソル形式
| 項目 | 値 |
|---|---|
| ndim | 2 (sanity check: target/vllm/vllm/v1/worker/utils.py:62-89) |
| shape | (num_vision_tokens, 8192) |
| dtype | bfloat16 |
| device | CUDA GPU(計算時)→ CPU(ECConnector保存時) |
| key | mm_hash (SHA256) |
| hidden_dim内訳 | 2048 (main) + 2048 (ds_layer8) + 2048 (ds_layer16) + 2048 (ds_layer24) |
画像サイズ別テーブル
| 入力画像 | resize後 | grid (t,h,w) | num_patches | num_vision_tokens | tensor shape | bfloat16 サイズ |
|---|---|---|---|---|---|---|
| 256×256 | 256×256 | (1,16,16) | 256 | 64 | (64, 8192) | 1.0 MB |
| 512×384 | 512×384 | (1,32,24) | 768 | 192 | (192, 8192) | 3.1 MB |
| 512×512 | 512×512 | (1,32,32) | 1024 | 256 | (256, 8192) | 4.2 MB |
| 768×768 | 768×768 | (1,48,48) | 2304 | 576 | (576, 8192) | 9.4 MB |
| 1024×768 | 1024×768 | (1,64,48) | 3072 | 768 | (768, 8192) | 12.6 MB |
| 1024×1024 | 1024×1024 | (1,64,64) | 4096 | 1024 | (1024, 8192) | 16.8 MB |
| 1920×1080 | 1920×1088 | (1,120,68) | 8160 | 2040 | (2040, 8192) | 33.4 MB |
| 2048×2048 | 2048×2048 | (1,128,128) | 16384 | 4096 | (4096, 8192) | 67.1 MB |
| 4096×3072 | 4096×3072 | (1,256,192) | 49152 | 12288 | (12288, 8192) | 201.3 MB |
計算式: size_bytes = num_vision_tokens × 8192 × 2
ECConnector save/loadの呼び出し箇所
Save:
# target/vllm/vllm/v1/worker/gpu_model_runner.py:2442-2445
for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output # (N, 8192) tensor
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
ECConnectorBase.save_caches() シグネチャ:
# target/vllm/vllm/distributed/ec_transfer/ec_connector/base.py:150-165
def save_caches(
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
) -> None:
Load:
# ECConnectorBase.start_load_caches() — base.py:132-147
def start_load_caches(
self, encoder_cache: dict[str, torch.Tensor], **kwargs
) -> None:
ECExampleConnector の保存例
参照: target/vllm/vllm/distributed/ec_transfer/ec_connector/example_connector.py
# save: GPU → CPU → safetensors
ec_cache = encoder_cache[mm_hash] # (N, 8192), bfloat16, CUDA
tensors = {"ec_cache": ec_cache.detach().cpu()} # → CPU
safetensors.torch.save_file(tensors, filename)
# load: safetensors → GPU
ec_cache = safetensors.torch.load_file(filename, device="cuda")["ec_cache"]
encoder_cache[mm_hash] = ec_cache # (N, 8192), bfloat16, CUDA
8. Deepstackの言語モデルへの注入方式
embed_input_ids での分割
参照: target/vllm/vllm/model_executor/models/qwen3_vl.py:1930-1969
# encoder_cacheから取得した (N, 8192) テンソルを分割
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim], # [2048, 6144]
dim=-1,
)
# main: (N, 2048) → テキスト埋め込みのplaceholder位置にmerge
# multiscale: (N, 6144) → reshape → (3, seq_len, 2048)
テキスト埋め込みへのマージ
- main embeddings
(N, 2048):_merge_multimodal_embeddings()でplaceholderトークン位置に配置 - multiscale embeddings: reshape →
(3, total_seq_len, 2048)→_set_deepstack_input_embeds()でバッファに書き込み
MoEテキストモデルへの注入
参照: target/vllm/vllm/model_executor/models/qwen3_vl_moe.py:117-123
# Qwen3MoeLLMModel.forward()
for layer_idx, layer in enumerate(self.layers):
hidden_states, residual = layer(positions, hidden_states, residual)
if deepstack_input_embeds is not None and layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
- Layer 0:
hidden_states += deepstack_input_embeds_0(Layer 8由来) - Layer 1:
hidden_states += deepstack_input_embeds_1(Layer 16由来) - Layer 2:
hidden_states += deepstack_input_embeds_2(Layer 24由来)
deepstack は early layers に中間表現を直接加算する形で多スケール情報を注入する。
9. MoE テキストモデルの構造
Mixture of Experts 仕様
| パラメータ | 値 |
|---|---|
| 総エキスパート数 | 128 |
| アクティブエキスパート数/トークン | 8 |
| エキスパート中間層サイズ | 768 |
| Gate | ReplicatedLinear(2048, 128) |
| norm_topk_prob | true |
実効計算量: 各トークンで8エキスパート × 768中間 = 6144次元相当(Dense 6144と同等の計算)
パラメータ効率
- 全パラメータ: ~30B(128エキスパートの重み含む)
- アクティブパラメータ: ~3B(8/128 = 6.25%のエキスパートのみ活性化)
decoder_sparse_step = 1→ 全層がMoE(Dense層なし)
10. Gemma3 27B との比較
| 項目 | Gemma3-27B-IT | Qwen3-VL-30B-A3B |
|---|---|---|
| Vision Encoder | ||
| アーキテクチャ | SiglipVisionModel | Qwen3_VisionTransformer |
| patch_size | 14 | 16 |
| hidden_size | 1152 | 1152(同一) |
| depth | 27 | 27(同一) |
| 位置埋め込み | 2D learned (固定) | 2D learned + bilinear interpolation |
| RoPE | なし | 2D Partial RoPE (factor=0.5) |
| 投影 | ||
| 方式 | AvgPool2d(4) + Linear | Spatial Merge (2×2) + MLP |
| 出力次元 | 5376 (text hidden) | 2048 |
| Deepstack | なし | あり (layers 8,16,24) |
| Encoder出力dim | 5376 | 8192 (2048×4) |
| トークン数/画像 | 固定256 | 可変 (64〜12288+) |
| Temporal | なし | temporal_patch_size=2 |
| テキストモデル | ||
| アーキテクチャ | Dense Transformer | MoE (128 experts) |
| hidden_size | 5376 | 2048 |
| 層数 | 62 | 48 |
| 前処理 | ||
| リサイズ | 固定896×896 | smart_resize (32の倍数, 可変) |
| Pan-and-Scan | あり(オプション) | なし |
| 正規化 | ImageNet mean/std | mean=0.5, std=0.5 |
| キャッシュ | ||
| encoder_cache tensor | (256, 5376) 固定 | (N, 8192) 可変 |
| キャッシュサイズ/画像 | 2.6 MB 固定 | 1.0〜201+ MB 可変 |
ECConnector実装への影響
- 可変サイズテンソル: Gemma3は固定256トークンだが、Qwen3-VLは画像解像度に依存して大幅に変動(64〜12288+トークン)。ストレージ割り当てに注意が必要
- 大きなhidden_dim: 8192次元はGemma3の5376より52%大きい。deepstack情報を含むため圧縮不可
- メモリ使用量: 高解像度画像で100MB超のテンソルがありうる。ネットワーク転送コストに注意
- deepstack分割の透明性: ECConnectorは
(N, 8192)テンソルをそのまま保存/復元すればよい。分割はembed_input_ids内で行われるため、ECConnector側でのdim分割は不要
付録A: 数値計算例
例1: 1024×1024 画像
H=1024, W=1024
smart_resize → 1024×1024 (変更なし, 32の倍数)
grid_t=1, grid_h=64, grid_w=64
num_patches = 4096
num_vision_tokens = 1024
pixel_values: (4096, 1536)
PatchEmbed後: (4096, 1152)
VisionBlock後: (4096, 1, 1152)
main merger後: (1024, 2048)
deepstack × 3: (1024, 2048) × 3
最終出力: (1024, 8192)
encoder_cache: 1024 × 8192 × 2 = 16,777,216 bytes ≈ 16.8 MB
例2: 1920×1080 画像 (Full HD)
H=1080, W=1920
smart_resize → 1088×1920 (H:1080→1088, 32の倍数に切り上げ)
grid_t=1, grid_h=68, grid_w=120
num_patches = 8160
num_vision_tokens = 2040
pixel_values: (8160, 1536)
最終出力: (2040, 8192)
encoder_cache: 2040 × 8192 × 2 = 33,423,360 bytes ≈ 33.4 MB
例3: 256×256 画像 (最小クラス)
H=256, W=256
smart_resize → 256×256 (min_pixels=65536, 256²=65536 ちょうどOK)
grid_t=1, grid_h=16, grid_w=16
num_patches = 256
num_vision_tokens = 64
pixel_values: (256, 1536)
最終出力: (64, 8192)
encoder_cache: 64 × 8192 × 2 = 1,048,576 bytes ≈ 1.0 MB