This page becomes unresponsive for a few seconds after loading to evaluate the code samples below. Don't worry, the page will be responsive after that.



Home
⚠️ Oops! This page doesn't appear to define a type called _.

This post's purpose is two-fold. First, it's a hat-tip to this beautful paper. Second, it can serve as a very concrete example of a monad. An idea (that others have articulated better than I can) is that when trying to teach an abstract concept, it's better to start with concrete examples and gradually move towards abstraction and generalization rather than the other way around.

One thing to point out up-front is that the code samples below are written in OCaml, and are editable/running (thanks to redoc), so if you feel the urge to get your hands dirty and change some code in order to see the effect - go for it!

If you do want to learn more about monads, I would recommend the following posts:

With that in mind, let's immediately break the rule of starting concrete and lay out the formal definition of a monad. If it means nothing or very little to you at this point in the post, don't worry. I will explain what return, map, and join mean intuitively, but only later, in the context of our concrete example: the probability distribution.

module Monad = struct
  module type S = sig
    type 'a t
    val return : 'a -> 'a t
    val map : 'a t -> f:('a -> 'b) -> 'b t
    val join : 'a t t -> 'a t
  end
end
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

Let's define a probability type. This is unnecessary, but it allows me to refer to probabilities as Probability.ts later on, rather than just floats, which hopefully adds a small bit of readability.

module Probability = struct
  type t = float 
  let zero = 0.
  let (+) = (+.)
  let (-) = (-.)
end
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};
module Distribution_implementation = struct
  type 'a t = (Probability.t * 'a) list
  let return a = [ 1., a ]

  (* compare just the 'a states, not the probabilities *)
  let compare_states (_, a1) (_, a2) = 
    compare a1 a2
  
  let group_states : 'a t -> 'a t = fun t ->
    List.sort compare_states t
    |> List.group ~break:(fun t1 t2 -> compare_states t1 t2 <> 0)
    |> List.map ~f:(fun group ->  
        (* each group is going to be non-empty *)
        let (_, a) = List.hd group in 
        let p = List.fold_left (fun acc (p, _) -> acc +. p) 0. group in
        p, a) 

  let map t ~f = 
    List.map t ~f:(fun (p, a) -> p, f a)
    |> group_states 
  
  let join t = 
    List.bind t ~f:(fun (p, t) -> List.map t ~f:(fun (p1, a) -> p *. p1, a))
    |> group_states 

  let to_string ?(decimals = 3) f t = 
    let string_of_float f = 
      let s = Printf.sprintf "%0.*f" decimals f in
      let is_exact = float_of_string s = f in
      if is_exact then s else "~" ^ s
    in
    String.concat "\n" 
      (List.map t ~f:(fun (p, a) -> f a ^ ":\t" ^ string_of_float p ^ " "))
end
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

Now back to the good part - let's define the main type that we'll be dealing with: a Distribution.t.

module Distribution : sig
  type 'a t = (Probability.t * 'a) list
  val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string
  include Monad.S with type 'a t := 'a t
end = Distribution_implementation 
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

What is Distribution_implementation and where is it defined? Well, I hid it. If you really want to see it, fine - click here to show it, but it's a big chunk of code that's not super interesting or important. What is important is the signature which Distribution exposes. Let's explain each line:

type 'a t = (Probability.t * 'a) list: This says that a Distribution.t of 'as (read: possible outcomes) is really a list of 'as with an associated probability. Again, let's make this concrete. Let's say our distribution represent the weather tomorrow, and to keep it simple, let's say there are two possible outcomes: Rain or Shine. If we wanted to create a Distribution.t that represents an 80% chance of Shine and a 20% chance of Rain, ... actually, let's just do that:

type rain_or_shine = Rain | Shine;;
let tomorrow's_weather : rain_or_shine Distribution.t = 
  [ 0.8, Shine; 0.2, Rain ]
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

You can see that the type of tomorrow's_weather is a rain_or_shine Distribution.t, but really it's just a list of possible outcomes along with their associated probabilities. And, just to explicitly connect the dots, the 'a in this case is rain_or_shine.

val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string: This line probably needs the least explaining. It says that there exists a function Distribution.to_string that, given a function from 'a -> string can produce a string from a 'a distribution. Although probably unnecessary, let's take it for a spin.

let rain_or_shine_to_string = function
  | Rain -> "Rain"
  | Shine -> "Shine";;

Distribution.to_string rain_or_shine_to_string tomorrow's_weather 
|> print_endline;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

let rain_or_shine_to_string =
  fun
  | Rain => "Rain"
  | Shine => "Shine";

Distribution.to_string(rain_or_shine_to_string, tomorrow's_weather)
|> print_endline;

Simple enough.

include Monad.S with type 'a t := 'a t: Ok, there's kind of a lot to unpack in this one line. include Monad.S really just means paste the signature Monad.S here, so let's do that and explain each of those lines:

module Monad = struct
  module type S = sig
    type 'a t
    val return : 'a -> 'a t
    val map : 'a t -> f:('a -> 'b) -> 'b t
    val join : 'a t t -> 'a t
  end
end
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

Great, here's where we can explain what these very abstract functions mean in the context of a distribution.

val return : 'a -> 'a t: return says give me an outcome, and I'll give you a distribution on that outcome. Continuing with our weather example, we could say Distribution.return Shine. The probabilities of a distribution have to add up to 100%, so what's the only reasonable thing this could mean? Well, return will create a "Distribution" where the outcome that you passed in happens with a 100% probability.

val map : 'a t -> f:('a -> 'b) -> 'b t: map says give me a distribution of outcomes of type 'a and a way to map 'as to 'bs, and I'll give you a distribution of 'bs. Let's imagine a way we might use this function on tomorrow's_weather. Let's say we're trying to figure out whether or not we can play soccer in the park tomorrow. Let's use map!

let can_play_soccer_in_park rain_or_shine = 
  match rain_or_shine with
  | Rain -> false
  | Shine -> true;;

let can_play_soccer_in_park_tomorrow : bool Distribution.t = 
  Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
  switch (rain_or_shine) {
  | Rain => false
  | Shine => true
  };

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
  Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);

