Building a Mamba-Transformer from Scratch

mamba

Combining Mamba architecture with Transformer architecture represents an innovative approach in deep learning that leverages the strengths of both architectures to address their respective limitations. The Mamba architecture, known for its efficiency in sequential data processing and inherent ability to encode positional information, complements the powerful self-attention mechanism of Transformers, which excels in capturing long-range dependencies but can be computationally intensive. By integrating these architectures, hybrid models like Jamba, MambaVision, and Mambaformer achieve enhanced performance, efficiency, and scalability across various domains such as language modeling, computer vision, and time series forecasting. This synergy not only improves computational and memory efficiency but also results in state-of-the-art performance in complex tasks, demonstrating the significant advantages of combining these two cutting-edge architectures

Mamba Architecture

mamba

Mamba is an innovative sequence modeling architecture designed to address some of the limitations found in traditional Transformers, unlike Transformers which use self-attention mechanisms to capture dependencies across long sequences but are computationally intensive, Mamba utilizes selective state space models (SSMs) , these models are particularly efficient for processing sequential data and manage long-range dependencies effectively without the quadratic complexity associated with self-attention .

Mamba’s architecture simplifies the design by eliminating attention mechanisms and multi-layer perceptrons, relying instead on recurrent mechanisms that allow it to handle longer sequences with lower memory footprints and higher throughput , this makes Mamba a powerful alternative for tasks requiring efficient processing of large volumes of data.

👉🏼 Check How to Build mamba from scratch project guide

Transformer Architecture

Transformer

Transformers have revolutionized sequence modeling with their self-attention mechanism, which allows them to capture dependencies across entire sequences regardless of their length , this architecture has become the backbone of many state-of-the-art models in nlp and other domains due to its ability to model complex relationships between tokens in a sequence.

Transformers face challenges related to computational complexity and memory usage, especially as sequence lengths increase , the self-attention mechanism operates with quadratic complexity concerning sequence length, leading to significant resource requirements for very long sequences, despite these limitations Transformers remain a cornerstone of modern AI due to their flexibility and powerful performance in handling various types of data and tasks.

Transformer-Mamba Hybrid Architecture

The Transformer-Mamba hybrid architecture combines the best features of both Transformer and Mamba models to address their individual limitations , this hybrid approach leverages the Transformer’s self-attention mechanism to capture intricate long-range dependencies and Mamba’s efficient sequential processing to reduce computational and memory overhead , by integrating Mamba’s recurrent mechanisms with the Transformer’s self-attention, the hybrid model achieves a balance between modeling complex dependencies and maintaining resource efficiency , this combination allows for scalable, high-performance models that can handle both short-term and long-term dependencies effectively.

The Transformer-Mamba hybrid is particularly useful in scenarios where long sequences need to be processed efficiently without compromising on the model’s ability to understand complex relationships within the data.

Examples of Hybrid Architectures

Mambaformer

Overview: Mambaformer integrates the self-attention mechanism of Transformers with the efficient sequential processing of Mamba , this hybrid model uses Mamba layers to handle long sequences efficiently while incorporating Transformer-like self-attention for capturing complex dependencies.

Mambaformer

Performance: Mambaformer has shown improved performance in natural language processing tasks compared to traditional Transformers of similar size , it maintains high accuracy while reducing memory consumption and computational overhead, particularly for tasks involving very long sequences.

Jamba

Overview: Jamba combines the Transformer’s self-attention mechanism with Mamba’s recurrence-based approach , this hybrid architecture aims to harness the strengths of both models, allowing for efficient processing of long sequences with complex dependencies.

Jamba

Performance: Jamba has demonstrated state-of-the-art results in various benchmarks, outperforming comparable models in terms of efficiency and accuracy , it achieves higher throughput and better handling of long sequences compared to pure Transformer models.

MambaVision

Overview: MambaVision blends Transformer components with Mamba’s efficient sequence handling, optimized for vision tasks , it utilizes self-attention for capturing global features and Mamba’s mechanisms for local feature extraction.

MambaVision

Performance: MambaVision has outperformed traditional CNNs and Transformer-based vision models in tasks such as image classification, object detection, and segmentation , it offers improved accuracy and efficiency in processing high-resolution images and complex visual data.

Hybrid Time Series Models

Overview: Hybrid models designed for time series forecasting combine Transformer’s ability to model long-term dependencies with Mamba’s efficiency in handling sequential data , these models use Transformer-like mechanisms for capturing complex patterns and Mamba’s recurrent features for efficient processing.

