The Distribution Monad
_
.
This post's purpose is two-fold. First, it's a hat-tip to this beautful paper. Second, it can serve as a very concrete example of a monad. An idea (that others have articulated better than I can) is that when trying to teach an abstract concept, it's better to start with concrete examples and gradually move towards abstraction and generalization rather than the other way around.
One thing to point out up-front is that the code samples below are written in OCaml, and are editable/running (thanks to redoc), so if you feel the urge to get your hands dirty and change some code in order to see the effect - go for it!
If you do want to learn more about monads, I would recommend the following posts:
With that in mind, let's immediately break the rule of starting concrete and lay out the formal definition of a monad. If it means nothing or very little to you at this point in the post, don't worry. I will explain what return
, map
, and join
mean intuitively, but only later, in the context of our concrete example: the probability distribution.
module Monad = struct
module type S = sig
type 'a t
val return : 'a -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val join : 'a t t -> 'a t
end
end
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
Let's define a probability type. This is unnecessary, but it allows me to refer to probabilities as Probability.t
s later on, rather than just float
s, which hopefully adds a small bit of readability.
module Probability = struct
type t = float
let zero = 0.
let (+) = (+.)
let (-) = (-.)
end
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = struct
type 'a t = (Probability.t * 'a) list
let return a = [ 1., a ]
(* compare just the 'a states, not the probabilities *)
let compare_states (_, a1) (_, a2) =
compare a1 a2
let group_states : 'a t -> 'a t = fun t ->
List.sort compare_states t
|> List.group ~break:(fun t1 t2 -> compare_states t1 t2 <> 0)
|> List.map ~f:(fun group ->
(* each group is going to be non-empty *)
let (_, a) = List.hd group in
let p = List.fold_left (fun acc (p, _) -> acc +. p) 0. group in
p, a)
let map t ~f =
List.map t ~f:(fun (p, a) -> p, f a)
|> group_states
let join t =
List.bind t ~f:(fun (p, t) -> List.map t ~f:(fun (p1, a) -> p *. p1, a))
|> group_states
let to_string ?(decimals = 3) f t =
let string_of_float f =
let s = Printf.sprintf "%0.*f" decimals f in
let is_exact = float_of_string s = f in
if is_exact then s else "~" ^ s
in
String.concat "\n"
(List.map t ~f:(fun (p, a) -> f a ^ ":\t" ^ string_of_float p ^ " "))
end
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
Now back to the good part - let's define the main type that we'll be dealing with: a Distribution.t
.
module Distribution : sig
type 'a t = (Probability.t * 'a) list
val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string
include Monad.S with type 'a t := 'a t
end = Distribution_implementation
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
What is Distribution_implementation
and where is it defined? Well, I hid it. If you really want to see it, fine - click here to show it, but it's a big chunk of code that's not super interesting or important. What is important is the signature which Distribution
exposes. Let's explain each line:
type 'a t = (Probability.t * 'a) list
: This says that a Distribution.t
of 'a
s (read: possible outcomes) is really a list of 'a
s with an associated probability. Again, let's make this concrete. Let's say our distribution represent the weather tomorrow, and to keep it simple, let's say there are two possible outcomes: Rain
or Shine
. If we wanted to create a Distribution.t
that represents an 80% chance of Shine
and a 20% chance of Rain
, ... actually, let's just do that:
type rain_or_shine = Rain | Shine;;
let tomorrow's_weather : rain_or_shine Distribution.t =
[ 0.8, Shine; 0.2, Rain ]
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
You can see that the type of tomorrow's_weather
is a rain_or_shine Distribution.t
, but really it's just a list of possible outcomes along with their associated probabilities. And, just to explicitly connect the dots, the 'a
in this case is rain_or_shine
.
val to_string : ?decimals:int -> ('a -> string) -> 'a t -> string
: This line probably needs the least explaining. It says that there exists a function Distribution.to_string
that, given a function from 'a -> string
can produce a string from a 'a
distribution. Although probably unnecessary, let's take it for a spin.
let rain_or_shine_to_string = function
| Rain -> "Rain"
| Shine -> "Shine";;
Distribution.to_string rain_or_shine_to_string tomorrow's_weather
|> print_endline;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
let rain_or_shine_to_string =
fun
| Rain => "Rain"
| Shine => "Shine";
Distribution.to_string(rain_or_shine_to_string, tomorrow's_weather)
|> print_endline;
Simple enough.
include Monad.S with type 'a t := 'a t
: Ok, there's kind of a lot to unpack in this one line. include Monad.S
really just means paste the signature Monad.S
here, so let's do that and explain each of those lines:
module Monad = struct
module type S = sig
type 'a t
val return : 'a -> 'a t
val map : 'a t -> f:('a -> 'b) -> 'b t
val join : 'a t t -> 'a t
end
end
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
Great, here's where we can explain what these very abstract functions mean in the context of a distribution.
val return : 'a -> 'a t
: return
says give me an outcome, and I'll give you a distribution on that outcome. Continuing with our weather example, we could say Distribution.return Shine
. The probabilities of a distribution have to add up to 100%, so what's the only reasonable thing this could mean? Well, return
will create a "Distribution" where the outcome that you passed in happens with a 100% probability.
val map : 'a t -> f:('a -> 'b) -> 'b t
: map
says give me a distribution of outcomes of type 'a
and a way to map 'a
s to 'b
s, and I'll give you a distribution of 'b
s. Let's imagine a way we might use this function on tomorrow's_weather
. Let's say we're trying to figure out whether or not we can play soccer in the park tomorrow. Let's use map!
let can_play_soccer_in_park rain_or_shine =
match rain_or_shine with
| Rain -> false
| Shine -> true;;
let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain => false
| Shine => true
};
let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);
So, we created a boolean distribution (which repesents whether or not we can play soccer in the park tomorrow) by mapping an existing distribution (tomorrow's weather) with a function that maps weather to whether or not you can play soccer in the park. A bit of a mouthful, but conceptually pretty straightforward.
val join : 'a t t -> 'a t
: Saving the best for last. join
is what really differentiates a monad. In our case, it means that a distribution of distributions is still just a distribution. No wonder people find this confusing... Let's carry on with our strategy of making this more and more concrete until it's actually intelligible.
What if our function can_play_soccer_in_park
was a bit too simplistic. Even if it's sunny tomorrow (i.e. Shine
), there's still some chance, albeit small, that you can't play soccer in the park. What if the park is closed? What if other people are using the park all day long? So, let's account for this and update can_play_soccer_in_park
to return a boolean distribution, as opposed to a boolean.
let can_play_soccer_in_park rain_or_shine =
match rain_or_shine with
| Rain ->
(* we definitely can't play if it's raining *)
Distribution.return false
| Shine ->
(* let's say we can play with a 95% probability *)
[ 0.95, true; 0.05, false ]
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};
Ok, and let's do what we did before and map tomorrow's_weather
with our new and improved function.
let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};
let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park);
Wait - what happened? Why didn't that work like before? Well,
ignoring the confusing line numbers, it says that you expected
can_play_soccer_in_park_tomorrow
to be of type bool Distribution.t
but it's actually bool Distribution.t Distribution.t
, i.e. it's a
distribution of boolean distributions!
If that's still confusing, think about it this way. With 80% chance,
the weather tomorrow is Shine
, and if the weather is Shine
then
there's an 95% chance you can play soccer in the park and a 5% chance
you can't. So, you can see how there are two layers of probability
distributions. But, intuitively, we know we can collapse these into
just one. If I ask you "what's the chance you can play in the park
tomorrow?" - there's a single answer: 80% * 95% + 20% * 0% = 76%.
That's exactly what join does. It flattens two levels of a monad - a distribution in our case - into one level. Let's see it in action.
let can_play_soccer_in_park_tomorrow : bool Distribution.t =
Distribution.map tomorrow's_weather ~f:can_play_soccer_in_park
|> Distribution.join
;;
Distribution.to_string string_of_bool can_play_soccer_in_park_tomorrow
|> print_endline;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
type rain_or_shine =
| Rain
| Shine;
let tomorrow's_weather: Distribution.t(rain_or_shine) = [
(0.8, Shine),
(0.2, Rain),
];
let can_play_soccer_in_park = rain_or_shine =>
switch (rain_or_shine) {
| Rain =>
/* we definitely can't play if it's raining */
Distribution.return(false)
| Shine =>
/* let's say we can play with a 95% probability */
[(0.95, true), (0.05, false)]
};
let can_play_soccer_in_park_tomorrow: Distribution.t(bool) =
Distribution.map(tomorrow's_weather, ~f=can_play_soccer_in_park)
|> Distribution.join;
Distribution.to_string(string_of_bool, can_play_soccer_in_park_tomorrow)
|> print_endline;
At this point, it's probably smart to introduce a new operation
bind
. This pattern of calling map and then join immediately
afterwards is so common that there's an operator to make it more
convenient. Let's define bind
so that we can use it later on.
open Distribution
let bind t ~f = map t ~f |> join
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
When actually writing code dealing with monads, it's often more convenient to use bind
than map
and join
. In fact, many people introduce bind
and the part of the definition of a monad instead of join
, but either definition is equivalent since given one, you can define the other (e.g. let join tt = bind tt ~f:(fun t -> t)
.
Alright! We've now explained every line of the signature which Distribution
exposes. You probably thought this day would never come.
Ok, so what? We've defined a bunch of really abstract stuff in the context of a concrete example, a probably distribution, and we've seen some extremely trivial manipulations which are not nearly cool enough to warrant all this hullabaloo.
I agree. Let's start the interesting part.
First off, let me define an infix operator for map
and bind
so that my code later on can be slightly less indented. And while we're at it, let's define a uniform
function which takes a list of outcomes, and creates a distribution where each outcome is equally weighted. One last thing - I'm going to open Distribution
so that I can write the functions we've been talking in an unqualified way, e.g. map
instead of Distribution.map
.
open Distribution
let ( >>| ) t f = map t ~f
let ( >>= ) t f = bind t ~f
let uniform : 'a list -> 'a t = fun l ->
let p = 1. /. float (List.length l) in
List.map l ~f:(fun a -> p, a)
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
With those handy tools, let's define a staple of probability questions.
The die.
let one_through_six = List.range 1 7;;
let die : int t = uniform one_through_six;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
If we want to print our die as a string, we can use to_string
:
to_string string_of_int die |> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
to_string(string_of_int, die) |> print_endline;
What's the probability our die is >= 5?
map die ~f:(fun d -> d >= 5)
|> to_string string_of_bool
|> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
map(die, ~f=d => d >= 5) |> to_string(string_of_bool) |> print_endline;
Too easy. I'm not impressed.
How about this? Let's define a distribution on two dice.
let two_dice =
die >>= fun d1 ->
die >>= fun d2 ->
return (d1, d2)
;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));
What's the probability that the sum of two dice is >= 10?
two_dice
|> map ~f:(fun (d1, d2) -> d1 + d2)
|> map ~f:(fun sum_of_dice -> sum_of_dice >= 10)
|> to_string string_of_bool
|> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));
two_dice
|> map(~f=((d1, d2)) => d1 + d2)
|> map(~f=sum_of_dice => sum_of_dice >= 10)
|> to_string(string_of_bool)
|> print_endline;
Finally something that you might not be able to immediately solve in your head. But let's find something that you definitely can't solve in your head.
The Newton–Pepys problem
In 1693 Samuel Pepys and Isaac Newton corresponded over a problem posed by Pepys in relation to a wager he planned to make. The problem was:
Which of the following three propositions has the greatest chance of success?
A. Six fair dice are tossed independently and at least one “6” appears.
B. Twelve fair dice are tossed independently and at least two “6”s appear.
C. Eighteen fair dice are tossed independently and at least three “6”s appear.
(* The probability distribution representing whether a single die comes up 6. *)
let single_die_is_six : bool t = die >>| fun d -> d = 6
(* The probability distribution representing the number of 6s in a single die. *)
let number_of_sixes_in_a_single_die : int t =
single_die_is_six
>>| function | false -> 0 | true -> 1
(* The probability distribution describing how many sixes come up in 6 dice. *)
let num_sixes_in_six_dice : int t =
number_of_sixes_in_a_single_die >>= fun n1 ->
number_of_sixes_in_a_single_die >>= fun n2 ->
number_of_sixes_in_a_single_die >>= fun n3 ->
number_of_sixes_in_a_single_die >>= fun n4 ->
number_of_sixes_in_a_single_die >>= fun n5 ->
number_of_sixes_in_a_single_die >>= fun n6 ->
return (n1 + n2 + n3 + n4 + n5 + n6)
let num_sixes_in_twelve_dice =
num_sixes_in_six_dice >>= fun first_six ->
num_sixes_in_six_dice >>= fun second_six ->
return (first_six + second_six)
let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice >>= fun first_twelve ->
num_sixes_in_six_dice >>= fun third_six ->
return (first_twelve + third_six)
;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);
/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);
/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);
let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);
let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);
Choice A: Six fair dice are tossed independently and at least one “6” appears.
num_sixes_in_six_dice
|> map ~f:(fun n -> n >= 1)
|> to_string string_of_bool
|> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);
/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);
/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);
let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);
let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);
num_sixes_in_six_dice
|> map(~f=n => n >= 1)
|> to_string(string_of_bool)
|> print_endline;
Choice B: Twelve fair dice are tossed independently and at least two “6”s appear.
num_sixes_in_twelve_dice
|> map ~f:(fun n -> n >= 2)
|> to_string string_of_bool
|> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);
/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);
/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);
let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);
let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);
num_sixes_in_twelve_dice
|> map(~f=n => n >= 2)
|> to_string(string_of_bool)
|> print_endline;
Choice C: Eighteen fair dice are tossed independently and at least three “6”s appear.
num_sixes_in_eighteen_dice
|> map ~f:(fun n -> n >= 3)
|> to_string string_of_bool
|> print_endline
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
/* The probability distribution representing whether a single die comes up 6. */
let single_die_is_six: t(bool) = die >>| (d => d == 6);
/* The probability distribution representing the number of 6s in a single die. */
let number_of_sixes_in_a_single_die: t(int) =
single_die_is_six
>>| (
fun
| false => 0
| true => 1
);
/* The probability distribution describing how many sixes come up in 6 dice. */
let num_sixes_in_six_dice: t(int) =
number_of_sixes_in_a_single_die
>>= (
n1 =>
number_of_sixes_in_a_single_die
>>= (
n2 =>
number_of_sixes_in_a_single_die
>>= (
n3 =>
number_of_sixes_in_a_single_die
>>= (
n4 =>
number_of_sixes_in_a_single_die
>>= (
n5 =>
number_of_sixes_in_a_single_die
>>= (n6 => return(n1 + n2 + n3 + n4 + n5 + n6))
)
)
)
)
);
let num_sixes_in_twelve_dice =
num_sixes_in_six_dice
>>= (
first_six =>
num_sixes_in_six_dice >>= (second_six => return(first_six + second_six))
);
let num_sixes_in_eighteen_dice =
num_sixes_in_twelve_dice
>>= (
first_twelve =>
num_sixes_in_six_dice
>>= (third_six => return(first_twelve + third_six))
);
num_sixes_in_eighteen_dice
|> map(~f=n => n >= 3)
|> to_string(string_of_bool)
|> print_endline;
As you can see here, that's correct!
Hopefully at this point you can see that we have some real power here.
Using bind
, map
, return
(and join
, although not explicitly) we
can start to answer some very real questions.
Let's tackle a more recent question, one that stumped quite a few mathematicians including Paul Erdos!
The Monty Hall Problem.
Here's wikipedia's formulation:
Suppose you're on a game show, and you're given the choice of three doors: Behind one door is a car; behind the others, goats. You pick a door, say No. 1, and the host, who knows what's behind the doors, opens another door, say No. 3, which has a goat. He then says to you, "Do you want to pick door No. 2?" Is it to your advantage to switch your choice?
If you haven't encountered this problem before, I highly recommending taking a minute to come up with an answer yourself before continuing.
First off, let's define another utility function, that turns out to be
very handy in probability questions. condition
will take a
distribution and condition on a pariticular event happening.
let renormalize t =
let total_prob t = List.sum (module Probability) t ~f:fst in
let scale_by = 1. /. total_prob t in
List.map t ~f:(fun (p, a) -> p *. scale_by, a)
;;
let condition : 'a t -> on:('a -> bool) -> 'a t
= fun t ~on ->
List.filter t ~f:(fun (_, a) -> on a)
|> renormalize
;;
module Monad = {
module type S = {
type t('a);
let return: 'a => t('a);
let map: (t('a), ~f: 'a => 'b) => t('b);
let join: t(t('a)) => t('a);
};
};
module type Commutative_group = {
type t;
let zero: t;
let (+): (t, t) => t;
let (-): (t, t) => t;
};
module List = {
include List;
let hd_exn = hd;
let map = (t, ~f) => map(f, t);
let bind = (t, ~f) => map(t, ~f) |> concat;
let filter = (t, ~f) => filter(f, t);
let count = (t, ~f) => filter(t, ~f) |> List.length;
let rec range = (start, stop) =>
if (start == stop) {
[];
} else {
[start, ...range(start + 1, stop)];
};
let group = (t, ~break) => {
let rec loop = (acc, curr_group, last, rest) =>
switch (rest) {
| [] => [curr_group, ...acc]
| [next, ...rest] =>
switch (last) {
| None => loop(acc, [next, ...curr_group], Some(next), rest)
| Some(last) =>
if (break(last, next)) {
loop([curr_group, ...acc], [next], Some(next), rest);
} else {
loop(acc, [next, ...curr_group], Some(next), rest);
}
}
};
loop([], [], None, t) |> rev;
};
let sum =
(type a, (module M): (module Commutative_group with type t = a), t, ~f) =>
map(t, ~f) |> List.fold_left((acc, el) => M.(+)(acc, el), M.zero);
};
module Probability = {
type t = float;
let zero = 0.;
let (+) = (+.);
let (-) = (-.);
};
module Distribution_implementation = {
type t('a) = list((Probability.t, 'a));
let return = a => [(1., a)];
/* compare just the 'a states, not the probabilities */
let compare_states = ((_, a1), (_, a2)) => compare(a1, a2);
let group_states: t('a) => t('a) =
t =>
List.sort(compare_states, t)
|> List.group(~break=(t1, t2) => compare_states(t1, t2) != 0)
|> List.map(~f=group => {
/* each group is going to be non-empty */
let (_, a) = List.hd(group);
let p = List.fold_left((acc, (p, _)) => acc +. p, 0., group);
(p, a);
});
let map = (t, ~f) =>
List.map(t, ~f=((p, a)) => (p, f(a))) |> group_states;
let join = t =>
List.bind(t, ~f=((p, t)) => List.map(t, ~f=((p1, a)) => (p *. p1, a)))
|> group_states;
let to_string = (~decimals=3, f, t) => {
let string_of_float = f => {
let s = Printf.sprintf("%0.*f", decimals, f);
let is_exact = float_of_string(s) == f;
if (is_exact) {
s;
} else {
"~" ++ s;
};
};
String.concat(
"\n",
List.map(t, ~f=((p, a)) =>
f(a) ++ ":\t" ++ string_of_float(p) ++ " "
),
);
};
};
module Distribution: {
type t('a) = list((Probability.t, 'a));
let to_string: (~decimals: int=?, 'a => string, t('a)) => string;
include Monad.S with type t('a) := t('a);
} = Distribution_implementation;
open Distribution;
let bind = (t, ~f) => map(t, ~f) |> join;
open Distribution;
let (>>|) = (t, f) => map(t, ~f);
let (>>=) = (t, f) => bind(t, ~f);
let uniform: list('a) => t('a) =
l => {
let p = 1. /. float(List.length(l));
List.map(l, ~f=a => (p, a));
};
let one_through_six = List.range(1, 7);
let die: t(int) = uniform(one_through_six);
let two_dice = die >>= (d1 => die >>= (d2 => return((d1, d2))));
let renormalize = t => {
let total_prob = t => List.sum((module Probability), t, ~f=fst);
let scale_by = 1. /. total_prob(t);
List.map(t, ~f=((p, a)) => (p *. scale_by, a));
};
let condition: (t('a), ~on: 'a => bool) => t('a) =
(t, ~on) => List.filter(t, ~f=((_, a)) => on(a)) |> renormalize;