glm-4-9b-chat-1m模型结构解读

模型代码文件下载

glm-4-9b-chat-1m模型总体结构

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (embedding): Embedding(
      (word_embeddings): Embedding(151552, 4096)
    )
    (rotary_pos_emb): RotaryEmbedding()
    (encoder): GLMTransformer(
      (layers): ModuleList(
        (0-39): 40 x GLMBlock(
          (input_layernorm): RMSNorm()
          (self_attention): SelfAttention(
            (query_key_value): Linear(in_features=4096, out_features=5120, bias=True)
            (core_attention): CoreAttention(
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (dense): Linear(in_features=4096, out_features=4096, bias=False)
          )
          (post_attention_layernorm): RMSNorm()
          (mlp): MLP(
            (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
            (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
          )
        )
      )
      (final_layernorm): RMSNorm()
    )
    (output_layer): Linear(in_features=4096, out_features=151552, bias=False)
  )
)

glm-4-9b-chat-1m模型详细结构(下面是从输入到输出的顺序输出的每层的参数量)

transformer.embedding.word_embeddings.weight: torch.Size([151552, 4096])
transformer.encoder.layers.0.input_layernorm.weight: torch.Size([4096])
transformer.encoder.layers.0.self_attention.query_key_value.weight: torch.Size([5120, 4096])
transformer.encoder.layers.0.self_attention.query_key_value.bias: torch.Size([5120])
transformer.encoder.layers.0.self_attention.dense.weight: torch.Size([4096, 4096])
transformer.encoder.layers.0.post_attention_layernorm.weight: torch.Size([4096])
transformer.encoder.layers.0.mlp.dense_h_to_4h.weight: torch.Size([27392, 4096])
transformer.encoder.layers.0.mlp.dense_4h_to_h.weight: torch.Size([4096, 13696])

...有40个transformer.encoder.layers层,这里省略transformer.encoder.layers.1----transformer.encoder.layers.38

transformer.encoder.layers.39.input_layernorm.weight: torch.Size([4096])
transformer.encoder.layers.39.self_attention.query_key_value.weight: torch.Size([5120, 4096])
transformer.encoder.layers.39.self_attention.query_key_value.bias: torch.Size([5120])
transformer.encoder.layers.39.self_attention.dense.weight: torch.Size([4096, 4096])
transformer.encoder.layers.39.post_attention_layernorm.weight: torch.Size([4096])
transformer.encoder.layers.39.mlp.dense_h_to_4h.weight: torch.Size([27392, 4096])
transformer.encoder.layers.39.mlp.dense_4h_to_h.weight: torch.Size([4096, 13696])
transformer.encoder.final_layernorm.weight: torch.Size([4096])
transformer.output_layer.weight: torch.Size([151552, 4096])