(*****************************************************************)
(******************** Cases construct ****************************)
(*****************************************************************)

(* Analysing "match ... with " construct to replace them with    *)
(* the application of the appropriate recursion operator.        *)

(* Author: Nicolas Magaud *)
(* $Id: case_analyser.ml,v 1.16 2006/01/22 15:08:59 magaud Exp $ *)

open Nametab;;
open Pp;;
open Printer;;
open Environ;;
open Names;;
open Declarations;;
open Term;;
open Evd;;
open Nameops;;
open Inductive;;
open Reductionops;;
open Libnames;; (* ConstRef, string_of_qualid *)
open Util;; (* pr_int *)
open Inductiveops;; (* mis_is_recursive *)
open Sign;; (* empty_named_context *)
open Typeops;;
open Termops;; (* print_sort_family *)
open Indrec;; (* lookup_eliminator *)

let debug = ref false;;

exception Not_a_constant;;
exception Local of string;;
let failwith s = raise (Local(s));;

(* borrowed from contrib/extraction/table.ml *)
let constant_of_gr r = match r with 
    | ConstRef c -> c (*(constant_of_kn kn)*)
    | IndRef (kn,a) -> failwith "Not a constant; Inductive"
    | ConstructRef ((kn,_),_) -> failwith "Not a constant; Constructor"
    | VarRef _ -> assert false;;

(* get_constant_body : Libnames.qualid -> Declarations.constant_body *)
let get_constant_body qualid_exp = 
  let gr = (locate qualid_exp) in
    try
      let c = constant_of_gr gr in 
	(* successful only if gr denotes a constant *)
	begin
          Conv_oracle.set_transparent_const c; 
          Global.lookup_constant c
	end
    with _ -> 
      begin
	msgnl ((str "RemoveCases: ")++
		 (pr_global gr)++
		 (str " is not a constant."));
	raise Not_a_constant
      end
	
let get_statement qualid_exp  = (get_constant_body qualid_exp).const_type;;

let get_proof_term qualid_exp = 
  match (get_constant_body qualid_exp).const_body with
    | None -> 
        failwith ("get_proof_term: No proof term for "^
                    (string_of_qualid qualid_exp))
    | (Some p) -> force p;;

(* is "i" a recursive inductive ? *)
let check_rec i =                                    
  let mut_i = fst (lookup_mind_specif (Global.env()) i)
  and one_i = snd (lookup_mind_specif (Global.env()) i)
  in (mis_is_recursive (i, mut_i, one_i));;

let rec strip_prod t =
  match (kind_of_term t) with
      Prod(x,t,b) -> strip_prod b
    | _ -> t;;

let head_app t =
  match (kind_of_term t) with
      App(h,_) -> h
    | _ -> failwith "head_app";;

let tail_app t =
  match (kind_of_term t) with
      App(_,ta) -> ta
    | _ -> failwith "tail_app";;

let global_qualified_reference v =  constr_of_reference (locate v);;

let rec head_of_first_application t =
  match (kind_of_term t) with 
    | Prod(x,t,b) -> head_of_first_application b
    | App(h,_) -> h
    | Ind(_) -> t
    | _ -> (msgnl (prterm t);failwith "head_of_first_application");;

let list_idents theenv = 
  List.map (function (a,b,c) -> a) (named_context theenv);;

(*let rec unfold_head c0 theenv = whd_betadeltaiota theenv (Evd.empty) c0*)

let global_reference v = constr_of_reference (locate (make_short_qualid v));;

let rec strip_dummy_abstractions tm theenv =
  match (kind_of_term tm) with
    | Lambda(Anonymous,t,b) -> 
	let new_name = 
	  (next_ident_away (id_of_string "anon") (list_idents theenv)) 
	in 
	    (strip_dummy_abstractions 
	       (subst1 (mkVar new_name) b) theenv)
    | Lambda(Name(n),t,b) -> 
        let new_name = (next_ident_away n (list_idents theenv))
        in 
	  if (occur_var theenv new_name (subst1 (mkVar new_name) b)) 
	    then 
	    mkNamedLambda 
              new_name
              t
              (strip_dummy_abstractions
                 (subst1 (mkVar new_name) b) theenv)
	  else
	    (strip_dummy_abstractions 
	       (subst1 (mkVar new_name) b) theenv)
    | _ -> tm
