This page becomes unresponsive for a few seconds after loading to evaluate the code samples below. Don't worry, the page will be responsive after that.

Home
⚠️ Oops! This page doesn't appear to define a type called _.

This post's purpose is two-fold. First, it's a hat-tip to this beautful paper. Second, it can serve as a very concrete example of a monad. An idea (that others have articulated better than I can) is that when trying to teach an abstract concept, it's better to start with concrete examples and gradually move towards abstraction and generalization rather than the other way around.

One thing to point out up-front is that the code samples below are written in OCaml, and are editable/running (thanks to redoc), so if you feel the urge to get your hands dirty and change some code in order to see the effect - go for it!

With that in mind, let's immediately break the rule of starting concrete and lay out the formal definition of a monad. If it means nothing or very little to you at this point in the post, don't worry. I will explain what return, map, and join mean intuitively, but only later, in the context of our concrete example: the probability distribution.

module type S = sig
type 'a t
val return : 'a -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val join : 'a t t -> 'a t
end
end
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

Let's define a probability type. This is unnecessary, but it allows me to refer to probabilities as Probability.ts later on, rather than just floats, which hopefully adds a small bit of readability.

module Probability = struct
type t = float
let zero = 0.
let (+) = (+.)
let (-) = (-.)
end
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = struct
type 'a t = (Probability.t * 'a) list
let return a = [ 1., a ]

(* compare just the 'a states, not the probabilities *)
let compare_states (_, a1) (_, a2) =
compare a1 a2

let group_states : 'a t -> 'a t = fun t ->
List.sort compare_states t
|> List.group ~break:(fun t1 t2 -> compare_states t1 t2 <> 0)
|> List.map ~f:(fun group ->
(* each group is going to be non-empty *)
let (_, a) = List.hd group in
let p = List.fold_left (fun acc (p, _) -> acc +. p) 0. group in
p, a)

let map t ~f =
List.map t ~f:(fun (p, a) -> p, f a)
|> group_states

let join t =
List.bind t ~f:(fun (p, t) -> List.map t ~f:(fun (p1, a) -> p *. p1, a))
|> group_states

let to_string ?(decimals = 3) f t =
let string_of_float f =
let s = Printf.sprintf "%0.*f" decimals f in
let is_exact = float_of_string s = f in
if is_exact then s else "~" ^ s
in
String.concat "\n"
(List.map t ~f:(fun (p, a) -> f a ^ ":\t" ^ string_of_float p ^ " "))
end
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

Now back to the good part - let's define the main type that we'll be dealing with: a Distribution.t.

module Distribution : sig
type 'a t = (Probability.t * 'a) list
val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string
include Monad.S with type 'a t := 'a t
end = Distribution_implementation
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

What is Distribution_implementation and where is it defined? Well, I hid it. If you really want to see it, fine - click here to show it, but it's a big chunk of code that's not super interesting or important. What is important is the signature which Distribution exposes. Let's explain each line:

type 'a t = (Probability.t * 'a) list: This says that a Distribution.t of 'as (read: possible outcomes) is really a list of 'as with an associated probability. Again, let's make this concrete. Let's say our distribution represent the weather tomorrow, and to keep it simple, let's say there are two possible outcomes: Rain or Shine. If we wanted to create a Distribution.t that represents an 80% chance of Shine and a 20% chance of Rain, ... actually, let's just do that:

type rain_or_shine = Rain | Shine;;
let tomorrow's_weather : rain_or_shine Distribution.t =
[ 0.8, Shine; 0.2, Rain ]
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

You can see that the type of tomorrow's_weather is a rain_or_shine Distribution.t, but really it's just a list of possible outcomes along with their associated probabilities. And, just to explicitly connect the dots, the 'a in this case is rain_or_shine.

val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string: This line probably needs the least explaining. It says that there exists a function Distribution.to_string that, given a function from 'a -> string can produce a string from a 'a distribution. Although probably unnecessary, let's take it for a spin.

