Skip to content

Starter Exercises for Learning Coq

jadephilipoom edited this page Aug 29, 2019 · 1 revision

Exercises at two difficulty levels for trying out Coq. Each code block runs as a standalone Coq file (tested in versions 8.6 and 8.9.1). Both difficulty levels have two solutions: one using only a small set of simple features, and the other using proof automation.

These exercises were designed to go with the Quick Reference for Beginners.

Easier exercises:

Require Import Coq.ZArith.ZArith.
Require Import Coq.micromega.Lia.

(*** Easy exercises ***)

(* All of the exercises here are about writing equivalent functions for
   natural numbers and for integers, and proving that they're equivalent. *)

(*** Part 1 ***)

(* Defines [foo] on natural numbers (nat). *)
Definition foo_nat (x y z : nat) : nat :=
  x - y + z.

(* Defines [foo] on integers (Z). *)
Definition foo_int (x y z : Z) : Z :=
  x - y + z.

(* Now, prove that the two functions are equivalent; converting the result
   from [foo_nat] to an integer is the same as converting the input to integers
   and running [foo_int]. *)
Lemma foo_equiv :
  forall (x y z : nat),
    y <= x ->
    Z.of_nat (foo_nat x y z) = foo_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [foo_nat foo_int].
  (* Exercise : Fill in the proof! *)

  (* Hint: try using these [Search] commands to find lemmas in the standard
     library. *)
  Search (Z.of_nat (_ + _)).
  Search (Z.of_nat (_ - _)).

Abort. (* [Abort] exits an incomplete proof; when you're done, write [Qed] instead. *)

(*** Part 2 ***)

Definition bar_nat (x y z : nat) : nat :=
  x + y - z.

Definition bar_int (x y z : Z) : Z :=
  x + y - z.

Lemma bar_equiv :
  forall (x y z : nat),
    (* Exercise : Like [foo_equiv], this lemma will need a precondition! Figure
       out what it is. *)
    Z.of_nat (bar_nat x y z) = bar_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [bar_nat bar_int].
  (* Exercise : Fill in the proof! *)

Abort. (* [Abort] exits an incomplete proof; when you're done, write [Qed] instead. *)

(*** Part 3 ***)

Definition baz_nat (x y z : nat) : nat :=
  x + ((y - x) / (z + (x / y))).

Definition baz_int (x y z : Z) : Z :=
  x + ((y - x) / (z + (x / y))).

Lemma baz_equiv :
  forall (x y z : nat),
    (* Exercise : fill in the preconditions (hint: there are 3) *)
    Z.of_nat (baz_nat x y z) = baz_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [baz_nat baz_int].
  (* Exercise : Fill in the proof! *)

Abort. (* [Abort] exits an incomplete proof; when you're done, write [Qed] instead. *)

Harder exercises:

Require Import Coq.Arith.PeanoNat.
Require Import Coq.micromega.Lia.

(**** Setup (feel free to skip this section) ****)

(* These proofs should probably be in the standard library, but they're not *)
Module Nat.
  Lemma succ_sub_1_r n : S n - 1 = n.
  Proof. lia. Qed.

  Lemma Even_mod2 :
    forall n : nat,
      Nat.Even n ->
      n mod 2 = 0.
  Proof.
    cbv [Nat.Even]; destruct 1; intros.
    subst.
    rewrite Nat.mul_mod by lia.
    rewrite Nat.mod_same by lia.
    rewrite Nat.mul_0_l.
    rewrite Nat.mod_0_l by lia.
    reflexivity.
  Qed.

  Lemma Even_mul :
    forall n m : nat,
      Nat.Even n \/ Nat.Even m ->
      Nat.Even (n * m).
  Proof.
    repeat match goal with
           | _ => progress (intros; subst)
           | H : _ \/ _ |- _ => destruct H
           | H : Nat.Even _ |- _ => destruct H
           | |- Nat.Even (2 * ?a * ?b) => exists (a * b); lia
           | |- Nat.Even (?a * (2 * ?b)) => exists (a * b); lia
           end.
  Qed.

  Lemma Odd_minus1_Even :
    forall n : nat,
      Nat.Odd n ->
      Nat.Even (n - 1).
  Proof.
    destruct n; intros; [ exists 0; reflexivity | ].
    rewrite succ_sub_1_r.
    apply Nat.Odd_succ; auto.
  Qed.

  Lemma mul_div_eq :
    forall n m : nat,
      m <> 0 ->
      n mod m = 0 ->
      m * (n / m) = n.
  Proof.
    intros.
    pose proof proj2 (Nat.div_exact n m ltac:(lia)).
    lia.
  Qed.
End Nat.

Hint Rewrite Nat.mod_add Nat.mod_mul Nat.mod_same Nat.mod_mod Nat.mod_0_l
     using lia : push_mod.
Hint Rewrite Nat.mul_add_distr_l Nat.mul_add_distr_r Nat.mul_0_l Nat.mul_0_r
     Nat.mul_1_l Nat.mul_1_r : push_mul.

Ltac simplify_Even :=
    repeat match goal with
           | _ => apply Nat.Even_mod2
           | _ => apply Nat.Even_mul
           | _ => apply Nat.Odd_minus1_Even
           | H : Nat.Odd ?n |- Nat.Even ?n \/ _ => right
           | H : Nat.Odd ?n |- Nat.Even (?n - 1) \/ _ => left
           | _ => tauto
           end.

(**** End of setup ***)

(*** Hard exercises ***)

(* These exercises involve proving formulas for the sums of series. *)

(*** Part 1 ***)

(* In part 1, we'll prove that the sum of the first n odd numbers is n ^ 2. *)

(* First, we write a function that adds together the first n odd numbers. *)
Fixpoint sum_of_first_n_odds (n : nat) : nat :=
  match n with
  | O => 0
  | S m =>
    (* nth odd # is (2 * n - 1); we add it to the sum of the first (n-1) odds *)
    sum_of_first_n_odds m + (2 * n - 1)
  end.

(* Let's look at some output! *)
Eval compute in (sum_of_first_n_odds 3). (* Expected result: 9 (since 1 + 3 + 5 = 9) *)
Eval compute in (sum_of_first_n_odds 6). (* Expected result: 36 *)
Eval compute in (sum_of_first_n_odds 20). (* Expected result: 400 *)
Eval compute in (sum_of_first_n_odds 0). (* Expected result: 0 *)

(* Now, we prove that the sum of the first n odd numbers is equal to n^2 *)
Lemma sum_of_first_n_odds_square :
  forall n : nat,
    sum_of_first_n_odds n = n * n.
Proof.
  induction n.
  { (* inductive base case; n = 0 *)
    (* Exercise: Fill in the proof! *)

    admit. (* [admit] gives up on a subgoal; remove it to prove this case. *) }
  { (* inductive case; n = S m *)

    (* [cbn] inlines functions where they can be simplified; see the reference
       for details. In this case, it performs one round of the match statement
       inside [sum_of_first_n_odds]. *)
    cbn [sum_of_first_n_odds].
    (* Exercise: Fill in the proof! *)

Abort. (* [Abort] exits an incomplete proof; when you're done, write [Qed] instead. *)

(*** Part 2 **)

(* In part 2, we'll prove the formula for an arithmetic series (that is, the sum
   of a sequence of numbers that increase by a constant amount between each pair
   of adjacent numbers, like [5 + 7 + 9] or [12 + 16 + 20 + 24]). For "informal"
   proofs, see the following:

   MathWorld : http://mathworld.wolfram.com/ArithmeticSeries.html
   Wikipedia : https://en.wikipedia.org/wiki/Arithmetic_progression#Sum
*)

(* represents \sum_{i=0}^{length} (start + (i * inc)); therefore

       arithmetic_series 1 0 3 = 0 + 1 + 2 = 3
       arithmetic_series 2 5 4 = 5 + 7 + 9 + 11 = 32 *)
Fixpoint arithmetic_series (inc start length : nat) : nat :=
  match length with
  | O => 0
  | S length' =>
    start + arithmetic_series inc (start + inc) length'
  end.

(* Check the output. *)
Eval compute in (arithmetic_series 1 0 3). (* Expected result : 3 (= 0 + 1 + 2) *)
Eval compute in (arithmetic_series 2 5 4). (* Expected result : 32 (= 5 + 7 + 9 + 11 + 13)*)
Eval compute in (arithmetic_series 0 1 3). (* Expected result : 3 (= 1 + 1 + 1) *)
Eval compute in (arithmetic_series 5 6 0). (* Expected result : 0 *)

(* formula for last term in an arithmetic sequence; makes things more readable *)
Definition last_element (inc start length : nat) : nat :=
  start + (length - 1) * inc.

(* Hint: This lemma just might come in handy; given that Coq's / is a floor
   division, you'll have to prove along the way that the division by 2 in the
   arithmetic series formula is okay to replace with a floor division. *)
Lemma length_times_start_plus_last_even :
  forall inc start length,
    (length * (start + last_element inc start length)) mod 2 = 0.
Proof.
  (* This proof uses some advanced techniques because I don't like annoying
     algebra; don't worry about understanding it at this point. *)
  cbv [last_element]; intros.
  repeat match goal with
         | _ => progress autorewrite with push_mul push_mod
         | |- context [?x + (?x + ?y)] =>
           replace (x + (x + y)) with (2 * x + y) by lia
         | |- context [(?x * ?y + ?z) mod ?x] =>
           replace (x * y + z) with (z + y * x) by lia
         end.
  destruct (Nat.Even_or_Odd length); simplify_Even.
Qed.

(* Hint: This lemma might also be useful; it allows you to multiply both sides of
   an equation by a divisor (c) to eliminate divisions. *)
Lemma eliminate_div_by_mul :
  forall a b c d,
    c <> 0 ->
    b mod c = 0 -> (* required because / is a floor division*)
    d mod c = 0 -> (* required because / is a floor division*)
    c * a + b = d ->
    a + b / c = d / c.
Proof.
  intros.
  apply Nat.mul_cancel_l with (p:=c); [ lia | ].
  autorewrite with push_mul.
  rewrite !Nat.mul_div_eq by lia.
  lia.
Qed.

(* Prove the formula correct. *)
Lemma arithmetic_series_equiv :
  forall inc start length : nat,
    arithmetic_series inc start length = length * (start + last_element inc start length) / 2.
Proof.
  cbv [last_element]. (* inline the formula for the last element *)
  induction length.
  (* Exercise: Fill in the proof! *)

  (* Hint : As written, this lemma's inductive hypothesis is too "weak" (meaning
     you won't be able to [rewrite] with it in the inductive case). Work until
     you get stuck, and then read the reference entry on [induction] for guidance
     on how to fix it. *)

Abort. (* [Abort] exits an incomplete proof; when you're done, write [Qed] instead. *)

Easy exercises solution #1 (non-automated)

Require Import Coq.micromega.Lia.

(*** Easy exercises ***)

(* All of the exercises here are about writing equivalent functions for
   natural numbers and for integers, and proving that they're equivalent. *)

(*** Part 1 ***)

(* Defines [foo] on natural numbers (nat). *)
Definition foo_nat (x y z : nat) : nat :=
  x - y + z.

(* Defines [foo] on integers (Z). *)
Definition foo_int (x y z : Z) : Z :=
  x - y + z.

(* Now, prove that the two functions are equivalent; converting the result
   from [foo_nat] to an integer is the same as converting the input to integers
   and running [foo_int]. *)
Lemma foo_equiv :
  forall (x y z : nat),
    y <= x ->
    Z.of_nat (foo_nat x y z) = foo_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [foo_nat foo_int].
  rewrite Nat2Z.inj_add.
  rewrite Nat2Z.inj_sub by lia.
  reflexivity.
Qed.

(*** Part 2 ***)

Definition bar_nat (x y z : nat) : nat :=
  x + y - z.

Definition bar_int (x y z : Z) : Z :=
  x + y - z.

Lemma bar_equiv :
  forall (x y z : nat),
    z <= x + y ->
    Z.of_nat (bar_nat x y z) = bar_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [bar_nat bar_int].
  rewrite Nat2Z.inj_sub by lia.
  rewrite Nat2Z.inj_add.
  reflexivity.
Qed.

(*** Part 3 ***)

Definition baz_nat (x y z : nat) : nat :=
  x + ((y - x) / (z + (x / y))).

Definition baz_int (x y z : Z) : Z :=
  x + ((y - x) / (z + (x / y))).

Lemma baz_equiv :
  forall (x y z : nat),
    y <> 0 -> z <> 0 -> x <= y ->
    Z.of_nat (baz_nat x y z) = baz_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof.
  intros.
  cbv [baz_nat baz_int].
  rewrite Nat2Z.inj_add.
  rewrite div_Zdiv by lia.
  rewrite Nat2Z.inj_sub by lia.
  rewrite Nat2Z.inj_add.
  rewrite div_Zdiv by lia.
  reflexivity.
Qed.

Easy exercises solution #2 (automated)


Require Import Coq.ZArith.ZArith.
Require Import Coq.micromega.Lia.

(*** Easy exercises ***)

(* All of the exercises here are about writing equivalent functions for
   natural numbers and for integers, and proving that they're equivalent. *)

(* After this command, writing [autorewrite with push_nat2z] will be the same as
   repeatedly rewriting with all the lemmas listed, as long as their
   preconditions can be solved by [lia]. *)
Hint Rewrite Nat2Z.inj_add Nat2Z.inj_sub div_Zdiv using lia : push_nat2z.

(* This tactic will try all three cases in order and keep going as long as at
   least one succeeds. The [progress] prefixes mean "fail if this doesn't make
   the goal simpler"; without them, the tactic could get stuck in a loop that
   doesn't make progress. *)
Ltac crush :=
  repeat first [ progress intros
               | progress autorewrite with push_nat2z
               | reflexivity ].

(*** Part 1 ***)

(* Defines [foo] on natural numbers (nat). *)
Definition foo_nat (x y z : nat) : nat :=
  x - y + z.

(* Defines [foo] on integers (Z). *)
Definition foo_int (x y z : Z) : Z :=
  x - y + z.

(* Now, prove that the two functions are equivalent; converting the result
   from [foo_nat] to an integer is the same as converting the input to integers
   and running [foo_int]. *)
Lemma foo_equiv :
  forall (x y z : nat),
    y <= x ->
    Z.of_nat (foo_nat x y z) = foo_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof. cbv [foo_nat foo_int]. crush. Qed.

(*** Part 2 ***)

Definition bar_nat (x y z : nat) : nat :=
  x + y - z.

Definition bar_int (x y z : Z) : Z :=
  x + y - z.

Lemma bar_equiv :
  forall (x y z : nat),
    z <= x + y ->
    Z.of_nat (bar_nat x y z) = bar_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof. cbv [bar_nat bar_int]. crush. Qed.

(*** Part 3 ***)

Definition baz_nat (x y z : nat) : nat :=
  x + ((y - x) / (z + (x / y))).

Definition baz_int (x y z : Z) : Z :=
  x + ((y - x) / (z + (x / y))).

Lemma baz_equiv :
  forall (x y z : nat),
    y <> 0 -> z <> 0 -> x <= y ->
    Z.of_nat (baz_nat x y z) = baz_int (Z.of_nat x) (Z.of_nat y) (Z.of_nat z).
Proof. cbv [baz_nat baz_int]. crush. Qed.

(* The really beautiful thing about automating proofs is that, unlike using a
   step-by-step strategy, automated proofs often don't depend on the exact
   interior structure of your definitions. So if you change things, your proofs
   won't break, as long as they use the same basic reasoning! *)

Hard exercises solution #1 (non-automated):

Require Import Coq.Arith.PeanoNat.
Require Import Coq.micromega.Lia.

(**** Setup (feel free to skip this section) ****)

(* These proofs should probably be in the standard library, but they're not *)
Module Nat.
  Lemma succ_sub_1_r n : S n - 1 = n.
  Proof. lia. Qed.

  Lemma Even_mod2 :
    forall n : nat,
      Nat.Even n ->
      n mod 2 = 0.
  Proof.
    cbv [Nat.Even]; destruct 1; intros.
    subst.
    rewrite Nat.mul_mod by lia.
    rewrite Nat.mod_same by lia.
    rewrite Nat.mul_0_l.
    rewrite Nat.mod_0_l by lia.
    reflexivity.
  Qed.

  Lemma Even_mul :
    forall n m : nat,
      Nat.Even n \/ Nat.Even m ->
      Nat.Even (n * m).
  Proof.
    repeat match goal with
           | _ => progress (intros; subst)
           | H : _ \/ _ |- _ => destruct H
           | H : Nat.Even _ |- _ => destruct H
           | |- Nat.Even (2 * ?a * ?b) => exists (a * b); lia
           | |- Nat.Even (?a * (2 * ?b)) => exists (a * b); lia
           end.
  Qed.

  Lemma Odd_minus1_Even :
    forall n : nat,
      Nat.Odd n ->
      Nat.Even (n - 1).
  Proof.
    destruct n; intros; [ exists 0; reflexivity | ].
    rewrite succ_sub_1_r.
    apply Nat.Odd_succ; auto.
  Qed.

  Lemma mul_div_eq :
    forall n m : nat,
      m <> 0 ->
      n mod m = 0 ->
      m * (n / m) = n.
  Proof.
    intros.
    pose proof proj2 (Nat.div_exact n m ltac:(lia)).
    lia.
  Qed.
End Nat.

Hint Rewrite Nat.mod_add Nat.mod_mul Nat.mod_same Nat.mod_mod Nat.mod_0_l
     using lia : push_mod.
Hint Rewrite Nat.mul_add_distr_l Nat.mul_add_distr_r Nat.mul_0_l Nat.mul_0_r
     Nat.mul_1_l Nat.mul_1_r : push_mul.

Ltac simplify_Even :=
    repeat match goal with
           | _ => apply Nat.Even_mod2
           | _ => apply Nat.Even_mul
           | _ => apply Nat.Odd_minus1_Even
           | H : Nat.Odd ?n |- Nat.Even ?n \/ _ => right
           | H : Nat.Odd ?n |- Nat.Even (?n - 1) \/ _ => left
           | _ => tauto
           end.

(**** End of setup ***)

(*** Hard exercises ***)

(* These exercises involve proving formulas for the sums of series. *)

(*** Part 1 ***)

(* In part 1, we'll prove that the sum of the first n odd numbers is n ^ 2. *)

(* First, we write a function that adds together the first n odd numbers. *)
Fixpoint sum_of_first_n_odds (n : nat) : nat :=
  match n with
  | O => 0
  | S m =>
    (* nth odd # is (2 * n - 1); we add it to the sum of the first (n-1) odds *)
    sum_of_first_n_odds m + (2 * n - 1)
  end.

(* Let's look at some output! *)
Eval compute in (sum_of_first_n_odds 3). (* Expected result: 9 (since 1 + 3 + 5 = 9) *)
Eval compute in (sum_of_first_n_odds 6). (* Expected result: 36 *)
Eval compute in (sum_of_first_n_odds 20). (* Expected result: 400 *)
Eval compute in (sum_of_first_n_odds 0). (* Expected result: 0 *)

(* Now, we prove that the sum of the first n odd numbers is equal to n^2 *)
Lemma sum_of_first_n_odds_square :
  forall n : nat,
    sum_of_first_n_odds n = n * n.
Proof.
  induction n.
  { reflexivity. }
  { cbn [sum_of_first_n_odds].
    rewrite IHn.
    lia. }
Qed.

(*** Part 2 **)

(* In part 2, we'll prove the formula for an arithmetic series (that is, the sum
   of a sequence of numbers that increase by a constant amount between each pair
   of adjacent numbers, like [5 + 7 + 9] or [12 + 16 + 20 + 24]). For "informal"
   proofs, see the following:

   MathWorld : http://mathworld.wolfram.com/ArithmeticSeries.html
   Wikipedia : https://en.wikipedia.org/wiki/Arithmetic_progression#Sum
*)

(* represents \sum_{i=0}^{length} (start + (i * inc)); therefore

       arithmetic_series 1 0 3 = 0 + 1 + 2 = 3
       arithmetic_series 2 5 4 = 5 + 7 + 9 + 11 = 32 *)
Fixpoint arithmetic_series (inc start length : nat) : nat :=
  match length with
  | O => 0
  | S length' =>
    start + arithmetic_series inc (start + inc) length'
  end.

(* Check the output. *)
Eval compute in (arithmetic_series 1 0 3). (* Expected result : 3 (= 0 + 1 + 2) *)
Eval compute in (arithmetic_series 2 5 4). (* Expected result : 32 (= 5 + 7 + 9 + 11 + 13)*)
Eval compute in (arithmetic_series 0 1 3). (* Expected result : 3 (= 1 + 1 + 1) *)
Eval compute in (arithmetic_series 5 6 0). (* Expected result : 0 *)

(* formula for last term in an arithmetic sequence; makes things more readable *)
Definition last_element (inc start length : nat) : nat :=
  start + (length - 1) * inc.

(* Hint: This lemma just might come in handy; given that Coq's / is a floor
   division, you'll have to prove along the way that the division by 2 in the
   arithmetic series formula is okay to replace with a floor division. *)
Lemma length_times_start_plus_last_even :
  forall inc start length,
    (length * (start + last_element inc start length)) mod 2 = 0.
Proof.
  (* This proof uses some advanced techniques because I don't like annoying
     algebra; don't worry about understanding it at this point. *)
  cbv [last_element]; intros.
  repeat match goal with
         | _ => progress autorewrite with push_mul push_mod
         | |- context [?x + (?x + ?y)] =>
           replace (x + (x + y)) with (2 * x + y) by lia
         | |- context [(?x * ?y + ?z) mod ?x] =>
           replace (x * y + z) with (z + y * x) by lia
         end.
  destruct (Nat.Even_or_Odd length); simplify_Even.
Qed.

(* Hint: This lemma might also be useful; it allows you to multiply both sides of
   an equation by a divisor (c) to eliminate divisions. *)
Lemma eliminate_div_by_mul :
  forall a b c d,
    c <> 0 ->
    b mod c = 0 -> (* required because / is a floor division*)
    d mod c = 0 -> (* required because / is a floor division*)
    c * a + b = d ->
    a + b / c = d / c.
Proof.
  intros.
  apply Nat.mul_cancel_l with (p:=c); [ lia | ].
  autorewrite with push_mul.
  rewrite !Nat.mul_div_eq by lia.
  lia.
Qed.

(* Prove the formula correct. *)
Lemma arithmetic_series_equiv :
  forall inc length start : nat, (* Changing the order of [length] and [start]
                                    strengthens the inductive hypothesis. *)
    arithmetic_series inc start length = length * (start + last_element inc start length) / 2.
Proof.
  cbv [last_element]. (* inline the formula for the last element *)
  induction length.
  { reflexivity. }
  { intros.
    cbn [arithmetic_series].
    rewrite IHlength.
    apply eliminate_div_by_mul.
    { lia. }
    { apply length_times_start_plus_last_even. }
    { apply length_times_start_plus_last_even. }
    { nia. } }
Qed.

Hard exercises solution #2 (automated):

Require Import Coq.micromega.Lia.

(**** Setup (feel free to skip this section) ****)

(* These proofs should probably be in the standard library, but they're not *)
Module Nat.
  Lemma succ_sub_1_r n : S n - 1 = n.
  Proof. lia. Qed.

  Lemma Even_mod2 :
    forall n : nat,
      Nat.Even n ->
      n mod 2 = 0.
  Proof.
    cbv [Nat.Even]; destruct 1; intros.
    subst.
    rewrite Nat.mul_mod by lia.
    rewrite Nat.mod_same by lia.
    rewrite Nat.mul_0_l.
    rewrite Nat.mod_0_l by lia.
    reflexivity.
  Qed.

  Lemma Even_mul :
    forall n m : nat,
      Nat.Even n \/ Nat.Even m ->
      Nat.Even (n * m).
  Proof.
    repeat match goal with
           | _ => progress (intros; subst)
           | H : _ \/ _ |- _ => destruct H
           | H : Nat.Even _ |- _ => destruct H
           | |- Nat.Even (2 * ?a * ?b) => exists (a * b); lia
           | |- Nat.Even (?a * (2 * ?b)) => exists (a * b); lia
           end.
  Qed.

  Lemma Odd_minus1_Even :
    forall n : nat,
      Nat.Odd n ->
      Nat.Even (n - 1).
  Proof.
    destruct n; intros; [ exists 0; reflexivity | ].
    rewrite succ_sub_1_r.
    apply Nat.Odd_succ; auto.
  Qed.

  Lemma mul_div_eq :
    forall n m : nat,
      m <> 0 ->
      n mod m = 0 ->
      m * (n / m) = n.
  Proof.
    intros.
    pose proof proj2 (Nat.div_exact n m ltac:(lia)).
    lia.
  Qed.
End Nat.

Hint Rewrite Nat.mod_add Nat.mod_mul Nat.mod_same Nat.mod_mod Nat.mod_0_l
     using lia : push_mod.
Hint Rewrite Nat.mul_add_distr_l Nat.mul_add_distr_r Nat.mul_0_l Nat.mul_0_r
     Nat.mul_1_l Nat.mul_1_r : push_mul.

Ltac simplify_Even :=
    repeat match goal with
           | _ => apply Nat.Even_mod2
           | _ => apply Nat.Even_mul
           | _ => apply Nat.Odd_minus1_Even
           | H : Nat.Odd ?n |- Nat.Even ?n \/ _ => right
           | H : Nat.Odd ?n |- Nat.Even (?n - 1) \/ _ => left
           | _ => tauto
           end.

(**** End of setup ***)

(*** Hard exercises ***)

(* These exercises involve proving formulas for the sums of series. *)

(* Define a tactic to solve goals in common ways *)
Ltac solver :=
  first [ reflexivity
        | lia
        | nia ].

(*** Part 1 ***)

(* In part 1, we'll prove that the sum of the first n odd numbers is n ^ 2. *)

(* First, we write a function that adds together the first n odd numbers. *)
Fixpoint sum_of_first_n_odds (n : nat) : nat :=
  match n with
  | O => 0
  | S m =>
    (* nth odd # is (2 * n - 1); we add it to the sum of the first (n-1) odds *)
    sum_of_first_n_odds m + (2 * n - 1)
  end.

(* Let's look at some output! *)
Eval compute in (sum_of_first_n_odds 3). (* Expected result: 9 (since 1 + 3 + 5 = 9) *)
Eval compute in (sum_of_first_n_odds 6). (* Expected result: 36 *)
Eval compute in (sum_of_first_n_odds 20). (* Expected result: 400 *)
Eval compute in (sum_of_first_n_odds 0). (* Expected result: 0 *)

(* Now, we prove that the sum of the first n odd numbers is equal to n^2 *)
Lemma sum_of_first_n_odds_square :
  forall n : nat,
    sum_of_first_n_odds n = n * n.
Proof.
  induction n;
    repeat first [ progress cbn [sum_of_first_n_odds]
                 | rewrite IHn
                 | solver ].
Qed.

(*** Part 2 **)

(* In part 2, we'll prove the formula for an arithmetic series (that is, the sum
   of a sequence of numbers that increase by a constant amount between each pair
   of adjacent numbers, like [5 + 7 + 9] or [12 + 16 + 20 + 24]). For "informal"
   proofs, see the following:

   MathWorld : http://mathworld.wolfram.com/ArithmeticSeries.html
   Wikipedia : https://en.wikipedia.org/wiki/Arithmetic_progression#Sum
*)

(* represents \sum_{i=0}^{length} (start + (i * inc)); therefore

       arithmetic_series 1 0 3 = 0 + 1 + 2 = 3
       arithmetic_series 2 5 4 = 5 + 7 + 9 + 11 = 32 *)
Fixpoint arithmetic_series (inc start length : nat) : nat :=
  match length with
  | O => 0
  | S length' =>
    start + arithmetic_series inc (start + inc) length'
  end.

(* Check the output. *)
Eval compute in (arithmetic_series 1 0 3). (* Expected result : 3 (= 0 + 1 + 2) *)
Eval compute in (arithmetic_series 2 5 4). (* Expected result : 32 (= 5 + 7 + 9 + 11 + 13)*)
Eval compute in (arithmetic_series 0 1 3). (* Expected result : 3 (= 1 + 1 + 1) *)
Eval compute in (arithmetic_series 5 6 0). (* Expected result : 0 *)

(* formula for last term in an arithmetic sequence; makes things more readable *)
Definition last_element (inc start length : nat) : nat :=
  start + (length - 1) * inc.

(* Hint: This lemma just might come in handy; given that Coq's / is a floor
   division, you'll have to prove along the way that the division by 2 in the
   arithmetic series formula is okay to replace with a floor division. *)
Lemma length_times_start_plus_last_even :
  forall inc start length,
    (length * (start + last_element inc start length)) mod 2 = 0.
Proof.
  (* This proof uses some advanced techniques because I don't like annoying
     algebra; don't worry about understanding it at this point. *)
  cbv [last_element]; intros.
  repeat match goal with
         | _ => progress autorewrite with push_mul push_mod
         | |- context [?x + (?x + ?y)] =>
           replace (x + (x + y)) with (2 * x + y) by lia
         | |- context [(?x * ?y + ?z) mod ?x] =>
           replace (x * y + z) with (z + y * x) by lia
         end.
  destruct (Nat.Even_or_Odd length); simplify_Even.
Qed.

(* Hint: This lemma might also be useful; it allows you to multiply both sides of
   an equation by a divisor (c) to eliminate divisions. *)
Lemma eliminate_div_by_mul :
  forall a b c d,
    c <> 0 ->
    b mod c = 0 -> (* required because / is a floor division*)
    d mod c = 0 -> (* required because / is a floor division*)
    c * a + b = d ->
    a + b / c = d / c.
Proof.
  intros.
  apply Nat.mul_cancel_l with (p:=c); [ lia | ].
  autorewrite with push_mul.
  rewrite !Nat.mul_div_eq by lia.
  lia.
Qed.

(* Prove the formula correct. *)
Lemma arithmetic_series_equiv :
  forall inc length start : nat,
    arithmetic_series inc start length = length * (start + last_element inc start length) / 2.
Proof.
  cbv [last_element]. (* inline the formula for the last element *)
  induction length;
    cbn [arithmetic_series];
    repeat first [ progress intros
                 | rewrite IHlength
                 | apply eliminate_div_by_mul
                 | apply length_times_start_plus_last_even
                 | solver ].
Qed.
Clone this wiki locally