ID photo of Ciro Santilli taken in 2013 right eyeCiro Santilli OurBigBook logoOurBigBook.com  Sponsor 中国独裁统治 China Dictatorship 新疆改造中心、六四事件、法轮功、郝海东、709大抓捕、2015巴拿马文件 邓家贵、低端人口、西藏骚乱
llm_count_mults.py
#!/usr/bin/env python

def metrics(
    # Number of layers
    L,
    # Embedding dimension
    d_model,
    # Dimension of hidden layer of fully connected layer
    d_ff,
    # Number of head
    h, 
    # Dimension of K and Q
    d_head,
    # Context length
    n_ctx,
    vocab_size,
    # Grouped query attention. TODO implement.
    kv_heads=None,
):
    return {
        'mults_per_token': 
            # My limited brain
            L * (
                h * (
                    # 1x K, 1x Q, and 2x for V rank decomposed
                    4 * d_model * d_head +
                    # Right-most column of KQ product
                    n_ctx * d_head +
                    # All values times newly calculated right-most column
                    n_ctx * d_model
                ) +
                # MLP for latest token only
                2 * d_model * d_ff
            ) +
            # Output projection
            d_model * vocab_size

            ## ChatGPT
            #(
            #    L * (
            #        4 * d_model**2 +
            #        h * d_head * n_ctx +
            #        # MLP for latest token only
            #        2 * d_model * d_ff
            #    ) +
            #    # Output projection
            #    d_model * vocab_size
            #)
        ,

        # I think that with KV caching we are basically just doing matrix-vector multiplication.
        # So the number of params equals the number of FLOPs for the most part, and it is memory
        # bottle-necked, unless we do some query batching.
        'n_params': (
            L * (
                h * (
                    # 1x K, 1x Q, and 2x for V rank decomposed
                    4 * d_model * d_head
                ) +
                # Fully connected layer, rank decomposed
                2 * d_ff * d_model
            ) +
            # Output projection
            d_model * vocab_size
        )
    }

# https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings
def human_format(num):
    num = float('{:.3g}'.format(num))
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    return '{} {}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'G', 'T'][magnitude])

models = {
    'gpt-2': {
        "L": 12,
        "d_model": 768,
        "d_ff": 3072,
        "h": 12,
        "d_head": 64,
        "n_ctx": 1024,
        "vocab_size": 50257,
    },
    'gpt-3': {
        "L": 96,
        "d_model": 12288,
        "d_ff": 49152,
        "h": 96,
        "d_head": 128,
        "n_ctx": 2048,
        "vocab_size": 50257,
    },
    # https://arxiv.org/pdf/2407.21783
    'llama-3-1-70b': {
        "L": 80,
        "d_model": 8192,
        "d_ff": 28672,
        "h": 64,
        # TODO source
        "d_head": 128,
        "kv_heads": 8,
        "n_ctx": 8192,
        "vocab_size": 128000,
    },
    #'deepseek-v2-67b': {
    #    "L": 80,
    #    "d_model": 8192,
    #    "d_ff": 28672,
    #    "h": 64,
    #    "n_ctx": 8192,
    #    "vocab_size": 130000,
    #},
}
for name, params in models.items():
    res = metrics(**params)
    print(name)
    print(f'mults_per_token: {res['mults_per_token']:,} (~{human_format(res['mults_per_token'])})')
    print(f'n_params: {res['n_params']:,} (~{human_format(res['n_params'])})')
    print()