So, we created a boolean distribution (which repesents whether or not we can play soccer in the park tomorrow) by mapping an existing distribution (tomorrow's weather) with a function that maps weather to whether or not you can play soccer in the park. A bit of a mouthful, but conceptually pretty straightforward.

val join : 'a t t -> 'a t: Saving the best for last. join is what really differentiates a monad. In our case, it means that a distribution of distributions is still just a distribution. No wonder people find this confusing... Let's carry on with our strategy of making this more and more concrete until it's actually intelligible.

What if our function can_play_soccer_in_park was a bit too simplistic. Even if it's sunny tomorrow (i.e. Shine), there's still some chance, albeit small, that you can't play soccer in the park. What if the park is closed? What if other people are using the park all day long? So, let's account for this and update can_play_soccer_in_park to return a boolean distribution, as opposed to a boolean.

let can_play_soccer_in_park rain_or_shine = 
  match rain_or_shine with
  | Rain -> 
    (* we definitely can't play if it's raining *)
    Distribution.return false 
  | Shine -> 
    (* let's say we can play with a 95% probability *)
    [ 0.95, true; 0.05, false ]
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
  switch (rain_or_shine) {
  | Rain =>
    /* we definitely can't play if it's raining */
    Distribution.return(false)
  | Shine =>
    /* let's say we can play with a 95% probability */
    [(0.95, true), (0.05, false)]
  };

Ok, and let's do what we did before and map tomorrow's_weather with our new and improved function.

let can_play_soccer_in_park_tomorrow : bool Distribution.t = 
  Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
Type Error: File "", line 102, characters 2-64: Error: This expression has type bool Distribution.t Distribution.t = (Probability.t * bool Distribution.t) list but an expression was expected of type bool Distribution.t = (Probability.t * bool) list Type bool Distribution.t = (Probability.t * bool) list is not compatible with type bool
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
  switch (rain_or_shine) {
  | Rain =>
    /* we definitely can't play if it's raining */
    Distribution.return(false)
  | Shine =>
    /* let's say we can play with a 95% probability */
    [(0.95, true), (0.05, false)]
  };

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
  Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);
Type Error: File "", line 124, characters 2-66: Error: This expression has type bool Distribution.t Distribution.t = (Probability.t * bool Distribution.t) list but an expression was expected of type bool Distribution.t = (Probability.t * bool) list Type bool Distribution.t = (Probability.t * bool) list is not compatible with type bool

Wait - what happened? Why didn't that work like before? Well, ignoring the confusing line numbers, it says that you expected can_play_soccer_in_park_tomorrow to be of type bool Distribution.t but it's actually bool Distribution.t Distribution.t, i.e. it's a distribution of boolean distributions!

If that's still confusing, think about it this way. With 80% chance, the weather tomorrow is Shine, and if the weather is Shine then there's an 95% chance you can play soccer in the park and a 5% chance you can't. So, you can see how there are two layers of probability distributions. But, intuitively, we know we can collapse these into just one. If I ask you "what's the chance you can play in the park tomorrow?" - there's a single answer: 80% * 95% + 20% * 0% = 76%.

That's exactly what join does. It flattens two levels of a monad - a distribution in our case - into one level. Let's see it in action.

let can_play_soccer_in_park_tomorrow : bool Distribution.t = 
  Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
  |> Distribution.join
;;

Distribution.to_string string_of_bool can_play_soccer_in_park_tomorrow
|> print_endline;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
  | Rain
  | Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
  (0.8, Shine),
  (0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
  switch (rain_or_shine) {
  | Rain =>
    /* we definitely can't play if it's raining */
    Distribution.return(false)
  | Shine =>
    /* let's say we can play with a 95% probability */
    [(0.95, true), (0.05, false)]
  };

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
  Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park)
  |> Distribution.join;

Distribution.to_string(string_of_bool, can_play_soccer_in_park_tomorrow)
|> print_endline;

At this point, it's probably smart to introduce a new operation bind. This pattern of calling map and then join immediately afterwards is so common that there's an operator to make it more convenient. Let's define bind so that we can use it later on.

open Distribution
let bind t ~f = map t ~f |> join
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

When actually writing code dealing with monads, it's often more convenient to use bind than map and join. In fact, many people introduce bind and the part of the definition of a monad instead of join, but either definition is equivalent since given one, you can define the other (e.g. let join tt = bind tt ~f:(fun t -> t).

Alright! We've now explained every line of the signature which Distribution exposes. You probably thought this day would never come.

Ok, so what? We've defined a bunch of really abstract stuff in the context of a concrete example, a probably distribution, and we've seen some extremely trivial manipulations which are not nearly cool enough to warrant all this hullabaloo.

I agree. Let's start the interesting part.

First off, let me define an infix operator for map and bind so that my code later on can be slightly less indented. And while we're at it, let's define a uniform function which takes a list of outcomes, and creates a distribution where each outcome is equally weighted. One last thing - I'm going to open Distribution so that I can write the functions we've been talking in an unqualified way, e.g. map instead of Distribution.map.

open Distribution
let ( >>| ) t f = map t ~f
let ( >>= ) t f = bind t ~f

let uniform : 'a list -> 'a t = fun l -> 
  let p = 1. /. float (List.length l) in
  List.map l ~f:(fun a -> p, a) 
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

With those handy tools, let's define a staple of probability questions.

The die.

let one_through_six = List.range 1 7;;
let die : int t = uniform one_through_six;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

If we want to print our die as a string, we can use to_string:

to_string string_of_int die |> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

to_string(string_of_int, die) |> print_endline;

What's the probability our die is >= 5?

map die ~f:(fun d -> d >= 5) 
|> to_string string_of_bool
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

map(die, ~f=d => d >= 5) |> to_string(string_of_bool) |> print_endline;

Too easy. I'm not impressed.

How about this? Let's define a distribution on two dice.

let two_dice = 
    die >>= fun d1 -> 
    die >>= fun d2 -> 
    return (d1, d2)
;; 
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

What's the probability that the sum of two dice is >= 10?

two_dice 
|> map ~f:(fun (d1, d2) -> d1 + d2)
|> map ~f:(fun sum_of_dice -> sum_of_dice >= 10)
|> to_string string_of_bool
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

two_dice
|> map(~f=((d1, d2)) => d1 + d2)
|> map(~f=sum_of_dice => sum_of_dice >= 10)
|> to_string(string_of_bool)
|> print_endline;

Finally something that you might not be able to immediately solve in your head. But let's find something that you definitely can't solve in your head.

The Newton–Pepys problem

In 1693 Samuel Pepys and Isaac Newton corresponded over a problem posed by Pepys in relation to a wager he planned to make. The problem was:

Which of the following three propositions has the greatest chance of success?

A. Six fair dice are tossed independently and at least one “6” appears.

B. Twelve fair dice are tossed independently and at least two “6”s appear.

C. Eighteen fair dice are tossed independently and at least three “6”s appear.

(* The probability distribution representing whether a single die comes up 6. *)
let single_die_is_six : bool t = die >>| fun d -> d = 6

(* The probability distribution representing the number of 6s in a single die. *)
let number_of_sixes_in_a_single_die : int t = 
  single_die_is_six
  >>| function | false -> 0 | true -> 1
  
(* The probability distribution describing how many sixes come up in 6 dice. *)
let num_sixes_in_six_dice : int t = 
  number_of_sixes_in_a_single_die >>= fun n1 -> 
  number_of_sixes_in_a_single_die >>= fun n2 -> 
  number_of_sixes_in_a_single_die >>= fun n3 -> 
  number_of_sixes_in_a_single_die >>= fun n4 -> 
  number_of_sixes_in_a_single_die >>= fun n5 -> 
  number_of_sixes_in_a_single_die >>= fun n6 -> 
  return (n1 + n2 + n3 + n4 + n5 + n6)

let num_sixes_in_twelve_dice = 
  num_sixes_in_six_dice >>= fun first_six -> 
  num_sixes_in_six_dice >>= fun second_six -> 
  return (first_six + second_six)

let num_sixes_in_eighteen_dice = 
  num_sixes_in_twelve_dice >>= fun first_twelve -> 
  num_sixes_in_six_dice >>= fun third_six -> 
  return (first_twelve + third_six)
;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
  single_die_is_six
  >>| (
    fun
    | false => 0
    | true => 1
  );

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
  number_of_sixes_in_a_single_die
  >>= (
    n1 =>
      number_of_sixes_in_a_single_die
      >>= (
        n2 =>
          number_of_sixes_in_a_single_die
          >>= (
            n3 =>
              number_of_sixes_in_a_single_die
              >>= (
                n4 =>
                  number_of_sixes_in_a_single_die
                  >>= (
                    n5 =>
                      number_of_sixes_in_a_single_die
                      >>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
                  )
              )
          )
      )
  );

let num_sixes_in_twelve_dice =
  num_sixes_in_six_dice
  >>= (
    first_six =>
      num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
  );

let num_sixes_in_eighteen_dice =
  num_sixes_in_twelve_dice
  >>= (
    first_twelve =>
      num_sixes_in_six_dice
      >>= (third_six => return(first_twelve + third_six))
  );

Choice A: Six fair dice are tossed independently and at least one “6” appears.

num_sixes_in_six_dice
|> map ~f:(fun n -> n >= 1)
|> to_string string_of_bool 
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
  single_die_is_six
  >>| (
    fun
    | false => 0
    | true => 1
  );

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
  number_of_sixes_in_a_single_die
  >>= (
    n1 =>
      number_of_sixes_in_a_single_die
      >>= (
        n2 =>
          number_of_sixes_in_a_single_die
          >>= (
            n3 =>
              number_of_sixes_in_a_single_die
              >>= (
                n4 =>
                  number_of_sixes_in_a_single_die
                  >>= (
                    n5 =>
                      number_of_sixes_in_a_single_die
                      >>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
                  )
              )
          )
      )
  );

let num_sixes_in_twelve_dice =
  num_sixes_in_six_dice
  >>= (
    first_six =>
      num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
  );

let num_sixes_in_eighteen_dice =
  num_sixes_in_twelve_dice
  >>= (
    first_twelve =>
      num_sixes_in_six_dice
      >>= (third_six => return(first_twelve + third_six))
  );

num_sixes_in_six_dice
|> map(~f=n => n >= 1)
|> to_string(string_of_bool)
|> print_endline;

Choice B: Twelve fair dice are tossed independently and at least two “6”s appear.

num_sixes_in_twelve_dice
|> map ~f:(fun n -> n >= 2) 
|> to_string string_of_bool 
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
  single_die_is_six
  >>| (
    fun
    | false => 0
    | true => 1
  );

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
  number_of_sixes_in_a_single_die
  >>= (
    n1 =>
      number_of_sixes_in_a_single_die
      >>= (
        n2 =>
          number_of_sixes_in_a_single_die
          >>= (
            n3 =>
              number_of_sixes_in_a_single_die
              >>= (
                n4 =>
                  number_of_sixes_in_a_single_die
                  >>= (
                    n5 =>
                      number_of_sixes_in_a_single_die
                      >>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
                  )
              )
          )
      )
  );

let num_sixes_in_twelve_dice =
  num_sixes_in_six_dice
  >>= (
    first_six =>
      num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
  );

let num_sixes_in_eighteen_dice =
  num_sixes_in_twelve_dice
  >>= (
    first_twelve =>
      num_sixes_in_six_dice
      >>= (third_six => return(first_twelve + third_six))
  );

num_sixes_in_twelve_dice
|> map(~f=n => n >= 2)
|> to_string(string_of_bool)
|> print_endline;

Choice C: Eighteen fair dice are tossed independently and at least three “6”s appear.

num_sixes_in_eighteen_dice
|> map ~f:(fun n -> n >= 3) 
|> to_string string_of_bool
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
  single_die_is_six
  >>| (
    fun
    | false => 0
    | true => 1
  );

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
  number_of_sixes_in_a_single_die
  >>= (
    n1 =>
      number_of_sixes_in_a_single_die
      >>= (
        n2 =>
          number_of_sixes_in_a_single_die
          >>= (
            n3 =>
              number_of_sixes_in_a_single_die
              >>= (
                n4 =>
                  number_of_sixes_in_a_single_die
                  >>= (
                    n5 =>
                      number_of_sixes_in_a_single_die
                      >>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
                  )
              )
          )
      )
  );

let num_sixes_in_twelve_dice =
  num_sixes_in_six_dice
  >>= (
    first_six =>
      num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
  );

let num_sixes_in_eighteen_dice =
  num_sixes_in_twelve_dice
  >>= (
    first_twelve =>
      num_sixes_in_six_dice
      >>= (third_six => return(first_twelve + third_six))
  );

num_sixes_in_eighteen_dice
|> map(~f=n => n >= 3)
|> to_string(string_of_bool)
|> print_endline;

As you can see here, that's correct!

Hopefully at this point you can see that we have some real power here. Using bind, map, return (and join, although not explicitly) we can start to answer some very real questions.

Let's tackle a more recent question, one that stumped quite a few mathematicians including Paul Erdos!

The Monty Hall Problem.

Here's wikipedia's formulation:

Suppose you're on a game show, and you're given the choice of three doors: Behind one door is a car; behind the others, goats. You pick a door, say No. 1, and the host, who knows what's behind the doors, opens another door, say No. 3, which has a goat. He then says to you, "Do you want to pick door No. 2?" Is it to your advantage to switch your choice?

If you haven't encountered this problem before, I highly recommending taking a minute to come up with an answer yourself before continuing.

First off, let's define another utility function, that turns out to be very handy in probability questions. condition will take a distribution and condition on a pariticular event happening.

let renormalize t =
  let total_prob t = List.sum (module Probability) t ~f:fst in
  let scale_by = 1. /. total_prob t in
  List.map t ~f:(fun (p, a) -> p *. scale_by, a)
;;

let condition : 'a t -> on:('a -> bool) -> 'a t
  = fun t ~on ->
  List.filter t ~f:(fun (_, a) -> on a)
  |> renormalize
