It’s really important to be able to follow the dimensionality of tensors as they move through a neural network. I love using Visual Studio Code debugging for this. But sometimes running an entire model just to do this debugging piece can be really cumbersome (setting up a bunch of libraries in a conda environment, etc.). And sometimes—as in the case for me today—I can’t even fit the model (Mixtral) on my machine.

Here’s my stupid tip, which applies if the model code you want to run is relatively decoupled from other dependencies. The idea is obvious enough that any experienced programmer is like ‘whatevs’, but I’ve never seen it shown before, so here you go…

I grab the code in question, drop it in a python file (using some generic environment that has torch and transformers), add whatever imports are necessary. Add code to instantiate classes and supply config and inputs. Basically do whatever you need to do to get the code to run.

Here’s what this looked like for today’s exercise, where I wanted to make sure I could follow what was happening in the MoE-related classes in Mixtral. Here’s the relevant code, pasted straight out of HuggingFace (the details aren’t important, but I’m putting it all here for reference):

class MixtralBLockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size
 
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
 
        self.act_fn = ACT2FN[config.hidden_act]
 
    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states
 
 
class MixtralSparseMoeBlock(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accomodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """
 
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok
 
        # gating
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
 
        self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
 
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)
 
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)
 
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
 
        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
 
        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
 
            if top_x.shape[0] == 0:
                continue
 
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()
 
            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
 
            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

In order to get it running in a standalone file, I had to make the following additions (here you are guided by the missing dependencies and vars highlighted in VS Code):

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class FakeConfig:
	def __init__(
		self,
		hidden_size=4096,
		intermediate_size=14336,
		hidden_act=nn.SiLU(),
		num_experts_per_tok=2,
		num_local_experts=8
	):
		self.hidden_size = hidden_size
		self.intermediate_size = intermediate_size
		self.hidden_act = hidden_act
		self.num_experts_per_tok = num_experts_per_tok
		self.num_local_experts = num_local_experts
 
# Insert classes above...
 
moe_config = FakeConfig()
moe_block = MixtralSparseMoeBlock(moe_config)
 
input = torch.randn((2, 512, 4096))
final_hidden_states, router_logits = moe_block(input)

and change two lines in the original snippet:

	# def __init__(self, config: MixtralConfig):
	def __init__(self, config: FakeConfig):
 
# ...
 
	# self.act_fn = ACT2FN[config.hidden_act]
	self.act_fn = config.hidden_act

That was quick and easy, and now I can step through the code. We’re off to the races!