open Syntax

module FeatureOrdered = struct type t = string let compare = compare end

module FClassifier = Naivebayes.Classifier (FeatureOrdered) (Monoid.IntMonoid)

type classifier_t = string FClassifier.t

let classifier : classifier_t ref = ref (FClassifier.empty ())
let processed_ftrs : (float FClassifier.Fm.t) ref = ref (FClassifier.Fm.empty)


(******************************************************************************)
(* convert from training examples to classification data *)

let ftrs_of_list l = FClassifier.Fm.map (fun _ -> 1.)
  (FClassifier.set_map_of_list FClassifier.Fm.empty l)

let rec trm_syms = function
  Name (n, _) -> [n]
| Ap (m1, m2) -> trm_syms m1 @ trm_syms m2
| Lam (_, m) -> trm_syms m
| _ -> []

(* symbols of processed terms *)
let proc_syms p = List.concat (List.map trm_syms p)

let conj_syms = function
  Some cj -> trm_syms cj
| None -> []

let training_ftrs cj proc = ftrs_of_list (conj_syms cj @ proc_syms proc)

let training_to_classifier t =
  let go cj proc m = (training_ftrs cj proc, (m, 1)) in
  List.concat (List.map (fun (cj, axs, proc, m, n) -> List.map (go cj proc) (trm_syms m)) t)


(******************************************************************************)
(* processed feature cache *)

let register_processed m =
  processed_ftrs := FClassifier.freq_map_of_list !processed_ftrs (trm_syms m)


(******************************************************************************)
(* Bayes classification *)

let load_classifier c = classifier := FClassifier.load c

let get_lbl_data lbl =
  let (tf, sf) = FClassifier.get_lbl_data !classifier lbl in
  (float_of_int tf, sf)


let relevance ftrs lbl =
  let (tf, sfs) = get_lbl_data lbl in
  let fl idf w = idf *. !Flags2.bayes_lweight
  and fi idf (w, sf) = idf *. !Flags2.bayes_iweight *. log (float_of_int sf /. tf)
  and fr idf sf = idf *. !Flags2.bayes_rweight *. log (1. -. float_of_int sf /. (tf +. 1.)) in
  !Flags2.bayes_fweight *. log tf +.
  FClassifier.relevance (FClassifier.get_idf !classifier) fl fi fr ftrs sfs

let bayes_rel ftrs m =
  if FClassifier.is_empty !classifier then 0.
  else
  begin
    let lbls = trm_syms m in
    let lbl_rels = List.map (relevance ftrs) lbls in
    let sum = Utils.float_sum lbl_rels in
    (*Format.printf "Bayes %f for %s\n" sum (trm_str m);*)
    sum
  end