;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

To gain a bit of intuition for what this function is doing, let's print out a distribution of two dice, conditioned on the first dice being either a one or a six.

two_dice
|> condition ~on:(fun (d1, d2) -> d1 = 1 || d1 = 6)
|> to_string (fun (d1, d2) -> Printf.sprintf "(%d, %d)" d1 d2)
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

two_dice
|> condition(~on=((d1, d2)) => d1 == 1 || d1 == 6)
|> to_string(((d1, d2)) => Printf.sprintf("(%d, %d)", d1, d2))
|> print_endline;

With that in our tool belt, let's work on the problem. I'll go a little faster this time.

module Monty_hall = struct
  type door = | A | B | C

  let string_of_door = function
    | A -> "A"
    | B -> "B"
    | C -> "C"
    
  let all_doors : door list = [A;B;C] 
  
  type outcome = | Win | Lose
  
  let string_of_outcome = function
    | Win -> "Win"
    | Lose -> "Lose"
    
  type state =
    { initial_pick : door
    ; prize : door
    ; opened : door
    ; final_pick : door
    }
    
  let string_of_state { initial_pick; prize; opened; final_pick } =
    Printf.sprintf "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }"
      (string_of_door initial_pick)
      (string_of_door prize)
      (string_of_door opened)
      (string_of_door final_pick)

  let outcome : state -> outcome = fun { final_pick; prize; _ } ->
    if prize = final_pick then Win else Lose
  
  (* defines the type of a strategy that could be used in the game *)
  type strategy = initial_pick:door -> opened : door -> door

  (* A strategy: never switch *)
  let never_switch : strategy = fun ~initial_pick ~opened:_ -> initial_pick

  (* A strategy: always switch *)
  let always_switch : strategy = fun ~initial_pick ~opened ->
    List.filter all_doors ~f:(fun door -> door <> initial_pick && door <> opened)
    |> List.hd_exn

  (* takes a strategy and returns you a distribution of states of the game *)
  let states : strategy -> state t = fun strat ->
    uniform all_doors >>= fun initial_pick ->
    uniform all_doors >>= fun prize ->
    let opened =
      (* The host will never open your door or the door with the prize *)
      List.filter all_doors ~f:(fun door -> door <> initial_pick && door <> prize)
      |> uniform
    in
    opened >>= fun opened ->
    (* the strategy defines what door will be the final choice *)
    let final_pick = strat ~initial_pick ~opened in
    let state = { initial_pick; prize; opened; final_pick } in
    return state
    
  let outcomes : strategy -> outcome t = fun strat ->
    states strat >>| outcome