;;

let db_type_of s the_env the_term = 
  try (fst (infer the_env the_term)).uj_type with
      e ->  
	begin
	  msgnl (str ("Type exception raised in "^s^" on term:\n"));
	  msgnl (prterm the_term);
	  raise e
	end
;;

let cased_exp_type t context =
  match (kind_of_term t) with
    | Case(info,p,a,ba) ->  
        db_type_of "cased_exp_type" context a
    | _ -> failwith "cased_exp_type must be applied to a case construct";;

let case_into_sort t context =
  match (kind_of_term t) with
    | Case(info,p,a,ba) -> 
        let predicate_type = 
          Typing.type_of context empty p
        in 
          begin 
            (*
	    if !debug then msgnl (str "case_into_sort");
            if !debug then msgnl ((str "p: ")++(prterm p));
            if !debug then msgnl ((str "predicate_type: ")++
                                     (prterm predicate_type));
            *)
            let sp = 
(* We should be careful about raw-stripping, it might affect type inference *)
	      strip_prod predicate_type in 
            let the_sort = 
              try destSort sp 
              with _ -> destSort (db_type_of "stripping" context sp)
            in family_of_sort the_sort 
            end
    | _ -> failwith "case_into_sort must be applied to a case construct.";;

let rec nb_Prod the_exp =
  match (kind_of_term the_exp) with
    | Prod(_,_,body) -> (nb_Prod body)+1
    | _ -> 0

let rec fix_exp_wrt_type nb_prods the_exp the_type context =
  if nb_prods==0 
  then 
    begin
      if !debug then msgnl ((str "fix_exp_type: term= ")++
			      (prterm the_exp));
      if !debug then msgnl ((str " type= ")++
			      (prterm the_type));
      the_exp 
    end
  else
    match 
      begin 
	if !debug then msgnl ((str "fix_exp_type: term= ")++
				(prterm the_exp));
	if !debug then msgnl ((str " type= ")++
				(prterm the_type));
	(kind_of_term the_exp),(kind_of_term the_type)
      end 
    with
      | Lambda(Anonymous, type_var1, body1),
	Prod(_,type_var2, body2) ->
	  if (is_conv context empty type_var1 type_var2)
	  then
	    let new_name = 
	      (next_ident_away (id_of_string "anon") (list_idents context))
	    in 
	      mkLambda (Anonymous, 
			type_var1,
			fix_exp_wrt_type (nb_prods-1)
			  (subst1 (mkVar new_name) body1)
			  (subst1 (mkVar new_name) body2)
			  (push_named (new_name,None,type_var1) context))	
	  else 
	    begin 
	      msgnl (str "tv1 and tv2 are not convertible; not handled yet");
	      the_exp 
	    end
      | Lambda((Name name_var1),type_var1,body1), 
	  Prod((Name name_var2),type_var2,body2) -> 
	  if (is_conv context empty type_var1 type_var2)
	  then
            let new_name = (next_ident_away name_var1 (list_idents context)) 
	    in mkNamedLambda 
		 new_name
		 type_var2
		 (fix_exp_wrt_type  (nb_prods-1)
		    (subst1 (mkVar new_name) body1)
		    (subst1 (mkVar new_name) body2)
		    (push_named (new_name,None,type_var2) context))
	  else
	    begin
	      if !debug then msgnl ((str "extra arg is not the last one: ")++ 
				      (prterm type_var1)++(str " ")++
				      (prterm type_var2)); 
              let new_name = (next_ident_away name_var2 (list_idents context)) 
	      in mkLambda (Anonymous, 
			   type_var2, 
			   (fix_exp_wrt_type 
			      (nb_prods-1)
			      the_exp
			      (subst1 (mkVar new_name) body2) 
			      context))
	    end
      | Lambda((Name name_var1),type_var1,body1), 
	    Prod(_,type_var2,body2) -> 
	  if (is_conv context empty type_var1 type_var2)
	  then
            let new_name = (next_ident_away name_var1 (list_idents context)) 
	    in mkNamedLambda 
		 new_name
		 type_var2
		 (fix_exp_wrt_type(nb_prods-1)
		    (subst1 (mkVar new_name) body1)
		    (subst1 (mkVar new_name) body2)
		    (push_named (new_name,None,type_var2) context))
	  else
	    begin
	      if !debug 
	      then msgnl ((str "extra arg not the last one and Anonymous: ")++ 
			    (prterm type_var1)++(str " ")++
			    (prterm type_var2)); 
              let new_name = 
		(next_ident_away (id_of_string "anon") (list_idents context)) 
	      in mkLambda (Anonymous, 
			   type_var2, 
			   (fix_exp_wrt_type (nb_prods-1)
			      the_exp
			      (subst1 (mkVar new_name) body2) 
			      context))
	    end
	      
      | _, Prod((Name name_var),type_var2,body2) -> 
	  begin
	    if !debug then msgnl (str "Unspecified lambda, Named prod: ");
	    let new_name = (next_ident_away name_var (list_idents context)) in
	    let sub_call = 
	      fix_exp_wrt_type (nb_prods-1)
		the_exp 
		(subst1 (mkVar new_name) body2)
		(push_named (new_name,None,type_var2) context)
	    in mkLambda (Anonymous, type_var2, sub_call)
	  end

      | _, Prod(Anonymous,type_var2,body2) -> 
	  begin 
	    if !debug then msgnl (str "something weird here but Anonymous: ");
	    let new_name = 
	      next_ident_away (id_of_string "anon") (list_idents context) in
	    let sub_call = 
	      fix_exp_wrt_type (nb_prods-1)
		the_exp 
		(subst1 (mkVar new_name) body2)
		(push_named (new_name,None,type_var2) context)
	    in 
	      mkLambda
		(Anonymous, type_var2, sub_call)
	  end
	    
	    
      | _,_ -> 	begin   msgnl (str "end of filling:");the_exp end