let rain_or_shine_to_string = function
| Rain -> "Rain"
| Shine -> "Shine";;

Distribution.to_string rain_or_shine_to_string tomorrow's_weather
|> print_endline;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

let rain_or_shine_to_string =
fun
| Rain => "Rain"
| Shine => "Shine";

Distribution.to_string(rain_or_shine_to_string, tomorrow's_weather)
|> print_endline;

Simple enough.

include Monad.S with type 'a t := 'a t: Ok, there's kind of a lot to unpack in this one line. include Monad.S really just means paste the signature Monad.S here, so let's do that and explain each of those lines:

module type S = sig
type 'a t
val return : 'a -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val join : 'a t t -> 'a t
end
end
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

Great, here's where we can explain what these very abstract functions mean in the context of a distribution.

val return : 'a -> 'a t: return says give me an outcome, and I'll give you a distribution on that outcome. Continuing with our weather example, we could say Distribution.return Shine. The probabilities of a distribution have to add up to 100%, so what's the only reasonable thing this could mean? Well, return will create a "Distribution" where the outcome that you passed in happens with a 100% probability.

val map : 'a t -> f:('a -> 'b) -> 'b t: map says give me a distribution of outcomes of type 'a and a way to map 'as to 'bs, and I'll give you a distribution of 'bs. Let's imagine a way we might use this function on tomorrow's_weather. Let's say we're trying to figure out whether or not we can play soccer in the park tomorrow. Let's use map!

let can_play_soccer_in_park rain_or_shine =
match rain_or_shine with
| Rain -> false
| Shine -> true;;

let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain => false
| Shine => true
};

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);

So, we created a boolean distribution (which repesents whether or not we can play soccer in the park tomorrow) by mapping an existing distribution (tomorrow's weather) with a function that maps weather to whether or not you can play soccer in the park. A bit of a mouthful, but conceptually pretty straightforward.

val join : 'a t t -> 'a t: Saving the best for last. join is what really differentiates a monad. In our case, it means that a distribution of distributions is still just a distribution. No wonder people find this confusing... Let's carry on with our strategy of making this more and more concrete until it's actually intelligible.

What if our function can_play_soccer_in_park was a bit too simplistic. Even if it's sunny tomorrow (i.e. Shine), there's still some chance, albeit small, that you can't play soccer in the park. What if the park is closed? What if other people are using the park all day long? So, let's account for this and update can_play_soccer_in_park to return a boolean distribution, as opposed to a boolean.

let can_play_soccer_in_park rain_or_shine =
match rain_or_shine with
| Rain ->
(* we definitely can't play if it's raining *)
Distribution.return false
| Shine ->
(* let's say we can play with a 95% probability *)
[ 0.95, true; 0.05, false ]
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};

Ok, and let's do what we did before and map tomorrow's_weather with our new and improved function.

let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
Type Error: File "", line 102, characters 2-64: Error: This expression has type bool Distribution.t Distribution.t = (Probability.t * bool Distribution.t) list but an expression was expected of type bool Distribution.t = (Probability.t * bool) list Type bool Distribution.t = (Probability.t * bool) list is not compatible with type bool
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);
Type Error: File "", line 124, characters 2-66: Error: This expression has type bool Distribution.t Distribution.t = (Probability.t * bool Distribution.t) list but an expression was expected of type bool Distribution.t = (Probability.t * bool) list Type bool Distribution.t = (Probability.t * bool) list is not compatible with type bool

Wait - what happened? Why didn't that work like before? Well, ignoring the confusing line numbers, it says that you expected can_play_soccer_in_park_tomorrow to be of type bool Distribution.t but it's actually bool Distribution.t Distribution.t, i.e. it's a distribution of boolean distributions!

If that's still confusing, think about it this way. With 80% chance, the weather tomorrow is Shine, and if the weather is Shine then there's an 95% chance you can play soccer in the park and a 5% chance you can't. So, you can see how there are two layers of probability distributions. But, intuitively, we know we can collapse these into just one. If I ask you "what's the chance you can play in the park tomorrow?" - there's a single answer: 80% * 95% + 20% * 0% = 76%.

