## The problem

This post is about the equation $(a+b)^3=a^3+3a^2b+3ab^2+b^3$. This special case of the binomial theorem is the sort of thing that we would expect incoming students at my university to know, or at least to be able to compute very quickly. But what actually goes into an axiomatic proof? Whether you think of $a$ and $b$ as real numbers or as elements of an arbitrary commutative ring, we humans are very good at multiplying out the brackets on the left hand side and then collecting up terms and putting them together to form the right hand side.

One reason humans are so good at this is that we have some really good notation available to us. This becomes apparent when we try and verify $(a+b)^3=a^3+3a^2b+3ab^2+b^3$ using Lean. The ring axiom left_distrib says a * (b + c) = a * b + a * c, and this axiom has a second, more user-friendly, name of mul_add. There is also add_mul for expanding out (a + b) * c. Putting these together we can expand out all the brackets on the left hand side. Well, first we have to expand out the ^3 notation, and because Lean defines ^ recursively with x^0=1 and x^(n+1)=x*x^n we see that x^3 expands out to x*(x*(x*1)). Fortunately mul_one is the theorem saying x*1=x so let’s use mul_one and mul_add and add_mul to expand out our brackets.


import algebra.group_power -- for power notation in rings

example (R : Type) [comm_ring R] (a b : R) :
(a+b)^3 = a^3+3*a^2*b+3*a*b^2+b^3 :=
begin
unfold pow monoid.pow, -- LHS now (a + b) * ((a + b) * ((a + b) * 1))
repeat {rw mul_one},

end


What do we have after repeatedly expanding out brackets? We must be close to proving the goal, right? Well, not really. Our goal is now


⊢ a * (a * a) + a * (a * b) + (a * (b * a) + a * (b * b)) +
(b * (a * a) + b * (a * b) + (b * (b * a) + b * (b * b))) =
a * (a * a) + 3 * (a * a) * b + 3 * a * (b * b) + b * (b * b)


This mess indicates exactly what humans are good at. We have terms a*(a*b) and a*(b*a) and b*(a*a) on the left hand side, and (a*a)*b on the right hand side: humans instantly simplify all of these to $a^2b$ but to Lean these variants are all different, and we’ll have to invoke theorems (associativity and commutativity of multiplication) to prove that they’re equal. Moreover, we have random other brackets all over the place on the left hand side because the terms are being added up in quite a random order. Humans would not even write any of these brackets, they would happily invoke the fact that p+q+r+s+t+u+v+w=u+(t+p)+(r+(w+q)+(v+s)) without stopping to think how many times one would have to apply associativity and commutativity rules — these things are “obvious to humans”. Humans would take one look at the terms on the left and right hand sides, immediately remove all the brackets, tidy up, and observe that the result was trivial. But asking general purpose automation tools like Lean’s simp to do this is a bit of a tall order. Perhaps one way to start is to remind Lean that 3=1+1+1 and then expand out those new brackets:


example (R : Type) [comm_ring R] (a b : R) :
(a+b)^3 = a^3+3*a^2*b+3*a*b^2+b^3 :=
begin
unfold pow monoid.pow,
repeat {rw mul_one},
-- now expand out 3=1+1+1
show _ = _ + (1+1+1) * _ * b + (1+1+1) * a * _ + _,
-- and now more expanding out brackets
repeat {rw one_mul},

end


The goal is now ⊢ a * (a * a) + a * (a * b) + (a * (b * a) + a * (b * b)) + (b * (a * a) + b * (a * b) + (b * (b * a) + b * (b * b))) = a * (a * a) + (a * a * b + a * a * b + a * a * b) + (a * (b * b) + a * (b * b) + a * (b * b)) + b * (b * b), which is trivial modulo a gazillion applications of associativity and commutativity of addition and multiplication. But even at this stage, simp [mul_comm,mul_assoc] doesn’t finish the job. What to do?

This problem became a real issue for Chris Hughes, an Imperial undergraduate, who was doing a bunch of algebra in Lean and would constantly run into such goals, maybe even with four unknown variables, and would constantly have to break his workflow to deal with these “trivial to human” proofs.

## The solution

Of course, computer scientists have thought about this problem, and there are algorithms present in theorem provers which can deal with these kinds of goals no problem. The most recent work I know of is the 2005 paper by Gregoire and Mahboubi Proving Equalities in a Commutative Ring Done Right in Coq , where they explained how to write an efficient tactic which would solve these sorts of problems. Their algorithm was implemented in Lean by Mario Carneiro as the tactic ring in Lean’s maths library, so we can now write


import tactic.ring

example (R : Type) [comm_ring R] (a b : R) :
(a+b)^3 = a^3+3*a^2*b+3*a*b^2+b^3 :=
by ring