Hybrid Time Series Models

Performance: Hybrid time series models have achieved superior forecasting accuracy compared to traditional time series models , they balance the need for long-term dependency modeling with the efficiency required for processing large-scale data, resulting in better performance on both short-term and long-term forecasts.


In this project guide, we will create a custom hybrid architecture by integrating insights from our previous work on Mamba and Transformer-based models.


Code Time

Install Import essential dependencies

! pip install einops datasets
import torch
import math
import time
from einops import einsum, rearrange, repeat
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Optional
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

Train The Tokenizer

We start by defining a function to get training data from a dataset of algerian darija , we use the load_dataset function to load the dataset and then yield chunks of 1000 text samples at a time.

Next we initialize a base tokenizer from a pre-trained Mamba model , after that, we train a new tokenizer using the chunks of training data with a specified vocabulary size , we then customize this tokenizer by setting its padding token to the end-of-sequence token and adding new special tokens related to FIM .

For these new tokens, we retrieve their token ids and optionally include a padding token id based on the truncate_or_pad flag.

# Function to get training data
def get_training_corpus():
    dataset = load_dataset("ayoubkirouane/Algerian-Darija", split="v1")
    for i in range(0, len(dataset), 1000):
        yield dataset[i : i + 1000]["Text"]

# Initialize the base tokenizer 
base_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")

# Train the new tokenizer
new_tokenizer = base_tokenizer.train_new_from_iterator(get_training_corpus(), vocab_size=3000)


new_tokenizer.pad_token = new_tokenizer.eos_token

fim_prefix_token = "<fim_prefix>"
fim_middle_token = "<fim_middle_token>"
fim_suffix_token = "<fim_suffix_token>"
fim_pad_token = "<fim_pad>"

# Get the FIM-specific tokens and get their token ids
new_tokenizer.add_tokens(
    [
        fim_prefix_token,
        fim_middle_token,
        fim_middle_token,
        fim_pad_token,
    ]
)
prefix_tok_id = new_tokenizer.convert_tokens_to_ids(fim_prefix_token)
middle_tok_id = new_tokenizer.convert_tokens_to_ids(fim_middle_token)
suffix_tok_id = new_tokenizer.convert_tokens_to_ids(fim_middle_token)
pad_tok_id = None

fim_tokens = [prefix_tok_id, middle_tok_id, suffix_tok_id]


# If truncate_or_pad is on, also get pad token id
truncate_or_pad = True
if truncate_or_pad:
    pad_tok_id = new_tokenizer.convert_tokens_to_ids(fim_pad_token)
    fim_tokens.append(pad_tok_id)

# Save the new tokenizer
new_tokenizer.save_pretrained("new_tokenizer")

Finally we save the new tokenizer to a directory for future use. This approach ensures our tokenizer is well-suited for handling algerian darija text and any special requirements we might have


Multi-Head Attention Mechanism

In this MultiHeadAttention module, we start by initializing the layers for query, key, and value projections using linear transformations , this setup includes defining the number of heads and the dimensionality of each head, with an option for linear attention if needed , in the forward pass, we first compute the query, key, and value projections from the input tensor and reshape them to accommodate multiple attention heads , we then scale the queries and use scaled dot-product attention to compute the attention output (flash attention-2).

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        heads: int,
        dim_head: int,
        dropout: float = 0.1,
        use_linear_attn: bool = False,
    ):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.dropout = dropout
        self.use_linear_attn = use_linear_attn

        inner_dim = heads * dim_head
        self.scale = dim_head ** -0.5

        # Query, Key, and Value projection layers
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        b, n, d = x.shape

        q = self.to_q(x).reshape(b, n, self.heads, self.dim_head).transpose(1, 2)
        k = self.to_k(x).reshape(b, n, self.heads, self.dim_head).transpose(1, 2)
        v = self.to_v(x).reshape(b, n, self.heads, self.dim_head).transpose(1, 2)

        q *= self.scale


        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        out = attn_output.transpose(1, 2).reshape(b, n, -1)

        return self.to_out(out)

Finally, we reshape and pass the output through a linear layer followed by dropout to produce the final result .


Linear Attention

In this module we start by projecting the input tensor into query, key, and value representations using a single linear layer that combines all three projections.