end;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

With that in place, let's try to answer the question. If we never switch, here's the distribution of comes:

Monty_hall.outcomes Monty_hall.never_switch
|> to_string Monty_hall.string_of_outcome 
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

Monty_hall.outcomes(Monty_hall.never_switch)
|> to_string(Monty_hall.string_of_outcome)
|> print_endline;

And here it is if we always switch:

Monty_hall.outcomes Monty_hall.always_switch
|> to_string Monty_hall.string_of_outcome 
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

Monty_hall.outcomes(Monty_hall.always_switch)
|> to_string(Monty_hall.string_of_outcome)
|> print_endline;

So you should always switch! Not satisfied? Well, we can print out more than just the outcomes given each strategy. Here's the full distribution (along with outcome) of game states...

let print_states_with_outcome states =
  states
  |> map ~f:(fun state -> state, Monty_hall.outcome state)
  |> to_string (fun (state, outcome) -> 
      Printf.sprintf "%s: %s" 
        (Monty_hall.string_of_state state)
        (Monty_hall.string_of_outcome outcome))
  |> print_endline
;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

If you never switch:

print_states_with_outcome Monty_hall.(states never_switch)
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

print_states_with_outcome(Monty_hall.(states(never_switch)));

And if you always switch:

print_states_with_outcome Monty_hall.(states always_switch)
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