That's exactly what join does. It flattens two levels of a monad - a distribution in our case - into one level. Let's see it in action.

let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
|> Distribution.join
;;

Distribution.to_string string_of_bool can_play_soccer_in_park_tomorrow
|> print_endline;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

type rain_or_shine =
| Rain
| Shine;

let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];

let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};

let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park)
|> Distribution.join;

Distribution.to_string(string_of_bool, can_play_soccer_in_park_tomorrow)
|> print_endline;

At this point, it's probably smart to introduce a new operation bind. This pattern of calling map and then join immediately afterwards is so common that there's an operator to make it more convenient. Let's define bind so that we can use it later on.

open Distribution
let bind t ~f = map t ~f |> join
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

When actually writing code dealing with monads, it's often more convenient to use bind than map and join. In fact, many people introduce bind and the part of the definition of a monad instead of join, but either definition is equivalent since given one, you can define the other (e.g. let join tt = bind tt ~f:(fun t -> t).

Alright! We've now explained every line of the signature which Distribution exposes. You probably thought this day would never come.

Ok, so what? We've defined a bunch of really abstract stuff in the context of a concrete example, a probably distribution, and we've seen some extremely trivial manipulations which are not nearly cool enough to warrant all this hullabaloo.

I agree. Let's start the interesting part.

First off, let me define an infix operator for map and bind so that my code later on can be slightly less indented. And while we're at it, let's define a uniform function which takes a list of outcomes, and creates a distribution where each outcome is equally weighted. One last thing - I'm going to open Distribution so that I can write the functions we've been talking in an unqualified way, e.g. map instead of Distribution.map.

open Distribution
let ( >>| ) t f = map t ~f
let ( >>= ) t f = bind t ~f

let uniform : 'a list -> 'a t = fun l ->
let p = 1. /. float (List.length l) in
List.map l ~f:(fun a -> p, a)
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

With those handy tools, let's define a staple of probability questions.

## The die.

let one_through_six = List.range 1 7;;
let die : int t = uniform one_through_six;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

If we want to print our die as a string, we can use to_string:

to_string string_of_int die |> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

to_string(string_of_int, die) |> print_endline;

What's the probability our die is >= 5?

map die ~f:(fun d -> d >= 5)
|> to_string string_of_bool
|> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

map(die, ~f=d => d >= 5) |> to_string(string_of_bool) |> print_endline;

Too easy. I'm not impressed.

let two_dice =
die >>= fun d1 ->
die >>= fun d2 ->
return (d1, d2)
;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

What's the probability that the sum of two dice is >= 10?

two_dice
|> map ~f:(fun (d1, d2) -> d1 + d2)
|> map ~f:(fun sum_of_dice -> sum_of_dice >= 10)
|> to_string string_of_bool
|> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

two_dice
|> map(~f=((d1, d2)) => d1 + d2)
|> map(~f=sum_of_dice => sum_of_dice >= 10)
|> to_string(string_of_bool)
|> print_endline;

Finally something that you might not be able to immediately solve in your head. But let's find something that you definitely can't solve in your head.

## The Newton–Pepys problem

In 1693 Samuel Pepys and Isaac Newton corresponded over a problem posed by Pepys in relation to a wager he planned to make. The problem was:

Which of the following three propositions has the greatest chance of success?

A. Six fair dice are tossed independently and at least one “6” appears.

B. Twelve fair dice are tossed independently and at least two “6”s appear.

C. Eighteen fair dice are tossed independently and at least three “6”s appear.

(* The probability distribution representing whether a single die comes up 6. *)
let single_die_is_six : bool t = die >>| fun d -> d = 6

(* The probability distribution representing the number of 6s in a single die. *)
let number_of_sixes_in_a_single_die : int t =
single_die_is_six
>>| function | false -> 0 | true -> 1

