Matrix Multiplication

Mar 25, 2023

Matrix multiplication is a foundational element in machine learning. It is the basis for propagating an input through a network of weights.

This article gives an overview of the properties of a matrix multiplication and provides an example function in Go.

The Formula

The formula for a matrix multiplication can be summarised as follows:

1
2
AB = [n, k=1] Σ Aᵢₖ Bₖⱼ
    = (Aᵢ₁B₁ⱼ) + (Aᵢ₂B₂ⱼ) + ... + (AᵢₖBₖⱼ)

A 2d matrix multiplication receives 2 matrices as an input and outputs a new matrix equal to the row height of the first matrix and the column width of the second. The columns of the first matrices and must be equal to the rows of the second.

Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
A = [[1, 2],
     [3, 4],
     [5, 6]]

B = [[7, 8, 9],
     [10, 11, 12]]

AB = [[1x7 + 2x10, 1x8 + 2x11, 1x9 + 2x12],
      [3x7 + 4x10, 3x8 + 4x11, 3x9 + 4x12],
      [5x7 + 6x10, 5x8 + 6x11, 5x9 + 6x12]]

   = [[27, 30, 33],
      [61, 68, 75],
      [95, 106, 117]]

Code Sample

The following is a code sample using Go generics.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package main

import (
	"errors"
	"fmt"
)

func main() {
	a := New([]int{1, 2,
		3, 4,
		5, 6}, 3, 2)

	b := New([]int{7, 8, 9,
		10, 11, 12}, 2, 3)
	prod := make([]int, a.Rows()*b.Cols())
	Product(a, b, prod)
	fmt.Printf("%v\n", prod)
}

func Product[T Numeric](a, b *Dense[T], sum []T) error {
	if a.Cols() != b.Rows() {
		return errors.New("unaligned matrices")
	}

	p := b.Cols()

	for i := 0; i < a.Rows(); i++ {
		for k, c := range a.Row(i) {
			for j, r := range b.Row(k) {
				sum[i*p+j] += c * r
			}
		}
	}
	return nil
}

func New[T Numeric](cells []T, rows, cols int) *Dense[T] {
	return &Dense[T]{rows: rows, cols: cols, cells: cells}
}

type Numeric interface {
	~int | ~float32 | ~float64
}

type Dense[T Numeric] struct {
	rows  int
	cols  int
	cells []T
}

func (d *Dense[T]) Rows() int {
	return d.rows
}

func (d *Dense[T]) Cols() int {
	return d.cols
}

func (d *Dense[T]) Row(i int) []T {
	start := i * d.Cols()
	return d.cells[start : start+d.Cols()]
}

tags: [ machinelearning algebra golang ]