print_states_with_outcome(Monty_hall.(states(always_switch)));

The outcomes are symmetric around the initial door that you chose, so it might be easier to analyze the difference by conditioning on the cases when you initially pick door A.

let outcomes_and_states_conditioned_on_initially_picking_a strat = 
  Monty_hall.states strat
  |> condition ~on:(fun {Monty_hall.initial_pick; _} -> initial_pick = Monty_hall.A)
  |> print_states_with_outcome
;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

let outcomes_and_states_conditioned_on_initially_picking_a = strat =>
  Monty_hall.states(strat)
  |> condition(~on=({Monty_hall.initial_pick, _}) =>
       initial_pick == Monty_hall.A
     )
  |> print_states_with_outcome;

Outcomes & states of game if you never switch, conditioned on you initially picking door A.

outcomes_and_states_conditioned_on_initially_picking_a 
  Monty_hall.never_switch
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

let outcomes_and_states_conditioned_on_initially_picking_a = strat =>
  Monty_hall.states(strat)
  |> condition(~on=({Monty_hall.initial_pick, _}) =>
       initial_pick == Monty_hall.A
     )
  |> print_states_with_outcome;

outcomes_and_states_conditioned_on_initially_picking_a(
  Monty_hall.never_switch,
);

And if you always switch:

outcomes_and_states_conditioned_on_initially_picking_a 
  Monty_hall.always_switch
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Monty_hall = {
  type door =
    | A
    | B
    | C;
  let string_of_door =
    fun
    | A => "A"
    | B => "B"
    | C => "C";
  let all_doors: list(door) = [A, B, C];
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  type state = {
    initial_pick: door,
    prize: door,
    opened: door,
    final_pick: door,
  };
  let string_of_state = ({initial_pick, prize, opened, final_pick}) =>
    Printf.sprintf(
      "{ initial_pick: %s, prize: %s, opened: %s, final_pick: %s }",
      string_of_door(initial_pick),
      string_of_door(prize),
      string_of_door(opened),
      string_of_door(final_pick),
    );
  let outcome: state => outcome =
    ({final_pick, prize, _}) =>
      if (prize == final_pick) {
        Win;
      } else {
        Lose;
      };
  /* defines the type of a strategy that could be used in the game */
  type strategy = (~initial_pick: door, ~opened: door) => door;
  /* A strategy: never switch */
  let never_switch: strategy = (~initial_pick, ~opened as _) => initial_pick;
  /* A strategy: always switch */
  let always_switch: strategy =
    (~initial_pick, ~opened) =>
      List.filter(all_doors, ~f=door => door != initial_pick && door != opened)
      |> List.hd_exn;
  /* takes a strategy and returns you a distribution of states of the game */
  let states: strategy => t(state) =
    strat =>
      uniform(all_doors)
      >>= (
        initial_pick =>
          uniform(all_doors)
          >>= (
            prize => {
              let opened =
                /* The host will never open your door or the door with the prize */
                List.filter(all_doors, ~f=door =>
                  door != initial_pick && door != prize
                )
                |> uniform;
              opened
              >>= (
                opened => {
                  /* the strategy defines what door will be the final choice */
                  let final_pick = strat(~initial_pick, ~opened);
                  let state = {initial_pick, prize, opened, final_pick};
                  return(state);
                }
              );
            }
          )
      );
  let outcomes: strategy => t(outcome) = strat => states(strat) >>| outcome;
};

let print_states_with_outcome = states =>
  states
  |> map(~f=state => (state, Monty_hall.outcome(state)))
  |> to_string(((state, outcome)) =>
       Printf.sprintf(
         "%s: %s",
         Monty_hall.string_of_state(state),
         Monty_hall.string_of_outcome(outcome),
       )
     )
  |> print_endline;

let outcomes_and_states_conditioned_on_initially_picking_a = strat =>
  Monty_hall.states(strat)
  |> condition(~on=({Monty_hall.initial_pick, _}) =>
       initial_pick == Monty_hall.A
     )
  |> print_states_with_outcome;