These projections are then split and reshaped to support multi-head attention , each query is scaled and softmax-ed to obtain attention weights, which are used to compute the context by applying matrix multiplication between the queries and keys , the resulting context is then used to generate the output by applying it to the values.

class LinearAttention(nn.Module):
    def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head**-0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        h = self.heads
        # Get queries, keys, and values
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        # Reshape for multi-head attention
        q = q.reshape(q.shape[0], q.shape[1], h, -1).transpose(1, 2)
        k = k.reshape(k.shape[0], k.shape[1], h, -1).transpose(1, 2)
        v = v.reshape(v.shape[0], v.shape[1], h, -1).transpose(1, 2)

        q *= self.scale
        q = F.softmax(q, dim=-1)
        k = F.softmax(k, dim=-2)

        if mask is not None:
            k = k.masked_fill(mask, float('-inf'))

        # Compute context and output
        context = torch.einsum('b h n d, b h n e -> b h d e', q, k)
        out = torch.einsum('b h d e, b h n d -> b h n e', context, v)

        # Reshape back to original dimensions and apply the final linear layer
        out = out.transpose(1, 2).reshape(x.shape[0], x.shape[1], -1)
        return self.to_out(out)

After computing the attention output, we reshape it back to the original dimensions and apply a final linear layer followed by dropout.


Mamba Block

In the MambaBlock class we define a sophisticated neural network block designed to leverage both convolutional and selective state space techniques.

We initialize the block with various parameters including the dimensionality of the input (dim), depth, state dimensions (d_state), and convolutional parameters , the in_proj layer projects the input into a higher-dimensional space, which is then processed through a 1D convolution to capture local patterns , the x_proj layer projects the input to produce components needed for the selective state space model (ssm), including parameters for Δ, B, and C.

class MambaBlock(nn.Module):
    def __init__(
        self,
        dim: int = None,
        depth: int = 5,
        d_state: int = 16,
        expand: int = 2,
        d_conv: int = 4,
        conv_bias: bool = True,
        bias: bool = False,
    ):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.d_state = d_state
        self.expand = expand
        self.d_conv = d_conv
        self.conv_bias = conv_bias
        self.bias = bias

        # If dt_rank is not provided, set it to ceil(dim / d_state)
        dt_rank = math.ceil(self.dim / 16)
        self.dt_rank = dt_rank

        # If dim_inner is not provided, set it to dim * expand
        dim_inner = dim * expand
        self.dim_inner = dim_inner

        # If dim_inner is not provided, set it to dim * expand
        self.in_proj = nn.Linear(dim, dim_inner * 2, bias=bias)

        self.conv1d = nn.Conv1d(
            in_channels=dim_inner,
            out_channels=dim_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=dim_inner,
            padding=d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(
            dim_inner, dt_rank + self.d_state * 2, bias=False
        )

        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(dt_rank, dim_inner, bias=True)

        A = repeat(torch.arange(1, self.d_state + 1), "n -> d n", d=dim_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(dim_inner))
        self.out_proj = nn.Linear(dim_inner, dim, bias=bias)

    def forward(self, x: Tensor):
        (b, l, d) = x.shape

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        x_and_res = rearrange(x_and_res, "b l x -> b x l")
        (x, res) = x_and_res.split(
            split_size=[self.dim_inner, self.dim_inner], dim=1
        )

        x = self.conv1d(x)[:, :, :l]
        x = F.silu(x)

        y = self.ssm(x)

        y = y * F.silu(res)

        output = self.out_proj(rearrange(y, "b dim l -> b l dim"))

        return output

    def ssm(self, x: Tensor):

        (d_in, n) = self.A_log.shape

        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = rearrange(x, "b d l -> b l d")
        x_dbl = self.x_proj(x_dbl)  # (b, l, dt_rank + 2*n)

        (delta, B, C) = x_dbl.split(
            split_size=[self.dt_rank, n, n], dim=-1
        )  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)

        y = self.selective_scan(
            x, delta, A, B, C, D
        )  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]

        return y

    def selective_scan(self, u, delta, A, B, C, D):

        (b, d_in, l) = u.shape
        n = A.shape[1]


        deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b d_in l n"))
        deltaB_u = einsum(
            delta, B, u, "b l d_in, b l n, b d_in l -> b d_in l n"
        )

        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        x = torch.zeros((b, d_in, n), device=next(self.parameters()).device)
        ys = []
        for i in range(l):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            y = einsum(x, C[:, i, :], "b d_in n , b n -> b d_in")
            ys.append(y)
        y = torch.stack(ys, dim=2)  # (b d_in l)

        if D is not None:
            y = y + u * rearrange(D, "d_in -> d_in 1")

        return y