(*
let second_arg type0 =
  match (kind_of_term type0) with
    | Prod(_,_,b) -> 
	(match (kind_of_term b) with 
	   | Prod(_,t,_) -> t
	   | _ -> failwith "second_arg, inner match failed!")
    | _ -> failwith "second_arg, outer match failed!"
*)

let rec map3_list f l1 l2 l3 =
match l1,l2,l3 with
  | [],[],[] -> []
  | (h1::r1), (h2::r2), (h3::r3) -> (f h1 h2 h3)::(map3_list f r1 r2 r3)
  | _ -> failwith "Inputs for map3_list must have the same size."

let map3 f a1 a2 a3 = 
  let l1 = Array.to_list a1
  and l2 = Array.to_list a2
  and l3 = Array.to_list a3
  in Array.of_list (map3_list f l1 l2 l3)

(* extracts "how_many_prod" step hypotheses *)
(* from "the_constr" which is a recursor type *)
let branches_types the_constr how_many_prod =
  let rec aux c how_many =
    if how_many==0 
    then
      []
    else
      match (kind_of_term c) with 
	| Prod(_, type_var, type_body) ->
	    type_var::(aux type_body (how_many-1))
	| _ -> failwith "Not enough products in branches_types."
  in Array.of_list (aux the_constr how_many_prod)


let partial_app_of_elim_upto_P 
    eliminator predicate type_of_cased_exp npar context =
  if npar==0
  then 
    mkApp (eliminator, Array.make 1 predicate)
  else
    let params = 
      let whd_type = 
	let t = whd_betadeltaiota context empty type_of_cased_exp
	in begin 
	    if !debug then 
	      msgnl ((str "partial_app_upto_P whd_betadeltaiota: ")++
		       (prterm t));
	    t 
	  end
      in Array.sub (snd (destApplication whd_type)) 0 npar
    in mkApp (mkApp (eliminator, params), Array.make 1 predicate)

let final_args cased_exp type_of_cased_exp npar context =
  let whd_type = 
    let t = (*strong*) whd_betadeltaiota context empty type_of_cased_exp
    in 
      begin
	if !debug then msgnl ((str "whd_betadelta of term: ")++(prterm t));
	t
      end
  in if (isApp whd_type)
    then let params_array = snd (destApplication whd_type)
    in Array.append
	 (Array.sub params_array npar ((Array.length params_array)-npar))
	 (Array.make 1 cased_exp)
    else Array.make 1 cased_exp 

let rec stripn_prod t n context =
  if n==0 
  then t
  else 
    match (kind_of_term t) with
      | Prod(_,_,body) -> 
	  let new_name = 
	    next_ident_away (id_of_string "anon") (list_idents context)
	  in stripn_prod (subst1 (mkVar new_name) body) (n-1) context
      | _ -> failwith "Not enough prods for stripn_prod"