outcomes_and_states_conditioned_on_initially_picking_a(
  Monty_hall.always_switch,
);

Tree growth

Here's another example from this paper. Consider the simple example of tree growth. Assume a tree can grow between one and five feet in height every year. Also assume that it is possible, although less likely, that a tree could fall down in a storm or be hit by lightning, which we assume would kill it standing. How can this be represented using probabilistic functions?

module Tree = struct
  type height = int

  type state =
    | Alive of height
    | Dead of height
  
  let string_of_state = function
    | Alive h -> Printf.sprintf "Alive and %d feet tall" h
    | Dead h  -> Printf.sprintf "Dead and %d feet tall" h

  let maybe_killed : [ `Not_killed | `Storm | `Lightning ] t =
     [ 0.95, `Not_killed
     ; 0.04, `Storm
     ; 0.01, `Lightning
     ]
  
  (* Takes the initial state and turns it into a distribution over states
     one year later *)
  let one_year : state -> state t = fun state ->
    match state with
    | Dead h -> return (Dead h)
    | Alive h -> 
      maybe_killed >>= function
      | `Lightning -> return (Dead h)
      | `Storm -> return (Dead 0) (* fell down *)
      | `Not_killed ->
        uniform (List.range 1 6)
        >>= fun grew_by ->
        return (Alive (h + grew_by))

  let n_years ~n =
    let rec loop s n =
      if n = 0 then return s
      else
        one_year s >>= fun s ->
        loop s (n-1)
    in
    loop (Alive 0) n
end;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tree = {
  type height = int;
  type state =
    | Alive(height)
    | Dead(height);
  let string_of_state =
    fun
    | Alive(h) => Printf.sprintf("Alive and %d feet tall", h)
    | Dead(h) => Printf.sprintf("Dead and %d feet tall", h);
  let maybe_killed: t([ | `Not_killed | `Storm | `Lightning]) = [
    (0.95, `Not_killed),
    (0.04, `Storm),
    (0.01, `Lightning),
  ];
  /* Takes the initial state and turns it into a distribution over states
     one year later */
  let one_year: state => t(state) =
    state =>
      switch (state) {
      | Dead(h) => return(Dead(h))
      | Alive(h) =>
        maybe_killed
        >>= (
          fun
          | `Lightning => return(Dead(h))
          | `Storm => return(Dead(0)) /* fell down */
          | `Not_killed =>
            uniform(List.range(1, 6))
            >>= (grew_by => return(Alive(h + grew_by)))
        )
      };
  let n_years = (~n) => {
    let rec loop = (s, n) =>
      if (n == 0) {
        return(s);
      } else {
        one_year(s) >>= (s => loop(s, n - 1));
      };
    loop(Alive(0), n);
  };
};

Now that we've modeled this scenario using our distribution monad, let's ask some questions. How about: What's the distribution of states after three years?

Tree.n_years ~n:3
|> to_string Tree.string_of_state 
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tree = {
  type height = int;
  type state =
    | Alive(height)
    | Dead(height);
  let string_of_state =
    fun
    | Alive(h) => Printf.sprintf("Alive and %d feet tall", h)
    | Dead(h) => Printf.sprintf("Dead and %d feet tall", h);
  let maybe_killed: t([ | `Not_killed | `Storm | `Lightning]) = [
    (0.95, `Not_killed),
    (0.04, `Storm),
    (0.01, `Lightning),
  ];
  /* Takes the initial state and turns it into a distribution over states
     one year later */
  let one_year: state => t(state) =
    state =>
      switch (state) {
      | Dead(h) => return(Dead(h))
      | Alive(h) =>
        maybe_killed
        >>= (
          fun
          | `Lightning => return(Dead(h))
          | `Storm => return(Dead(0)) /* fell down */
          | `Not_killed =>
            uniform(List.range(1, 6))
            >>= (grew_by => return(Alive(h + grew_by)))
        )
      };
  let n_years = (~n) => {
    let rec loop = (s, n) =>
      if (n == 0) {
        return(s);
      } else {
        one_year(s) >>= (s => loop(s, n - 1));
      };
    loop(Alive(0), n);
  };
};

Tree.n_years(~n=3) |> to_string(Tree.string_of_state) |> print_endline;

And maybe also: conditioned on it dying within the first 5 years, how tall is it?

Tree.n_years ~n:5
|> condition ~on:(function | Tree.Alive _ -> false | Dead _ -> true)
|> map ~f:(function | Tree.Alive _ -> assert false | Dead h -> h)
|> to_string string_of_int
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tree = {
  type height = int;
  type state =
    | Alive(height)
    | Dead(height);
  let string_of_state =
    fun
    | Alive(h) => Printf.sprintf("Alive and %d feet tall", h)
    | Dead(h) => Printf.sprintf("Dead and %d feet tall", h);
  let maybe_killed: t([ | `Not_killed | `Storm | `Lightning]) = [
    (0.95, `Not_killed),
    (0.04, `Storm),
    (0.01, `Lightning),
  ];
  /* Takes the initial state and turns it into a distribution over states
     one year later */
  let one_year: state => t(state) =
    state =>
      switch (state) {
      | Dead(h) => return(Dead(h))
      | Alive(h) =>
        maybe_killed
        >>= (
          fun
          | `Lightning => return(Dead(h))
          | `Storm => return(Dead(0)) /* fell down */
          | `Not_killed =>
            uniform(List.range(1, 6))
            >>= (grew_by => return(Alive(h + grew_by)))
        )
      };
  let n_years = (~n) => {
    let rec loop = (s, n) =>
      if (n == 0) {
        return(s);
      } else {
        one_year(s) >>= (s => loop(s, n - 1));
      };
    loop(Alive(0), n);
  };
};