In the forward pass, we project and split the input into x and residual components , the convolutional layer processes x, which is then passed through the ssm , the ssm incorporates learned parameters to perform a selective scan, capturing complex sequential dependencies and the final output is obtained by projecting the result back to the original dimensionality.

The ssm method implements the core selective state space model, using the parameters A, D, B, and C to perform a selective scan over the input sequence.


Feed Forward

The initialization sets up the dimensions for input and output, the internal dimensional multiplier, dropout rate, and whether to use bias , we compute the internal dimension based on the input dimension multiplied by a given factor (mult) , the project_in sequential block first linearly projects the input to this internal dimension and applies the silu activation function.

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: Optional[int] = None,
        dim_out: Optional[int] = None,
        mult: Optional[int] = 4,
        post_act_ln: Optional[bool] = False,
        dropout: Optional[float] = 0.0,
        no_bias: Optional[bool] = False,
        triton_kernels_on: bool = False,
    ):

        super().__init__()
        self.dim = dim
        self.dim_out = dim_out
        self.mult = mult
        self.post_act_ln = post_act_ln
        self.dropout = dropout
        self.no_bias = no_bias
        self.triton_kernels_on = triton_kernels_on

        inner_dim = int(dim * mult)
        dim_out = dim_out or dim  # Default to input dimension if not provided

        # Determine activation function
        activation = nn.SiLU()


        project_in = nn.Sequential(
                nn.Linear(dim, inner_dim, bias=not no_bias), activation
            )

        # Define feedforward network
        if post_act_ln:
            self.ff = nn.Sequential(
                project_in,
                nn.LayerNorm(inner_dim),
                nn.Dropout(dropout),
                nn.Linear(inner_dim, dim_out, bias=not no_bias),
            )
        else:
            self.ff = nn.Sequential(
                project_in,
                nn.Dropout(dropout),
                nn.Linear(inner_dim, dim_out, bias=not no_bias),
            )

    def forward(self, x):

        return self.ff(x)

The ff layer is constructed differently depending on whether post_act_ln is True, if it is the network includes a LayerNorm applied after activation and dropout, followed by another linear transformation to project to the output dimension , if post_act_ln is False the network omits LayerNorm and directly applies dropout before projecting to the output dimension.

In the forward method we simply pass the input through the defined ff sequential block, applying all transformations and activations in one go , this modular approach allows for flexibility in designing feedforward networks tailored to specific tasks or architectural needs.


RMSNorm

This class implements a normalization technique that scales input tensors to stabilize training , during initialization the class calculates a scaling factor as the inverse square root of the dimensionality (dim) and initializes a learnable parameter g, which scales the normalized output.

class RMSNorm(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.scale = dim ** (-0.5)
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x: Tensor) -> Tensor:
        return F.normalize(x, dim=-1) * self.scale * self.g

In the forward pass the input tensor x is first normalized across the last dimension using F.normalize , this normalization standardizes the tensor, reducing its variance , then the normalized tensor is scaled by multiplying with the previously computed scale factor and the learnable parameter g , this operation helps in adjusting the output's magnitude while keeping the training stable.


TransformerBlock

Here we combine components of both traditional and linear attention mechanisms to build a versatile transformer block , we initialize first the block with parameters defining the dimensionality, number of attention heads, and whether to use linear attention , inside the constructor we set up the attention mechanism — either MultiHeadAttention or LinearAttention based on the use_linear_attn flag and a feedforward network (FeedForward) , we also include a layer normalization step (nn.LayerNorm) to stabilize training.

class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        heads: int,
        dim_head: int,
        dropout: float = 0.1,
        ff_mult: int = 4,
        use_linear_attn: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.dropout = dropout
        self.ff_mult = ff_mult
        self.use_linear_attn = use_linear_attn

        self.attn = MultiHeadAttention(dim, heads, *args, **kwargs)

        # Linear Attention
        self.linear_attn = LinearAttention(
            dim=dim, heads=heads, dim_head=dim_head, dropout=dropout
        )

        self.ffn = FeedForward(dim, dim, ff_mult, *args, **kwargs)

        # Normalization
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: Tensor) -> Tensor:
        """
        Performs a forward pass of the TransformerBlock.

        Args:
            x (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.

        """
        if self.use_linear_attn:
            x = self.linear_attn(x)
            x = self.norm(x)
            x = self.ffn(x)
        else:
            x, _, _ = self.attn(x)
            x = self.norm(x)
            x = self.ffn(x)

        return x