let case_into_elimination t context eliminator =
  let elim_type = db_type_of "case_into_elimination" context eliminator 
  in match (kind_of_term t) with
    | Case(info,p,a,ba) -> 
	let type_of_a = 
	  let t =
	  (*(strong whd_betadelta) 
	    context 
	    empty *)
 	    (cased_exp_type t context) 
	  in 
	    begin 
	      if !debug then msgnl ((str "type of a: ")++(prterm t));
	      t 
	    end
	in
	let elim_applied_to_p = 
	  partial_app_of_elim_upto_P eliminator p type_of_a info.ci_npar context
	in
	let elim_type_with_p = 
	  (local_strong whd_beta)
	   (* context empty*)
	    (db_type_of "case_into_elimination2" 
	       context 
	       elim_applied_to_p)
	in 
	  begin
	    if !debug then msgnl (str "case_into_elimination");
	    if !debug then msgnl ((str "eliminator: ")++(prterm eliminator));
	    if !debug then msgnl ((str "type: ")++(prterm elim_type));
	    if !debug then msgnl ((str "type with P: ")++
				    (prterm elim_type_with_p));
	    let array_branches_types = 
	      begin
		if !debug then 
		  msgnl ((str "#branches= ")++(pr_int (Array.length ba)));
		if !debug then 
		  msgnl ((str "before stripping (not required any more!): ")++
			   (prterm elim_type_with_p));
		branches_types elim_type_with_p (Array.length ba)
	      end
	    and branches_rec_sizes =
	      Array.map
		nb_Prod
		(branches_types 
(* to find the number of prods in each branch of the induction principle,  *)
(* we prefer to use the type of the completed predicate "elim_type"        *)
(* instead of "elim_type_with_p" (this should avoid extra prods to appear) *)
		   (stripn_prod elim_type (info.ci_npar+1) context)
		   (Array.length ba))
	    in 
	    let updated_branches = 
	      begin
		(try 
		   if !debug 
		   then msgnl ((str "trying... #prods branch 0: ")++
				 (pr_int branches_rec_sizes.(0)))
		 with _ -> ());
		(try 
		   if !debug 
		   then msgnl ((str "trying... #prods branch 1: ")++
				 (pr_int branches_rec_sizes.(1)))
		 with _ -> ());
		map3 
		  (fun i j k -> (fix_exp_wrt_type k i j context)) 
		  ba 
		  array_branches_types
		  branches_rec_sizes
	      end
	    in begin
		  if !debug 
		  then msgnl 
		    ((str "after update: ")++
		       (prterm 
			  (mkApp 
			     ((mkApp 
				 (elim_applied_to_p,updated_branches)),
			      (final_args a type_of_a info.ci_npar context)))));
		
		mkApp ((mkApp (elim_applied_to_p,updated_branches)),
		       (final_args a type_of_a info.ci_npar context))
	      end
	  end
    | _ -> 
	failwith 
	  "case_into_elimination can only be applied to case constructs."
	  