Tree.n_years(~n=5)
|> condition(
     ~on=
       fun
       | Tree.Alive(_) => false
       | Dead(_) => true,
   )
|> map(
     ~f=
       fun
       | Tree.Alive(_) => assert false
       | Dead(h) => h,
   )
|> to_string(string_of_int)
|> print_endline;

Tennis

Ok, one last example. Let's model a tennis game where I have a p probability of winning each point.

module Tennis = struct
  type outcome = | Win | Lose
  
  let string_of_outcome = function
    | Win -> "Win"
    | Lose -> "Lose"

  (* This function takes the probability the outcome is a Win as an argument,
     and returns the distribution with that win probability. *)
  let point ~p = [ p, Win; (1. -. p), Lose ]

  let play_to_n_and_win_by_2 ~n ~point =
    let rec play ~my_points ~other_points =
      point >>= fun point ->
      let my_points, other_points =
        match point with
        | Win -> my_points + 1, other_points
        | Lose -> my_points, other_points + 1
      in
      if my_points >= n && my_points - other_points >= 2
      then return Win
      else if other_points >= n && other_points - my_points >= 2
      then return Lose
      (* I'm dealing with the problem of an infinite tennis game by just
         saying that if the sum of the points is > 30, then just say we each
         win with 50%.  This is an unlikely enough event that it will not
         affect the distribution too much and that's likely close to correct
         in this case anyways. *)
      else if my_points + other_points > 30
      then [ 0.5, Win; 0.5, Lose ]
      else play ~my_points ~other_points
    in
    play ~my_points:0 ~other_points:0
  ;;
  
  let game ~point =
    play_to_n_and_win_by_2 ~n:4 ~point

  let set ~game =
    play_to_n_and_win_by_2 ~n:6 ~point:game
end;;
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tennis = {
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  /* This function takes the probability the outcome is a Win as an argument,
     and returns the distribution with that win probability. */
  let point = (~p) => [(p, Win), (1. -. p, Lose)];
  let play_to_n_and_win_by_2 = (~n, ~point) => {
    let rec play = (~my_points, ~other_points) =>
      point
      >>= (
        point => {
          let (my_points, other_points) =
            switch (point) {
            | Win => (my_points + 1, other_points)
            | Lose => (my_points, other_points + 1)
            };
          if (my_points >= n && my_points - other_points >= 2) {
            return(Win);
          } else if (other_points >= n && other_points - my_points >= 2) {
            return(Lose);
          } else if
            /* I'm dealing with the problem of an infinite tennis game by just
               saying that if the sum of the points is > 30, then just say we each
               win with 50%.  This is an unlikely enough event that it will not
               affect the distribution too much and that's likely close to correct
               in this case anyways. */
            (my_points + other_points > 30) {
            [(0.5, Win), (0.5, Lose)];
          } else {
            play(~my_points, ~other_points);
          };
        }
      );
    play(~my_points=0, ~other_points=0);
  };
  let game = (~point) => play_to_n_and_win_by_2(~n=4, ~point);
  let set = (~game) => play_to_n_and_win_by_2(~n=6, ~point=game);
};

As a sanity check, let's see what the probability is that I win a game, given the probability I win each point is 50%. It should obviously be 50%.

let point = Tennis.point ~p:0.5 in
Tennis.game ~point
|> to_string Tennis.string_of_outcome
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tennis = {
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  /* This function takes the probability the outcome is a Win as an argument,
     and returns the distribution with that win probability. */
  let point = (~p) => [(p, Win), (1. -. p, Lose)];
  let play_to_n_and_win_by_2 = (~n, ~point) => {
    let rec play = (~my_points, ~other_points) =>
      point
      >>= (
        point => {
          let (my_points, other_points) =
            switch (point) {
            | Win => (my_points + 1, other_points)
            | Lose => (my_points, other_points + 1)
            };
          if (my_points >= n && my_points - other_points >= 2) {
            return(Win);
          } else if (other_points >= n && other_points - my_points >= 2) {
            return(Lose);
          } else if
            /* I'm dealing with the problem of an infinite tennis game by just
               saying that if the sum of the points is > 30, then just say we each
               win with 50%.  This is an unlikely enough event that it will not
               affect the distribution too much and that's likely close to correct
               in this case anyways. */
            (my_points + other_points > 30) {
            [(0.5, Win), (0.5, Lose)];
          } else {
            play(~my_points, ~other_points);
          };
        }
      );
    play(~my_points=0, ~other_points=0);
  };
  let game = (~point) => play_to_n_and_win_by_2(~n=4, ~point);
  let set = (~game) => play_to_n_and_win_by_2(~n=6, ~point=game);
};

{
  let point = Tennis.point(~p=0.5);
  Tennis.game(~point) |> to_string(Tennis.string_of_outcome) |> print_endline;
};

How about if I win each point with probability 55%? How likely am I to win the game?

let point = Tennis.point ~p:0.55 in
Tennis.game ~point
|> to_string Tennis.string_of_outcome
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tennis = {
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  /* This function takes the probability the outcome is a Win as an argument,
     and returns the distribution with that win probability. */
  let point = (~p) => [(p, Win), (1. -. p, Lose)];
  let play_to_n_and_win_by_2 = (~n, ~point) => {
    let rec play = (~my_points, ~other_points) =>
      point
      >>= (
        point => {
          let (my_points, other_points) =
            switch (point) {
            | Win => (my_points + 1, other_points)
            | Lose => (my_points, other_points + 1)
            };
          if (my_points >= n && my_points - other_points >= 2) {
            return(Win);
          } else if (other_points >= n && other_points - my_points >= 2) {
            return(Lose);
          } else if
            /* I'm dealing with the problem of an infinite tennis game by just
               saying that if the sum of the points is > 30, then just say we each
               win with 50%.  This is an unlikely enough event that it will not
               affect the distribution too much and that's likely close to correct
               in this case anyways. */
            (my_points + other_points > 30) {
            [(0.5, Win), (0.5, Lose)];
          } else {
            play(~my_points, ~other_points);
          };
        }
      );
    play(~my_points=0, ~other_points=0);
  };
  let game = (~point) => play_to_n_and_win_by_2(~n=4, ~point);
  let set = (~game) => play_to_n_and_win_by_2(~n=6, ~point=game);
};