In the forward method we conditionally apply either the LinearAttention or MultiHeadAttention to the input tensor x, for linear attention the tensor is processed directly and then normalized, followed by the feedforward network , if traditional attention is used the tensor undergoes multi-head attention, normalization, and then the feedforward network , this design allows the block to leverage either attention mechanism as needed, providing flexibility and efficiency in sequence modeling.


Mamba-Transformer block

In this class we combine mamba and transformer architectures to build a sophisticated block for sequence processing.

To start the block is initialized with various parameters, including dimensions, the number of attention heads, and depths for both mamba and transformer components , the constructor sets up multiple layers for each component: mamba blocks, transformer blocks, and feedforward networks (FeedForward), which are all organized into nn.ModuleList for modularity.

class MambaTransformerblock(nn.Module):

    def __init__(
        self,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: float = 0.1,
        ff_mult: int = 4,
        d_state: int = None,
        transformer_depth: int = 1,
        mamba_depth: int = 1,
        use_linear_attn: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.dim_head = dim_head
        self.d_state = d_state
        self.dropout = dropout
        self.ff_mult = ff_mult
        self.d_state = d_state
        self.transformer_depth = transformer_depth
        self.mamba_depth = mamba_depth

        # Mamba, Transformer, and ffn blocks
        self.mamba_blocks = nn.ModuleList([
            MambaBlock(dim, mamba_depth, d_state, *args, **kwargs)
            for _ in range(mamba_depth)
        ])
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                dim,
                heads,
                dim_head,
                dropout,
                ff_mult,
                use_linear_attn,
                *args,
                **kwargs,
            ) for _ in range(transformer_depth)
        ])

        self.ffn_blocks = nn.ModuleList([
            FeedForward(dim, dim, ff_mult, *args, **kwargs)
            for _ in range(depth)
        ])

        # Layernorm
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: Tensor) -> Tensor:
        for mamba, attn, ffn in zip(
            self.mamba_blocks,
            self.transformer_blocks,
            self.ffn_blocks,
        ):
            x = self.norm(x)
            x = mamba(x) + x
            x = self.norm(x)
            x = attn(x) + x
            x = self.norm(x)
            x = ffn(x) + x

        return x

During the forward pass the input tensor x is sequentially processed through each mamba block, transformer block, and feedforward network.

Each mamba block enhances the tensor, which is then passed through a transformer block , after each block layer normalization is applied to stabilize training and improve convergence , the results from each block are added back to the input tensor, implementing a residual connection that helps in training deeper networks.

This hybrid approach integrates the strengths of mamba’s efficient sequence processing with transformer’s powerful attention mechanisms, followed by feedforward layers, to enhance the model’s capability in handling complex sequence tasks.


Mamba-Transformer

This class defines a model that integrates both mamba and transformer blocks designed for sequence processing tasks.

In the initialization, the model sets up several key components: an embedding layer to convert token indices into dense vectors, a hybrid MambaTransformerBlock to process these embeddings, and a final layer to produce logits , the MambaTransformerBlock combines mamba and transformer blocks, providing a rich representation of the input sequences.