and our goal is solved instantly! Experiments show that the tactic seems to prove any random true ring theory lemma you throw at it. You can look at Mario’s code here (permalink) . This is tactic code, so it’s full of intimidating stuff like meta and do and expr, things which an “end-user” such as myself never use — these keywords are for tactic-writers, not tactic-users.

I don’t know too much about tactic writing, and this is for the most part because I have never tried to write a tactic of my own. I’ve looked at Programming In Lean — people point out that this document is unfinished, but all of the important stuff is there: the “unfinished” part is basically all covered in Theorem Proving In Lean. However, as is usual in this game currently, the examples in Programming In Lean are most definitely of a more computer-sciencey flavour. So this leads us to the question: how should a mathematician learn how to write tactics? One method, which I’ve used to learn several other things in my life, is “supervise a Masters student for their project, and give them the job of writing some tactics for mathematicians”, and indeed I will be trying this method in October. The reason this method is so effective is that it forces me to learn the material myself so I don’t let the student down. In short, I had better learn something about how to write tactics. So how does tactic.ring work?

## The ring tactic

Looking at the code, I realised that it didn’t look too intimidating. I had looked at Mahboubi–Gregoire before, several months ago, and at that time (when I knew far far less about Lean) I could not see the wood for the trees. But when I revisited the paper this week it made a whole lot more sense to me. The main point of the Mahboubi–Gregoire paper is to explain how to write an efficient ring tactic, but clearly I am not at that stage yet — I just want to know how to write any ring tactic! I decided to start small, and to try and write a tactic that could prove basic stuff like $(d+1)^2=d^2+2d+1$ for $d\in\mathbb{Z}$ automatically, or at least to try and understand what went into writing such a tactic. And actually, it turns out to be relatively easy, especially if someone like Mario shows up on Zulip and writes some meta code for you.

Here is the basic idea. If d were actually an explicit fixed int, then proving (d+1)^2=d^2+2*d+1 would be easy — you just evaluate both sides and check that they’re equal. The problem is that d is an unknown integer, so we cannot evaluate both sides explictly. However, we are lucky in that the reason the equality above is true is because it can be deduced from the axioms of a ring, even if d is considered to be an unknown variable. More formally, I am saying that in the polynomial ring $\mathbb{Z}[X]$ we have the identity $(X+1)^2=X^2+2X+1$, and if we can prove this, then we can evaluate both sides at $X=d$ and deduce the result we need.

This gives us our strategy for proving (d+1)^2=d^2+2*d+1 using a tactic:

1. Make a type representing a polynomial ring $\mathbb{Z}[X]$ in one variable in Lean.

2. Prove that the “evaluation at $d$” map $ev(d) : \mathbb{Z}[X] \to \mathbb{Z}$ is a ring homomorphism.

3. Now given the input (d+1)^2=d^2+2*d+1, construct the equation (X+1)^2=X^2+2*X+1 as an equality in int[X], prove this, and deduce the result by evaluating at d.

This is the basic idea. There are all sorts of subtleties in turning this into an efficient implementation and I will not go into these now; I just want to talk about making this sort of thing work at all. So what about these three points above? The first two are the sort of thing I now find pretty straightforward in Lean. On the other hand, the third point is, I believe, perhaps impossible to do in Lean without using tactics. The int (d+1)^2 was built using addition and the power map, but we now want to write some sort of function which says “if the last thing you did to get this int was to square something, then square some corresponding polynomial” and so on — one needs more than the int (d+1)^2 — one actually needs access to the underlying expr. There, I said it. And expr is a weird type — if you try definition expr_id : expr → expr := id, Lean produces the error invalid definition, it uses untrusted declaration 'expr'. We have to go into meta-land to access the expr — the land of untrusted code.

## Writing a baby ring tactic

Two thirds of this is now within our grasp (the first two points above) and we will borrow Mario for the third one. The approach we use is called “reflection”. Every operator comes with a cost, so (out of laziness) I will stick to + and * and not implement ^ or -. It’s also far easier to stick to one unknown variable, so instead of tackling $(a+b)^3$ I will stick to $(x+1)^3$ as our test case. The full code is here (permalink)

I should also heartily thank Mario Carneiro, who wrote all the meta code.

We start by creating a concrete polynomial type, represented as a list of coefficients. We use znum (which is just a more efficient version of int when it comes to calculations rather than proofs) for our coefficients, because it’s best practice.


def poly := list znum

 

For example [1,2,3] represents $1+2X+3X^2$. We define addition


def poly.add : poly → poly → poly
| [] g := g
| f [] := f
| (a :: f') (b :: g') := (a + b) :: poly.add f' g'

 

and scalar multiplication, multiplication, constants, and our variable X.