{
  let point = Tennis.point(~p=0.55);
  Tennis.game(~point) |> to_string(Tennis.string_of_outcome) |> print_endline;
};

How about the set?

let point = Tennis.point ~p:0.55 in
Tennis.set ~game:(Tennis.game ~point)
|> to_string Tennis.string_of_outcome
|> print_endline
module Monad = {
  module type S = {
    type t('a);
    let return: 'a => t('a);
    let map: (t('a), ~f: 'a => 'b) => t('b);
    let join: t(t('a)) => t('a);
  };
};

module type Commutative_group = {
  type t;
  let zero: t;
  let (+): (t, t) => t;
  let (-): (t, t) => t;
};

module List = {
  include List;
  let hd_exn = hd;
  let map = (t, ~f) => map(f, t);
  let bind = (t, ~f) => map(t, ~f) |> concat;
  let filter = (t, ~f) => filter(f, t);
  let count = (t, ~f) => filter(t, ~f) |> List.length;
  let rec range = (start, stop) =>
    if (start == stop) {
      [];
    } else {
      [start, ...range(start + 1, stop)];
    };
  let group = (t, ~break) => {
    let rec loop = (acc, curr_group, last, rest) =>
      switch (rest) {
      | [] => [curr_group, ...acc]
      | [next, ...rest] =>
        switch (last) {
        | None => loop(acc, [next, ...curr_group], Some(next), rest)
        | Some(last) =>
          if (break(last, next)) {
            loop([curr_group, ...acc], [next], Some(next), rest);
          } else {
            loop(acc, [next, ...curr_group], Some(next), rest);
          }
        }
      };
    loop([], [], None, t) |> rev;
  };
  let sum =
      (type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
    map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
  type t = float;
  let zero = 0.;
  let (+) = (+.);
  let (-) = (-.);
};

module Distribution_implementation = {
  type t('a) = list((Probability.t, 'a));
  let return = a => [(1., a)];
  /* compare just the 'a states, not the probabilities */
  let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
  let group_states: t('a) => t('a) =
    t =>
      List.sort(compare_states, t)
      |> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
      |> List.map(~f=group => {
           /* each group is going to be non-empty */
           let (_, a) = List.hd(group);
           let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
           (p, a);
         });
  let map = (t, ~f) =>
    List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
  let join = t =>
    List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
    |> group_states;
  let to_string = (~decimals=3, f, t) => {
    let string_of_float = f => {
      let s = Printf.sprintf("%0.*f", decimals, f);
      let is_exact = float_of_string(s) == f;
      if (is_exact) {
        s;
      } else {
        "~" ++ s;
      };
    };
    String.concat(
      "\n",
      List.map(t, ~f=((p, a)) =>
        f(a) ++ ":\t" ++ string_of_float(p) ++ " "
      ),
    );
  };
};

module Distribution: {
  type t('a) = list((Probability.t, 'a));
  let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
  include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
  l => {
    let p = 1. /. float(List.length(l));
    List.map(l, ~f=a => (p, a));
  };

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
  let total_prob = t => List.sum((module Probability), t, ~f=fst);
  let scale_by = 1. /. total_prob(t);
  List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
  (t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;

module Tennis = {
  type outcome =
    | Win
    | Lose;
  let string_of_outcome =
    fun
    | Win => "Win"
    | Lose => "Lose";
  /* This function takes the probability the outcome is a Win as an argument,
     and returns the distribution with that win probability. */
  let point = (~p) => [(p, Win), (1. -. p, Lose)];
  let play_to_n_and_win_by_2 = (~n, ~point) => {
    let rec play = (~my_points, ~other_points) =>
      point
      >>= (
        point => {
          let (my_points, other_points) =
            switch (point) {
            | Win => (my_points + 1, other_points)
            | Lose => (my_points, other_points + 1)
            };
          if (my_points >= n && my_points - other_points >= 2) {
            return(Win);
          } else if (other_points >= n && other_points - my_points >= 2) {
            return(Lose);
          } else if
            /* I'm dealing with the problem of an infinite tennis game by just
               saying that if the sum of the points is > 30, then just say we each
               win with 50%.  This is an unlikely enough event that it will not
               affect the distribution too much and that's likely close to correct
               in this case anyways. */
            (my_points + other_points > 30) {
            [(0.5, Win), (0.5, Lose)];
          } else {
            play(~my_points, ~other_points);
          };
        }
      );
    play(~my_points=0, ~other_points=0);
  };
  let game = (~point) => play_to_n_and_win_by_2(~n=4, ~point);
  let set = (~game) => play_to_n_and_win_by_2(~n=6, ~point=game);
};

{
  let point = Tennis.point(~p=0.55);
  Tennis.set(~game=Tennis.game(~point))
  |> to_string(Tennis.string_of_outcome)
  |> print_endline;
};

And that's all I've got. I hope this serves as a useful example of a monad with practical applications and at least partially demonstrates the beauty and complexity that you can create starting from such simple combinator functions as return, map, join, and bind. Thanks for reading!