class MambaTransformer(nn.Module):
    def __init__(
        self,
        num_tokens: int,
        dim: int,
        heads: int,
        depth: int,
        dim_head: int,
        dropout: float = 0.1,
        ff_mult: int = 4,
        d_state: int = None,
        return_embeddings: bool = False,
        transformer_depth: int = 1,
        mamba_depth: int = 1,
        use_linear_attn=False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.dim_head = dim_head
        self.d_state = d_state
        self.dropout = dropout
        self.ff_mult = ff_mult
        self.d_state = d_state
        self.return_embeddings = return_embeddings
        self.transformer_depth = transformer_depth
        self.mamba_depth = mamba_depth

        self.emb = nn.Embedding(num_tokens, dim)
        self.mt_block = MambaTransformerblock(
            dim,
            heads,
            depth,
            dim_head,
            dropout,
            ff_mult,
            d_state,
            return_embeddings,
            transformer_depth,
            mamba_depth,
            use_linear_attn,
            *args,
            **kwargs,
        )
        self.to_logits = nn.Sequential(
            RMSNorm(dim), nn.Linear(dim, num_tokens)
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the MambaTransformer model.

        Args:
            x (Tensor): Input tensor of shape (batch_size, sequence_length).

        Returns:
            Tensor: Output tensor of shape (batch_size, sequence_length, num_tokens).
        """
        x = self.emb(x)
        x = self.mt_block(x)

        if self.return_embeddings:
            return x

        else:
            return self.to_logits(x)

In the forward pass the input tensor xrepresenting token indices is first embedded into dense vectors. , these embeddings are then processed by the MambaTransformerBlock, which applies the hybrid sequence processing techniques, depending on the return_embeddings flag the model either returns the processed embeddings or further applies a RMSNorm layer followed by a linear transformation to generate logits for each token position, typically used for prediction tasks.


Create The Model

We create an instance of the MambaTransformer model , this model is configured to use the vocabulary size from new_tokenizer, a model dimension of 512, 16 attention heads, and a depth of 8 transformer layers and 8 mamba layers , it incorporates both mamba and transformer blocks, with each block having a dimension of 128 for attention heads and a state dimension of 512.

# Create an instance of the MambaTransformer model
model = MambaTransformer(
    num_tokens=new_tokenizer.vocab_size,  # Number of tokens in the input sequence
    dim=512,  # Dimension of the model
    heads=16,  # Number of attention heads
    depth=8,  # Number of transformer layers
    dim_head=128,  # Dimension of each attention head
    d_state=512,  # Dimension of the state
    dropout=0.1,  # Dropout rate
    ff_mult=4,  # Multiplier for the feed-forward layer dimension
    return_embeddings=False,  # Whether to return the embeddings,
    transformer_depth=8,  # Number of transformer blocks
    mamba_depth=8,  # Number of Mamba blocks,
    use_linear_attn=True,  # Whether to use linear attention
)

The dropout rate is set to 0.1, the feed-forward layer uses a multiplier of 4, and the model is configured to use linear attention, return_embeddings is set to False indicating that the model will output logits rather than embeddings.

This configuration specifies a substantial model architecture capable of handling complex sequence tasks with a balanced approach between attention mechanisms and feed-forward processing.


Dataset class

In this class we define a custom dataset class TextDataset for use with pytorch's DataLoader , the __init__ method initializes the dataset by loading it with the specified dataset_name and split, and sets up the tokenizer and maximum sequence length , the __len__ method return the length of the dataset.

class TextDataset(Dataset):
    def __init__(self, dataset_name, split, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.dataset = load_dataset(dataset_name, split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['Text']  # Adjust 'text' if your dataset uses a different column name
        encoding = self.tokenizer(text, truncation=True,
                                  padding='max_length',
                                  max_length=self.max_length,
                                  return_tensors='pt') # Ensure PyTorch Tensor output

        input_ids = encoding['input_ids'].squeeze()

        # Assuming you want to use the input_ids as labels for language modeling

        # Shift labels
        labels = input_ids.clone()

        labels[:-1] = input_ids[1:]  # Shift labels
        return input_ids, labels  # Return both input_ids and labels

__getitem__ method retrieves a text sample by its index, tokenizes it while applying truncation and padding to ensure each sequence is of max_length , it then prepares the input ids and creates labels for language modeling by shifting the input ids one position forward the method in the end returns both the input_ids and labels as pytorch tensors.


Create the Trainer

The train function orchestrates the training of a MambaTransformer model, managing both the training and evaluation processes , it initializes the model, optimizer, and optional learning rate scheduler, then iterates over the training dataset for a specified number of epochs.

During training it performs forward passes, computes and backpropagates losses, and updates the model weights, incorporating gradient clipping to prevent instability.

def train(model: MambaTransformer,
          train_data: DataLoader,
          optimizer: torch.optim.Optimizer,

          val_data: DataLoader = None,
          epochs: int = 10,
          tokenizer: AutoTokenizer = new_tokenizer,
          device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
          clip_grad_norm: float = 1.0,
          lr_scheduler=None):
    """Trains the Mistral model.

    Args:
        model: The Mistral model to train.
        train_data: A DataLoader for the training dataset.
        optimizer: The optimizer to use for training.
        epochs: The number of training epochs.
        device: The device to use for training (e.g., 'cuda' or 'cpu').
        clip_grad_norm: The maximum norm of the gradients to clip.
        lr_scheduler: An optional learning rate scheduler.
    """

    model = model.to(device)
    model.train()

    print("Training...")
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        total_loss = 0.0
        start_time = time.time()

        for batch in tqdm(train_data, leave=False):
            input_ids, labels = batch

            input_ids, labels = input_ids.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(input_ids)

            # Calculate loss (use cross-entropy loss for language modeling)
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(outputs.view(-1, tokenizer.vocab_size), labels.view(-1))

            # Backward pass
            loss.backward()

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)

            # Update weights
            optimizer.step()

            if lr_scheduler is not None:
                lr_scheduler.step()

            total_loss += loss.item()


        avg_loss = total_loss / len(train_data)
        elapsed_time = time.time() - start_time
        print(f"Average loss: {avg_loss:.4f} | Elapsed time: {elapsed_time:.2f}s")

        if val_data is not None:
          # Evaluation Phase
          model.eval()
          eval_loss = 0
          with torch.no_grad():
              for step, batch in enumerate(val_data):
                  # Get input_ids and labels from the batch
                  input_ids, labels = batch
                  input_ids = input_ids.to(device)  # Send input_ids to the device
                  labels = labels.to(device)  # Send labels to the device

                  # Forward pass
                  outputs = model(input_ids)

                  # Calculate loss
                  loss = F.cross_entropy(outputs.view(-1, tokenizer.vocab_size), labels.view(-1), ignore_index=tokenizer.pad_token_id)
                  eval_loss += loss.item()
          avg_eval_loss = eval_loss / len(val_data)
          print(f"Epoch: {epoch+1}, Evaluation Loss: {avg_eval_loss:.4f}")
    model_save_path = "Hybrid.pt"
    torch.save(model.state_dict(), model_save_path)
    print("Training complete!")

If a validation dataset is provided it evaluates the model’s performance at the end of each epoch computing and printing the validation loss , the function also includes functionality to save the trained model to a file, ensuring that the learned parameters can be reused or deployed later.

well it is a basic routine for effectively training and evaluating the model, with attention to important aspects like gradient clipping and performance tracking.


Optimizer: torch.optim.AdamW is used with a learning rate of 1e-4 for the model parameters , this optimizer is well-suited for handling the weight updates during training.

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

Learning Rate Scheduler: torch.optim.lr_scheduler.CosineAnnealingLR adjusts the learning rate using a cosine annealing schedule, with a maximum number of iterations (T_max) set to 10 , this scheduler gradually reduces the learning rate following a cosine curve, helping the model converge smoothly towards the end of training.


train_dataset = TextDataset("ayoubkirouane/Algerian-Darija",
                      "v1",
                      new_tokenizer,
                      max_length=128)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)

Dataset: TextDataset is initialized with the ayoubkirouane/Algerian-Darija dataset from the datasets library, using the v1 split , it utilizes new_tokenizer for tokenization with a maximum sequence length of 128 tokens.

DataLoader: DataLoader is created with the train_dataset for batching. It uses a batch size of 4 and does not shuffle the data (shuffle=False) , this setup allows for efficient data loading during training, with each batch containing 4 samples.


Start The Training

train(model=model ,
      train_data=train_loader,
      optimizer=optimizer,
      tokenizer=new_tokenizer,
      epochs=1,
      clip_grad_norm=1.0,
      lr_scheduler=lr_scheduler)

This will start the training loop and print the progress and average loss for the specified number of epochs.


Generate Text

  • Perform a forward pass through the model.
  • Extract logits for the last token position and apply softmax to get probabilities.
  • Select the top 50 probable tokens and sample from them.
  • Append the sampled token to the sequence.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_length = 30
num_return_sequences = 10


tokens = new_tokenizer.encode("نتا  ")
tokens = torch.tensor(tokens , dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)

x = tokens.to(device)

while x.size(1) < max_length:

    with torch.no_grad():
        outputs = model(x)
        logits = outputs[0] if isinstance(outputs, tuple) else outputs
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

        ix = torch.multinomial(topk_probs, 1)
        xcol = torch.gather(topk_indices, -1, ix)
        x = torch.cat((x, xcol), dim=1)

# print the generated text
for i in range(num_return_sequences):
    tokens = x[i, :max_length].tolist()
    decoded = new_tokenizer.decode(tokens, skip_special_tokens=True)
    print(">", decoded)

This code is designed to generate text using the trained MambaTransformer model , it initializes tokens, processes them through the model to generate sequences, and prints the decoded results.