We have the wrong notion of equality: the polynomials [1,2,3] and [1,2,3,0] represent the same polynomial. Rather than defining an equivalence relation, we just define poly.is_eq, defined as "the lists are equal when you throw away the zeros on the end", and when I say two polynomials "are equal" I mean they're equal modulo this equivalence relation.

We define evaluation of a polynomial at an element of a ring:


def poly.eval {α} [comm_ring α] (X : α) : poly → α
| [] := 0
| (n::l) := n + X * poly.eval l


and then have to prove all the lemmas saying that polynomial evaluation respects equality (in the sense above) and addition and multiplication. For example


@[simp] theorem poly.eval_mul {α} [comm_ring α] (X : α) : ∀ p₁ p₂ : poly,
(p₁.mul p₂).eval X = p₁.eval X * p₂.eval X := ...


We are now 3/4 of the way through the code, everything has been relatively straightforward, and there hasn't even been any meta stuff. Now it starts. We make an "abstract" polynomial ring (with a hint that we're going to be using reflection later on in the attribute).


@[derive has_reflect]
inductive ring_expr : Type
| add : ring_expr → ring_expr → ring_expr
| mul : ring_expr → ring_expr → ring_expr
| const : znum → ring_expr
| X : ring_expr


The problem with this abstract ring is that $(X+1)*(X+1)$ and $X*X+2*X+1$ are different terms, whereas they are both the same list [1,2,1] in the concrete polynomial ring. We define an evaluation from the abstract to the concrete polynomial ring:


def to_poly : ring_expr → poly
| (ring_expr.mul e₁ e₂) := (to_poly e₁).mul (to_poly e₂)
| (ring_expr.const z) := poly.const z
| ring_expr.X := poly.X


and then define evaluation of a polynomial in the abstract ring and prove it agrees with evaluation on the concrete ring. This is all relatively straightforward (and probably a good exercise for learners!).

The big theorem: if the concrete polynomials are equal, then the abstract ones, evaluated at the same value, return the same value.


theorem main_thm {α} [comm_ring α] (X : α) (e₁ e₂) {x₁ x₂}
(H : poly.is_eq (to_poly e₁) (to_poly e₂)) (R1 : e₁.eval X = x₁) (R2 : e₂.eval X = x₂) : x₁ = x₂ :=
by rw [← R1, ← R2, ← to_poly_eval,poly.eval_is_eq X H, to_poly_eval]



We are nearly there, and the only intimidating line of code so far was the @[derive has_reflect] line before the abstract definition. But now we go one more stage of abstract: we now define a reflection map, sending an expr (like (d+1)^2) to an abstract polynomial (like (X+1)^2). This is precisely the part we cannot do unless we're in meta-land. Note that this code will fail if the expr does not represent a polynomial, but that's OK becuase it's a meta definition.


meta def reflect_expr (X : expr) : expr → option ring_expr
| (%%e₁ + %%e₂) := do
p₁ ← reflect_expr e₁,
p₂ ← reflect_expr e₂,
| (%%e₁ * %%e₂) := do
p₁ ← reflect_expr e₁,
p₂ ← reflect_expr e₂,
return (ring_expr.mul p₁ p₂)
| e := if e = X then return ring_expr.X else
do n ← expr.to_int e,
return (ring_expr.const (znum.of_int' n))


After those few lines of meta, we are actually ready to define the tactic!


meta def ring_tac (X : pexpr) : tactic unit := do
X ← to_expr X,
(%%x₁ = %%x₂) ← target,
r₁ ← reflect_expr X x₁,
r₂ ← reflect_expr X x₂,
let e₁ : expr := reflect r₁,
let e₂ : expr := reflect r₂,
[refine main_thm %%X %%e₁ %%e₂ rfl _ _],
all_goals [simp only [ring_expr.eval,
znum.cast_pos, znum.cast_neg, znum.cast_zero',
pos_num.cast_bit0, pos_num.cast_bit1,
pos_num.cast_one']]


The code takes in a variable x, then assumes that the goal is of the form f(x)=g(x), attempts to build two terms f(X) and g(X) in the abstract ring, proves the corresponding lists of coefficients in the concrete ring are equal using rfl, and deduces that the abstract polynomials evaluated at x are equal. If something doesn't work, it fails, and in practice this just means the tactic fails.

And it works!


example (x : ℤ) : (x + 1) * (x + 1) = x*x+2*x+1 := by do ring_tac (x)

example (x : ℤ) : (x + 1) * (x + 1) * (x + 1) = x*x*x+3*x*x+3*x+1 := by do ring_tac (x)

example (x : ℤ) : (x + 1) + ((-1)*x + 1) = 2 := by do ring_tac (x)


The last example would fail if we hadn't used the correct version of equality for the lists, because [2,0] and [2] are different lists -- but the corresponding polynomials are equal.

And that's how to write a ring tactic! Many thanks to Mario Carneiro for doing much of the planning and also writing the meta code.