Transformer with PyTorch and Rust


I was exploring Rust crates for machine learning and found tch-rs, a crate of rustic bindings for (Py)Torch. I thought building a Transformer using this would be cool, and I wanted to test if using these bindings was an okay route for larger projects. My initial goal included training a GPT and loading Mistral 7B. I didn’t quite get the weights loading, but a Mistral-like model can be trained if you’ve got enough RAM. The source code is on GitHub and is released into the public domain.

One test model was trained on 118k lines of the Zig compiler concatenated together. After a few thousand batches, we can sample the model to find code that is almost coherent:

//!! Below is generated code from a random sample during training.
fn writeArrayFully(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
        const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
        const extra = self.code.extraData(Zir.Inst.StructInit, inst_data.payload_index);

    const value = self.code.values.get(ip)[index];
        try self.writeBody(stream, body[0], decls_len);
    for (0..) |*decl_index, i| {
            const decl_index = struct_type.decl.unwrap() orelse {
            if (!decl.getName_values()) |decl_index| {
                assert(decl.has_decls.count() == null);
                    break :blk null;
            .fn_ret_ty_ies => |fn_ret_ty| {
                    const fn_info = fn_ty.getNamespace()[fn_info.total_params_len];
          [flag_index] = @intFromEnum(fn_info.return_type));
                    break :good;
            return call_info.child == position_type_target_index;
        .node_offset_params => |node_off| {
            const tree = try src_loc.file_scope.getTree(gpa);
            const node_tags = tree.nodes.items(.tag);
            for (node) |node_tags| - @singleError!{
                const node = src_loc.declRelativeToNodeIndex(node_off);
                const container_node = src_loc.declRelativeToNodeIndex(node_off);
                assert(src_loc.fullSrcLoc(node_decl_index, .{ .msg = test_node, .lazy = node_decl_index }).lazy;
                try transExpr(c, scope, expr_node, .used);
        .auto, .node_offset_func_type_extra_index => |nod

Even those this is nosensical code, we can see it successfully picked up syntax patterns and types that appear throughout the compiler. Some observations:

  • It reproduces basic Zig structure: the function has parameters and statements below it, whitespace is looking normal, variable declarations are present.
  • It nailed the basic structure of Zig’s for loop syntax but semantically it is incorrect: for (0..) |*decl_index, i| {} (it should be iterating over something + the 0.. range: for (decls, 0..)).
  • It reproduces a tree of cases quite well .auto, .node_offset_func_type_extra_index => . Notice how the whitespace before and after these lines consistent.

Here’s a sample from another model - it was able to generate entire function(-looking thing)s after 79500 batches.

/// Continue if and only happen unsigned into fit hold installation and then unexisting (first immediately)?
/// * For each of its comptime-known alloc, we cuse ready division by comptime-known
/// beson positive into element (or each vector with a well-defined layout
/// for integers, or backends which we are pre-resolving this feature.
fn resolveFnTypes(sema: *Sema, block: *Block, src: LazySrcLoc, parent_ty: Type, fn_ty: *const clang.QualType) bool {
    const scope = &c.global_scope.base;
    const opt_const = try transQualType(c, scope, const_qt, const_val_result, ty, sema.arena, sema.mod);
    return Air.internedToRef(result_val.toIntern());

fn zirValidateArrayInitTy(
    sema: *Sema,
    block: *Block,
    inst: Zir.Inst.Index,
    array_init: Air.Inst.Ref,
) CompileError!Air.Inst.Ref {
    const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
    const operand = try sema.resolveInst(inst_data.operand);
    const operand_ty = sema.typeOf(operand);
    const zir_datas = sema.air_instructions.items(.data);
    const operand_inst = zir_datas[@intFromEnum(operand)].pl_node;
    const extra = zir.extraData(Zir.Inst.ErrorSetDecl, pl_node.payload_index);
    const error_set_extra_index = extra.end + @typeInfo(ErrorSet).Struct.fields.len;

    var fields: []FieldErrorBundle.FieldErrorBundle.FieldIndex = undefined;
    try guide_file.clort_one == 0;

Although this is obviously overfitting the data, this makes me think we can achieve useful autocomplete for the ends of lines in local models using a small portion of resources (“small” relative to my Mac Studio with 128GB of RAM - the memory usages for the models I trained ranged 30-140GB). I am hoping to build a system for pretraining LLMs that are of a tiny size compared to X billion param models. Rust is the language I want to use to build this project, but I’m not convinced bundling PyTorch with tch-rs is a good route.

The tch-rs bindings for Tensor, Linear, Module, and some other standard types behave like normal PyTorch. As an example, here is an implementation of the Sigmoid Linear Unit activation function (aka Swish).

use tch::{nn::{Module, Path}, Kind, Tensor};
use crate::{linear::Linear, transformer::FeedForward};

/// FFN using the Sigmoid Linear Unit (aka Swish) activation function
pub struct Swish {
    /// aka w1
    gate_proj: Linear,
    /// aka w2
    down_proj: Linear,
    /// aka w3
    up_proj: Linear,

impl Module for Swish {
    fn forward(&self, xs: &Tensor) -> Tensor {
        // silu = x * sigmoid(x)
        let xs = xs.apply(&self.gate_proj).silu() * xs.apply(&self.up_proj);

impl FeedForward for Swish {
    fn new(p: &Path, in_dim: i64, hidden_dim: i64, kind: Kind) -> Self {
        let gate_proj = Linear::new_no_bias(p / "gate_proj", in_dim, hidden_dim, kind);
        let down_proj = Linear::new_no_bias(p / "down_proj", hidden_dim, in_dim, kind);
        let up_proj = Linear::new_no_bias(p / "up_proj", in_dim, hidden_dim, kind);

        Self {

And here’s an implementation of Root Mean Square Normalization.

use tch::{Kind, Tensor, nn::{Module, Path}};
use crate::transformer::NormLayer;

/// Root Mean Square Layer Normalization
pub struct RmsNorm {
    scale: Tensor,
    size: i64,

impl Module for RmsNorm {
    fn forward(&self, xs: &Tensor) -> Tensor {
        let variance = (xs*xs).mean_dim(-1, true, xs.kind());
        let hidden_states = xs * (variance + 1e-5).rsqrt();
        let scale = self.scale.reshape([1, 1, self.size]);
        scale * hidden_states

impl NormLayer for RmsNorm {
    fn new(p: &Path, size: i64, _kind: Kind) -> Self {
        let scale = p.zeros("weight", &[size]);
        Self { scale, size }

We get full access to building modules with custom forward passes and the Tensor API is pretty simple to use with lots of builtin functions. The backwards pass is also created using PyTorch’s autodiff engine.

The largest hurdle I ran into was the boilerplate nessecary to fully utilize the PyTorch API, resulting in complexity that I don’t think will scale well nor is justified for such a tiny clientside application (smol local LLMs). Features like

  • torch.compile (the JIT compiler introduced in PyTorch 2.0),
  • scaled_dot_product_attention (a fast attention kernel whose implementation is automatically chosen at runtime),
  • and torch.autograd.Function (integrating custom operations into PyTorch’s autodifferentiation engine)

are implemented in the Python layer and are not available from within base Torch. tch-rs has no APIs for these! torch.compile is completely unavailable to us as it only works with Python classes. To access the other Python APIs from Rust, we need to use another crate called PyO3 with tch-rs’ subcrate pyo3-tch. The latter is sparse for now (only including a conversion from Python Tensor objects to Rust Tensor objects), so we need to inject some extra Python to call the proper functions.

from torch import nn

def scaled_dot_product_attention(q, k, v):
    return nn.functional.scaled_dot_product_attention(

This can be called from Rust by building a PyModule and executing a function inside of it.

/// TODO dont reparse module every call lol
pub fn scaled_dot_product_attention(q: Tensor, k: Tensor, v: Tensor) -> PyResult<Tensor> {
    Python::with_gil(|py| {
        let fun: Py<PyAny> = PyModule::from_code(
        // get the function inside the module

        // call the function with q, k, v
        let result = fun.call1(py, (PyTensor(q), PyTensor(k), PyTensor(v)))?.extract::<PyTensor>(py)?.0;

To use torch.autograd.Function, a new Python class is required that overrides the forward and backward methods. This is created by instantiating another module produced by filling in a template.

import torch
import NEW

class NEWKernel(torch.autograd.Function):
    def forward(ctx, x):
        result = NEW.forward(x)
        return result

    def backward(ctx, grad):
        grad = grad.clone()
        result, = ctx.saved_tensors
        return NEW.backward(result, grad)

def NEW_wrapper(input):
    return NEWKernel.apply(input)

NEW is replaced with the name of our new function. This isn’t amazing code because we have zero static analysis for this weird template.

const KERNEL_AUTOGRAD_PY_TEMPLATE: &str = include_str!("");

/// Initialize a Python module with a given name and a function to register pyfunctions in it.
fn init_pymodule(name: impl AsRef<str>, fn_register: impl Fn(&PyModule) -> PyResult<()>) -> PyResult<PyObject> {
    let name = name.as_ref();
    Python::with_gil(|py| {
        // create a module and register the functions
        let module = PyModule::new(py, name)?;

        // insert into sys.modules
        let sys = PyModule::import(py, "sys")?;
        let py_modules: &PyDict = sys.getattr("modules")?.downcast()?;
        py_modules.set_item(name, module)?;

        let function_wrapper_source = KERNEL_AUTOGRAD_PY_TEMPLATE.replace("NEW", name);

        // create another module that contains the generated toch.autograd.Function in it
        let function_module = PyModule::from_code(
            &format!("{}", name),
            &format!("{}_function_wrapper", name),

        let fun: Py<PyAny> = function_module
            .getattr(format!("{}_wrapper", name).as_str())?


We can finally build a new function with a custom forward and backward pass:

#[pyo3(name = "forward")]
fn exp_forward(x: PyTensor) -> PyTensor {
    // TODO call metal kernel

#[pyo3(name = "backward")]
fn exp_backward(x: PyTensor, grad: PyTensor) -> PyTensor {
    // TODO call metal kernel
    PyTensor(x.0 * grad.0)

fn exp_module() -> impl Fn(Tensor) -> Tensor {
    init_tensor_op1("exp", |m| {
        m.add_function(wrap_pyfunction!(exp_forward, m)?)?;
        m.add_function(wrap_pyfunction!(exp_backward, m)?)?;

/// Create a function that modifies 1 tensor and returns 1 tensor.
/// The forward() and backward() functions are registered in the fn_register closure.
fn init_tensor_op1(
    name: impl AsRef<str>,
    fn_register: impl Fn(&PyModule) -> PyResult<()>
) -> impl Fn(Tensor) -> Tensor {
    let fun = init_pymodule(name, fn_register)
        .expect("failed to initialize module");

    move |tensor: Tensor| {
        Python::with_gil(|py| {
            fun.call1(py, (PyTensor(tensor),))
                .expect("failed to apply")
                .expect("it wasn't a PyTensor??")

As my TODOs note, ideally we would be building up larger tensor ops with many operations. Adding GPU acceleration would require jumping to yet another language, a Metal shader for my system (or try and hook into the unfinished Metal Performance Shader backend - see #86076 for info on the future of the Metal backends). After writing all that boilerplate, I decided PyTorch was a bit too bulky for my liking. I’m having to route calls through

  • Modules implemented in Rust to
  • Tensor ops written in Python which may contain
  • forward and backward passes written in Rust which will call
  • kernels written in Metal / GLSL / HLSL.

That’s a lot of movement and wasted CPU cycles in a clientside application that is already stretching the limits of the system. If you already have a large PyTorch application and are wanting to add a small amount of Rust, using tch-rs may be a fine approach if you’re willing to juggle the extra complexity. I didn’t even mention installing everything, which required a fork of tch-rs and building nightly PyTorch from scratch for bfloat16 support on macOS.

Thankfully there are some alternatives in the ecosystem for a fully Rust program. There are some extra requirements I am wanting, and although existing frameworks look great they are missing some things.

  • Burn from is an ML framework that includes kernel fusion and multiple backends for GPU acceleration. For macOS support, fusion requires using their WebGPU backend but wgpu doesn’t support (b)float16.
  • Candle from HuggingFace is an ML framework that includes a Tensor API with custom kernels for ops. It supports Metal 3’s bfloat16 and training models with lots of LLM examples.

Burn also includes a backend for Candle, so we can use Burn’s Autodiff<B> backend decorator to add autodiff to Candle’s Metal kernels. More investigation needs to be done on these, but they look more flexible than bundling PyTorch.