Polymorphic programming

What is polymorphism and what problem does it solve? Let’s describe the situation first.

  1. We have multiple different struct types.

  2. They share a common interface.

  3. We want to write code common to multiple types.

  4. We may need to work with a list of polymorphic objects.

Simple generics

As long as all we need is to write down common code for different types, like containers, all we need is a generic struct or enum with a type argument.

pub mod board {
    pub struct Board<Piece> {
        pub fields: [[Option<Piece>; 8]; 8],
    }

    impl<Piece> Board<Piece> {
        pub fn new() -> Self {
            Self {
                fields: Default::default(),
            }
        }

        pub fn replace(&mut self, row: usize, col: usize, piece: Option<Piece>) -> Option<Piece> {
            let field = &mut self.fields[row][col];
            match piece {
                Some(piece) => field.replace(piece),
                None => field.take(),
            }
        }
    }
}

pub mod chess {
    pub enum Piece {
        King,
        Queen,
        Bishop,
        Knight,
        Rook,
        Pawn,
    }
}

use board::Board;
use chess::Piece;

fn main() {
    let mut board = Board::<Piece>::new();

    board.replace(0, 0, Some(Piece::King));
}

Now you can combine various board implementation with various piece sets. Just please note that the Board has no idea about the Piece type except its size. It can only move it into .fields or take it out.

Implementations of Board<_> and Piece are completely independent in the code but the combined type Board<Piece> is resolved at the compile time.

Enum polymorphics

Rust enumerations can be used to represent multiple different objects with a shared set of operations.

enum Shape {
    Square { side: i32 },
    Rectangle { width: i32, height: i32 },
    Circle { radius: i32 },
}

use Shape::*;

impl Shape {
    fn name(&self) -> &'static str {
        match self {
            Square { .. } => "square",
            Rectangle { .. } => "rectangle",
            Circle { .. } => "circle",
        }
    }

    fn area(&self) -> Option<i32> {
        match self {
            Square { side } => Some(side * side),
            Rectangle { width, height } => Some(width * height),
            Circle { radius } => None,
        }
    }
}

fn main() {
    let shapes = [
        Square { side: 5 },
        Rectangle { width: 6, height: 8 },
        Circle { radius: 3 },
    ];

    for shape in &shapes {
        println!("name={} area={}", shape.name(), shape.area().unwrap_or(0));
    }
}

The Shape type effectively combines different types of shapes into a single data type. The set of shapes can only be exteded by extending the enumeration and modifying the operations.

Trait polymorphism

Traits describe sets of available operations. You can create generic functions and types that depend on a trait rather than a specific type but still get resolved at the compile time.

pub mod ops {
    pub trait Area {
        fn area(&self) -> f64;
    }
}

pub mod shapes {
    pub struct Square {
        pub side: f64,
    }

    impl super::ops::Area for Square {
        fn area(&self) -> f64 {
            self.side * self.side
        }
    }
}

use ops::Area;
use shapes::Square;

fn main() {
    let shape = Square { side: 5. };

    println!("{}", shape.area())
}

This example doesn’t show the power of traits. It merely shows the basic usage. The next artificial example introduces a generic function that takes impl Shape.

pub mod shapes {
    pub trait Shape {
        fn shape(&self) -> &'static str;
        fn area(&self) -> f64;
    }

    pub struct Square {
        pub side: f64,
    }

    impl Shape for Square {
        fn shape(&self) -> &'static str {
            "square"
        }
        fn area(&self) -> f64 {
            self.side * self.side
        }
    }

    pub struct Rectangle {
        pub width: f64,
        pub height: f64,
    }

    impl Shape for Rectangle {
        fn shape(&self) -> &'static str {
            "rectangle"
        }
        fn area(&self) -> f64 {
            self.width * self.height
        }
    }
}

use shapes::*;

pub fn consume(shape: impl Shape)
{
    println!("Shape is a {}. Its area is {}.",
        shape.shape(),
        shape.area());
}

pub fn main() {
    consume(Square { side: 5. });
    consume(Rectangle { width: 3., height: 5. });
}

