When you spin up an LLM with SGLang, you get this satisfying sequence: progress bars flying, safetensors loading, and then — boom — the model starts generating tokens. But have you ever wondered what’s actually happening under the hood? From allocating GPU memory, to reading weights off disk, to wiring everything together so the model can actually run inference — there’s a lot going on.
Specifically, questions like:
- How does SGLang know which module to use to load a given model?
- How does it know which weight in a safetensors file maps to which module?
- And even if it knows the module, how does it know which parameter inside that module a weight belongs to?
- A single module (like an MLP layer) gets reused across many layers — how does each instance know which weights are “its own”?
I recently had to dig into all of this while adding support for a new model in SGLang, so let’s walk through it together.
The Model Repo
First, let’s see what a model provider actually ships alongside the safetensors. Take Qwen/Qwen3-4B as an example. Beyond the weight files, there’s a config.json that configures the model for inference with the transformers library. Here’s what the Qwen3-4B config looks like:
{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": false,
"hidden_size": 2560,
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
...
}
You can ignore most of the fields for now, but architectures is critical. This is the name of the top-level module used to load this model. You can find the corresponding class in the transformers repo, and SGLang has its own version with the same name here.
More specifically, SGLang has a model registry that scans all files under sglang.srt.models, finds the entry class in each, registers them, and then looks up the right one when reading the config. That’s why every model file needs a line like this at the bottom:
EntryClass = Qwen3ForCausalLM
OK! So now we know how SGLang figures out which module to use. But let’s keep looking at what else is in the model repo.
For large models where a single safetensors file isn’t enough, there’s a model.safetensors.index.json. Here’s a snippet from Qwen3-4B’s:
{
"metadata": {
"total_size": 8044936192
},
"weight_map": {
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
...
}
}
This weight_map tells you exactly which shard each weight tensor lives in. If the model is small enough to fit in a single file (like my current company’s model LFM2.5-1.2B), there’s no index file. For a specific tensor like model.layers.0.self_attn.q_proj.weight, the shard file stores metadata like this:
{
"model.layers.0.self_attn.q_proj.weight": {
"dtype": "BF16",
"shape": [4096, 2560],
"data_offsets": [
953559552,
974531072
]
}
}
So every weight has a unique string key, and we just need to find the matching parameter in our module and load it in, right? Easy!
Except — wait. When we write a module, do we ever manually specify a name for each tensor we declare? No. And when a layer class gets reused across many layers, how does each instance know which weights belong to it?
Tensor Naming
Let’s answer the first question: how does PyTorch decide on the key (name) for each parameter we declare?
Here’s the simplest possible example:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.i_am_linear = nn.Linear(10, 20)
model = MyModel()
for name, param in model.named_parameters():
print(name)
# Output:
# i_am_linear.weight
# i_am_linear.bias
The two parameters inside nn.Linear get automatically named i_am_linear.weight and i_am_linear.bias — which is {variable name of the layer}.{parameter name}. Let’s try nesting:
class Block(nn.Module):
def __init__(self):
super().__init__()
self.i_am_linear = nn.Linear(10, 10)
class BigModel(nn.Module):
def __init__(self):
super().__init__()
self.i_am_block = Block()
model = BigModel()
for name, _ in model.named_parameters():
print(name)
# Output:
# i_am_block.i_am_linear.weight
# i_am_block.i_am_linear.bias
The naming rule is clear: a tensor’s key is formed by chaining the variable names from the top-level module all the way down, separated by dots. So when you implement a model in SGLang, don’t casually rename your variables — those names aren’t just for readability, they’re how weights get matched! (Which is kind of wild when you think about it.)
Now that key model.layers.0.self_attn.q_proj.weight from earlier makes a lot more sense — model, layers, self_attn, q_proj are all just variable names at different levels of nesting. You can find self.model = Qwen3Model(...) right there in Qwen3ForCausalLM.
But what about that .0. in layers.0? That comes from nn.ModuleList:
class StackModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
])
model = StackModel()
for name, _ in model.named_parameters():
print(name)
# Output:
# layers.0.weight
# layers.0.bias
# layers.1.weight
# layers.1.bias
# layers.2.weight
# layers.2.bias
PyTorch automatically uses the list index to distinguish between layers. That’s the answer to our “how does each layer instance know which weights are its own” question — PyTorch handles the indexing for you.
Great, so we understand the naming. Now we can just let PyTorch load everything automatically, right?
…not quite.
Custom Weight Loaders
Both Lfm2ForCausalLM and Qwen3ForCausalLM implement their own load_weights function. Why not just let PyTorch do its thing?
Looking at Lfm2ForCausalLM’s load_weights, there are basically four reasons you’d want a custom loader:
1. Skip parameters you don’t need to load. For example, rotary_emb.inv_freq can be recomputed from the RoPE config, so there’s no need to load it from disk.
2. Handle key mismatches. Sometimes the weight file uses a different key than your module expects — your implementation might use different variable names or a different layer structure than whatever code originally saved those tensors. For instance, a weight called ?.conv.conv.weight in the file might correspond to ?.conv.conv_weight in your module — you need to map one to the other.
3. Transform weights on the fly. Sometimes you need to tweak the tensor before loading it:
if ".conv.conv.weight" in name:
name = name.replace(".conv.conv.weight", ".conv.conv_weight")
loaded_weight = loaded_weight.squeeze(1) # (D, 1, K) -> (D, K)
4. Handle fused weights with custom loaders. This one’s the most interesting.
For inference efficiency and easier parallelism, SGLang sometimes fuses weights together. A classic example is QKVParallelLinear, which fuses Q, K, and V into a single module:
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
The problem: in the weight file, Q, K, and V are stored separately. But in our module, they’re fused into one. How do we get three separate tensors into one fused parameter?
The answer is a custom weight_loader. Looking at QKVParallelLinear’s implementation, its signature looks like:
def weight_loader(
self,
param: Parameter, # the fused QKV parameter to fill
loaded_weight: torch.Tensor, # the weight tensor from the file
loaded_shard_id: Optional[str] = None, # "q", "k", or "v"
):
So the loader gets called three times to fully populate the fused parameter:
weight_loader(qkv_weight, q_tensor, 'q')
weight_loader(qkv_weight, k_tensor, 'k')
weight_loader(qkv_weight, v_tensor, 'v')
Note: This is simplified — in practice, the loader also has to handle quantization, tensor parallelism, different loading formats, and a bunch of other things. That’s why these functions tend to get… large.
So if you’re using a fused layer (or any layer with a custom weight loader) in your model, you need to handle these cases in your top-level load_weights function.
Wrapping Up
If you’ve made it here, you now have a solid mental model of how weight loading works in SGLang! To recap:
- SGLang figures out which module class to use via
config.json’sarchitecturesfield and its own model registry - PyTorch automatically names all parameters by chaining variable names down the module hierarchy
nn.ModuleListhandles per-layer indexing automatically- Custom
load_weightsfunctions handle key mismatches, weight transforms, and fused layers
Honestly, I wouldn’t have thought about any of this if I hadn’t needed to add a new model to SGLang. We all take weight loading for granted — and we haven’t even touched tensor parallelism or quantization yet, which add another whole layer of complexity.
So next time you watch SGLang churn through those safetensors progress bars, spare a thought for the infra engineers who made all that complexity invisible.
Show some respect for the infra folks!