(* The probability distribution describing how many sixes come up in 6 dice. *)
let num_sixes_in_six_dice : int t =
number_of_sixes_in_a_single_die >>= fun n1 ->
number_of_sixes_in_a_single_die >>= fun n2 ->
number_of_sixes_in_a_single_die >>= fun n3 ->
number_of_sixes_in_a_single_die >>= fun n4 ->
number_of_sixes_in_a_single_die >>= fun n5 ->
number_of_sixes_in_a_single_die >>= fun n6 ->
return (n1 + n2 + n3 + n4 + n5 + n6)

let num_sixes_in_twelve_dice =
num_sixes_in_six_dice >>= fun first_six ->
num_sixes_in_six_dice >>= fun second_six ->
return (first_six + second_six)

let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice >>= fun first_twelve ->
num_sixes_in_six_dice >>= fun third_six ->
return (first_twelve + third_six)
;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);

let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);

let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);

Choice A: Six fair dice are tossed independently and at least one “6” appears.

num_sixes_in_six_dice
|> map ~f:(fun n -> n >= 1)
|> to_string string_of_bool
|> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);

let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);

let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);

num_sixes_in_six_dice
|> map(~f=n => n >= 1)
|> to_string(string_of_bool)
|> print_endline;

Choice B: Twelve fair dice are tossed independently and at least two “6”s appear.

num_sixes_in_twelve_dice
|> map ~f:(fun n -> n >= 2)
|> to_string string_of_bool
|> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);

let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);

let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);

num_sixes_in_twelve_dice
|> map(~f=n => n >= 2)
|> to_string(string_of_bool)
|> print_endline;

Choice C: Eighteen fair dice are tossed independently and at least three “6”s appear.

num_sixes_in_eighteen_dice
|> map ~f:(fun n -> n >= 3)
|> to_string string_of_bool
|> print_endline
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);

/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);

/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);

let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);

let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);

num_sixes_in_eighteen_dice
|> map(~f=n => n >= 3)
|> to_string(string_of_bool)
|> print_endline;

As you can see here, that's correct!

Hopefully at this point you can see that we have some real power here. Using bind, map, return (and join, although not explicitly) we can start to answer some very real questions.

Let's tackle a more recent question, one that stumped quite a few mathematicians including Paul Erdos!

## The Monty Hall Problem.

Here's wikipedia's formulation:

Suppose you're on a game show, and you're given the choice of three doors: Behind one door is a car; behind the others, goats. You pick a door, say No. 1, and the host, who knows what's behind the doors, opens another door, say No. 3, which has a goat. He then says to you, "Do you want to pick door No. 2?" Is it to your advantage to switch your choice?

If you haven't encountered this problem before, I highly recommending taking a minute to come up with an answer yourself before continuing.

First off, let's define another utility function, that turns out to be very handy in probability questions. condition will take a distribution and condition on a pariticular event happening.

let renormalize t =
let total_prob t = List.sum (module Probability) t ~f:fst in
let scale_by = 1. /. total_prob t in
List.map t ~f:(fun (p, a) -> p *. scale_by, a)
;;

let condition : 'a t -> on:('a -> bool) -> 'a t
= fun t ~on ->
List.filter t ~f:(fun (_, a) -> on a)
|> renormalize
;;
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};

module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};

module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};

module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};

module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};

module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;

open Distribution;

let bind = (t, ~f) => map(t, ~f) |> join;

open Distribution;

let (>>|) = (t, f) => map(t, ~f);

let (>>=) = (t, f) => bind(t, ~f);

let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};

let one_through_six = List.range(1, 7);

let die: t(int) = uniform(one_through_six);

let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));

let renormalize = t => {
let total_prob = t => List.sum((module Probability), t, ~f=fst);
let scale_by = 1. /. total_prob(t);
List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};

let condition: (t('a), ~on: 'a => bool) => t('a) =
(t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;