open Utils

module Classifier (Feature : Map.OrderedType) (LabelNo : Monoid.Monoid) = struct

module Fm = Map.Make(Feature)

type 'l t =
    (* number of training examples *)
  { mutable te_num : int
    (* number of times feature occurs in the training examples *)
  ; ftr_no : (Feature.t, float) Hashtbl.t
    (* number of times label occurs in the training examples (tfreq) *)
  ; lbl_no : ('l, LabelNo.t) Hashtbl.t
    (* how often did label occur together with feature (sfreq) *)
  ; lbl_ftr_no : ('l, LabelNo.t Fm.t) Hashtbl.t
    (* feature IDF *)
  ; mutable ftr_idf : float Fm.t
  }

let empty () : 'l t =
  { te_num = 0
  ; ftr_no = Hashtbl.create 1
  ; lbl_no = Hashtbl.create 1
  ; lbl_ftr_no = Hashtbl.create 1
  ; ftr_idf = Fm.empty
  }


(* ****************************************************************************)
(* Map helper functions *)

let list_of_map m = Fm.fold (fun k v vs -> (k, v) :: vs) m []

let map_image m = Fm.fold (fun k v vs -> v :: vs) m []

let map_map1 m k f x = Fm.add k (try f x (Fm.find k m) with Not_found -> x) m

let freq_map_of_list m l =
  List.fold_left (fun acc ftr -> map_map1 acc ftr (+.) 1.) m l

let set_map_of_list m l = List.fold_left (fun acc ftr -> Fm.add ftr () acc) m l


(* ****************************************************************************)
(* I/O *)

let load c =
  let te_num     = input_value c
  and ftr_no     = input_value c
  and lbl_no     = input_value c
  and lbl_ftr_no = input_value c
  and ftr_idf    = input_value c in
  { te_num = te_num
  ; ftr_no = ftr_no
  ; lbl_no = lbl_no
  ; lbl_ftr_no = lbl_ftr_no
  ; ftr_idf = ftr_idf
  }

let write d fp = with_out fp (fun c ->
  output_value c d.te_num;
  output_value c d.ftr_no;
  output_value c d.lbl_no;
  output_value c d.lbl_ftr_no;
  output_value c d.ftr_idf)


(* ****************************************************************************)
(* obtain learned data *)

let is_empty d = d.te_num = 0

let get_stats d = (d.te_num, Hashtbl.length d.ftr_no, Hashtbl.length d.lbl_no)

let get_idf d ftr = try Fm.find ftr d.ftr_idf with Not_found -> 0.

(*let get_lbl_freq d (tf, sf) = tf /. float_of_int (d.te_num)*)

let get_lbl_data d lbl =
  try (Hashtbl.find d.lbl_no lbl, Hashtbl.find d.lbl_ftr_no lbl)
  with Not_found -> (LabelNo.zero, Fm.empty)


(* ****************************************************************************)
(* update learned data *)

let update_ftr_no d ftrs = Fm.iter (fun k w -> hashtbl_map1 d.ftr_no k (+.) w) ftrs

let update_lbl_no d (lbl, lbli) =
  hashtbl_map1 d.lbl_no lbl LabelNo.plus lbli

let update_lbl_ftr_no d (lbl, lbli) ftrs =
  let fold_fun ftr _ fm = map_map1 fm ftr LabelNo.plus lbli in
  hashtbl_map0 d.lbl_ftr_no lbl (Fm.fold fold_fun ftrs) Fm.empty

let calc_idf d ftr =
  log (float_of_int d.te_num) -. log (Hashtbl.find d.ftr_no ftr)

let update_ftr_idf d =
  let ftrs = Utils.hashtbl_keys d.ftr_no in
  List.fold_left (fun m ftr -> Fm.add ftr (calc_idf d ftr) m) Fm.empty ftrs

let add_training_ex d (ftrs, lbl) =
  d.te_num <- d.te_num + 1;
  update_ftr_no d ftrs;
  update_lbl_no d lbl;
  update_lbl_ftr_no d lbl ftrs

let add_training_exs d l =
  List.iter (add_training_ex d) l;
  Printf.printf "pf, cn and cn_pf frequencies updated\n%!";
  d.ftr_idf <- update_ftr_idf d;
  Printf.printf "IDF information calculated\n%!"


(* ****************************************************************************)
(* Naive Bayes relevance *)

let relevance idf fl fi fr ftrs sfreq =
  let go ftr l r = match (l, r) with
      (Some l, None  ) -> Some (fl (idf ftr) l)
    | (Some l, Some r) -> Some (fi (idf ftr) (l, r))
    | (None  , Some r) -> Some (fr (idf ftr) r)
    | _ -> None in
  float_sum (map_image (Fm.merge go ftrs sfreq))

end