let rec search_and_replace_cases pf_term context =
  match (kind_of_term pf_term) with 
    | Lambda((Name n),t,b) -> 
        let new_name = (next_ident_away n (list_idents context)) 
        in mkNamedLambda 
             new_name
             (search_and_replace_cases t context)
             (search_and_replace_cases
		(subst1 (mkVar new_name) b)
		(push_named (new_name,None,t) context))
    | Lambda(Anonymous,t,b) ->
        let new_name = 
          (next_ident_away (id_of_string "anon") (list_idents context)) 
        in mkNamedLambda
             new_name
             (search_and_replace_cases t context)
             (search_and_replace_cases
		(subst1 (mkVar new_name) b)
		(push_named (new_name,None,t) context))
    | LetIn(Anonymous,a,ta,b) -> (* let $[x:=b:t_1]t_2$ *)
        mkLetIn (Anonymous, 
                 (search_and_replace_cases a context), 
                 (search_and_replace_cases ta context), 
                 (search_and_replace_cases b context))
    | LetIn(Name(x),t1,t2,t3) -> (* let x = t1 : t2 in t3 *)
        let new_name = (next_ident_away x (list_idents context))
        in mkNamedLetIn
             new_name
             (search_and_replace_cases t1 context)
             (search_and_replace_cases t2 context)
             (search_and_replace_cases 
                (subst1 (mkVar new_name) t3)
                (push_named (new_name,Some t1,t2) context) )
    | App(c,c1n) -> 
        mkApp ((search_and_replace_cases c context), 
               (Array.map 
                  (function c -> search_and_replace_cases c context) 
                  c1n))
    | Const(c) -> mkConst c
    | Ind(i) -> mkInd i
    | Construct(c) -> mkConstruct c
    | Case(info,p',a',ba') -> 
	begin 
          if !debug then msgnl (str "Case construct: ");
	  let p = search_and_replace_cases p' context
	  and a = search_and_replace_cases a' context
	  and ba = (Array.map 
                      (function c -> search_and_replace_cases c context) ba')
	  in let pf_term= mkCase(info,p,a,ba)
	  in 
          let the_sort_of_type_p_stripped = 
	    (*try*) case_into_sort pf_term context (*with _ -> InSet*) in
	    
	  let p = match the_sort_of_type_p_stripped with
	    | InProp -> 
		let the_sort_of_a =
		  let the_type_of_a =db_type_of "cased_exp_type" context a 
		  in db_type_of "cased_exp_type2" context the_type_of_a
		in (match 
		      begin
			if !debug 
			then msgnl ((str "sort of a: ")++
				      (prterm the_sort_of_a));
			(destSort the_sort_of_a)
		      end with
			| Prop(Null) -> 
			    (* this is Prop, thus non-dependent elim *)
			    strip_dummy_abstractions p context
			| _ -> p)
	    | InSet | InType -> p
	  and the_type_of_a = db_type_of "cased_exp_type" context a in
	  let the_type_of_p = 
	    begin 
	      if !debug then msgnl ((str "real p is : ")++(prterm p));
	      db_type_of "the_type_of_p" context p
	    end
              in let the_eliminator = 
                  lookup_eliminator info.ci_ind
                  (*(try (destInd the_type_of_a) 
                    with _ -> (destInd (head_app the_type_of_a)))*)
                    the_sort_of_type_p_stripped 
              in 
		begin
                  if !debug then msgnl ((str "p= ")++(prterm p));
                  
                  (* msgnl ((str "env1: ")++(pr_context_of (Global.env()))); *)
                  (* msgnl ((str "env2: ")++(pr_context_of context));        *)
                  if !debug 
		  then msgnl ((str "case_info: ci_ind= ")++
				(prterm (mkInd info.ci_ind))++
				(str ", ci_npar= ")++(pr_int info.ci_npar));
                  if !debug 
		  then msgnl ((str "type_of_a = ")++(prterm the_type_of_a));
                  if !debug 
		  then msgnl ((str "type_of_p = ")++(prterm the_type_of_p));
                  if !debug 
		  then msgnl ((str "candidate eliminator: ")++
				(prterm the_eliminator));
		  (case_into_elimination 
		     (mkCase(info,p,a,ba)) (* was pf_term *)
		     context 
		     the_eliminator)
		end
	end
    | _ -> pf_term ;; (*failwith "Handler missing... to be added"*)

let remove_cases ident = 
  let qualid_exp = make_short_qualid ident in
  let pf_term = 
    begin
      if !debug then msgnl (str "--- begin intocase ---");
      get_proof_term qualid_exp 
    end in
(*  let v = db_type_of "initial" (Global.env()) pf_term in*)
  let term = 
    search_and_replace_cases pf_term (Global.env()) 
  in begin
      if !debug 
      then msgnl ((str "updated term after remove_cases: ")++(prterm term));
      (try 
	let new_type = db_type_of "intocase" (Global.env()) term
	in if !debug then msgnl ((str "type: ")++(prterm new_type))
      with _ -> failwith "Ill-typed. Not happy!");
    
(*
      List.iter 
        (function (id,_,tj) -> 
           begin
             if !debug then msgnl ((pr_id id)++(str ": ")++(prterm tj))
           end
        )
        (List.rev (named_context context));
*)
      
      if !debug then msgnl (str "--- end intocase ---")
    end;;
      

VERNAC COMMAND EXTEND IntoCase
  [ "RemoveCases" ident(i1) ]
   -> [ remove_cases i1 ]
END
  