Dynamic polymorphism

Now if you want to work with multiple objects in a container, your code cannot know the concrete type at compile time. You need to use dynamic dispatch.

Now if you want to build a list of trait objects, you need to be able to dispatch method calls at the run time and pass dynamic trait object references around.

pub mod shapes {
    pub trait Shape {
        fn shape(&self) -> &'static str;
        fn area(&self) -> f64;
    }

    pub struct Square {
        pub side: f64,
    }

    impl Shape for Square {
        fn shape(&self) -> &'static str {
            "square"
        }
        fn area(&self) -> f64 {
            self.side * self.side
        }
    }

    pub struct Rectangle {
        pub width: f64,
        pub height: f64,
    }

    impl Shape for Rectangle {
        fn shape(&self) -> &'static str {
            "rectangle"
        }
        fn area(&self) -> f64 {
            self.width * self.height
        }
    }
}

use shapes::*;

pub fn main() {
    let shapes: Vec<Box<dyn Shape>> = vec![
        Box::new(Square { side: 5. }),
        Box::new(Rectangle { width: 3., height: 5. }),
        Box::new(Rectangle { width: 3., height: 5. }),
    ];

    for shape in &shapes {
        println!("{}", shape.shape());
    }
}

So what is stored in the vector? It is a fat pointer that consists of a pointer to the Square or Rectangle structure and another pointing to the vtable created impl Trait for Struct.

Some theory

This is a different model from object oriented languages and is actually more flexible. First you can use the trait system for both generic programming and dynamic polymorphism. Second you can combine traits together and then pass any type that implements all traits.

No class hierary needed. No data inheritance included. All data inclusion is done via composition, dynamic dispatch is done via Box<dyn Trait>, &dyn Trait and other smart pointers that can work with unsized data.

You can combine all of the above techniques together and explore more advanced features of the Rust type system.

Example: Expression tree

You can use dynamic polymorphism each time you would use interfaces, abstract classes or vtables in other programming languages. The following exampe builds a tree structure using dynamic trait objects.

mod exp {
    struct Constant {
        value: f64,
    }

    struct Variable {
        name: String,
    }

    struct Sum {
        left: Box<dyn Exp>,
        right: Box<dyn Exp>,
    }

    #[derive(Debug)]
    pub enum Error {
        NotImplemented,
        UnboundVariable,
    }

    pub trait Exp {
        fn evaluate(&self, map: &HashMap<String, f64>) -> Result<f64, Error>;
    }

    use std::collections::HashMap;

    impl Exp for Constant {
        fn evaluate(&self, _map: &HashMap<String, f64>) -> Result<f64, Error> {
            Ok(self.value)
        }
    }

    impl Exp for Variable {
        fn evaluate(&self, map: &HashMap<String, f64>) -> Result<f64, Error> {
            map.get(&self.name).cloned().ok_or(Error::UnboundVariable)
        }
    }

    impl Exp for Sum {
        fn evaluate(&self, map: &HashMap<String, f64>) -> Result<f64, Error> {
            Ok(self.left.evaluate(map)? + self.right.evaluate(map)?)
        }
    }

    type BoxedExp = Box<dyn Exp>;

    pub fn var(name: &str) -> BoxedExp {
        Box::new(Variable { name: name.to_string() })
    }

    pub fn cst(value: f64) -> BoxedExp {
        Box::new(Constant { value })
    }

    pub fn sum(left: BoxedExp, right: BoxedExp) -> BoxedExp {
        Box::new(Sum { left, right })
    }

    use std::str::FromStr;
    impl FromStr for Box<dyn Exp> {
        type Err = Error;
        fn from_str(s: &str) -> Result<Self, Self::Err> {
            Err(Error::NotImplemented)
        }
    }
}

use std::collections::HashMap;
use std::iter::FromIterator;
use exp::*;

fn main() {
    let e = sum(sum(cst(5.), var("a")), var("b"));
    let vars = HashMap::from_iter([("a".to_string(), 10.), ("b".to_string(), 11.)]);

    println!("Result: {}", e.evaluate(&vars).unwrap());
}