(* arbre 2-dimensionnel *)

let d = 2;; 

type point = float array;; (* tableaux de longueur d *)

type kdtree = V | N of kdtree * point * kdtree;; 

let pt_test = [|
[| 3.3 ; 0.9 |];
[| 2.0 ; 4.5 |];
[| 3.7 ; 5.1 |];
[| 8.4 ; 6.0 |];
[| 2.7 ; 1.8 |];
[| 1.2 ; 4.4 |];
[| 5.1 ; 3.6 |];
[| 5.7 ; 1.8 |];
[| 4.0 ; 1.4 |];
[| 6.6 ; 2.4 |];
[| 6.4 ; 0.7 |];
[| 7.9 ; 5.5 |];
|];;

let rec inserer tree pt p =
  (* inserer dans l'abre le point pt, sachant qu'on est à la profondeur p *) 
  match tree with
  | V -> N (V, pt, V)
  | N (ag, v, ad) -> 
    if pt.(p mod d) <= v.(p mod d) then N (inserer ag pt (p+1), v, ad)
    else N (ag, v, inserer ad pt (p+1))

let rec hauteur tree =
  match tree with
  | V -> -1
  | N (ag, _, ad) -> 1 + max (hauteur ag) (hauteur ad)


let affiche_kdtree tree =
  (* afficher le kdtree *)
  let rec affiche_espace n = 
    if n > 0 then begin print_string " "; affiche_espace (n-1) end
  in
  let rec affiche_aux decalage tree =  match tree with
    | V -> ()
    | N (ag, pt, ad) -> begin
        affiche_aux (decalage + 3) ag;
        affiche_espace decalage; Printf.printf "(%f,%f) \n" pt.(0) pt.(1);
        affiche_aux (decalage + 3) ad;
    end
  in
  affiche_aux 0 tree
;;


let construire data =
  (* construction du kdtree *)
  let tree = ref V in
  for i = 0 to Array.length data - 1 do
    tree := inserer !tree data.(i) 0
  done;
  !tree
;;


(* distance entre deux points *)
let dist pt1 pt2 = 
  let n = Array.length pt1 in
  let n2 = Array.length pt2 in
  assert (n = n2);
  let s = ref 0. in
  for i = 0 to (n - 1) do
    s := !s +. (pt1.(i) -. pt2.(i)) ** 2.
  done;
  sqrt !s;;

let plus_proche_voisin tree pt =
  let pppt = ref [|0.; 0.|] in (* pt le plus proche rencontré *)
  let dpppt = ref infinity in (* distance au pt le plus proche rencontré *)

  let rec parcours t p = 
    (* parcours intelligent de l'arbre en gardant en mémoire le point le plus proche *)
    match t with
    | V -> ()
    | N (ag, v, ad) -> 
      (* ordre de visite : quel est le premier sous arbre à visiter ? *)
      let (a1, a2) = if pt.(p mod d) <= v.(p mod d) then (ag, ad) else (ad, ag)
      in
        parcours a1 (p + 1);
        (* est-ce nécessaire de visiter le 2e sous arbre ? *)
        if (!dpppt > abs_float (pt.(p mod d) -. v.(p mod d))) then begin
          (* on compare avec le point séparateur *)
          let dv = dist pt v in 
          if dv < !dpppt then begin
            pppt := v; dpppt := dv; 
            Printf.printf "pppt courant : %f,%f %f\n" !pppt.(0) !pppt.(1) !dpppt
          end
          ; parcours a2 (p + 1)
        end (* else : sous-arbre non exploré *)
  in
  parcours tree 0; !pppt
;;

(* Test sur l'exemple *)
let tree = construire pt_test;;

affiche_kdtree tree;;

let pppt = plus_proche_voisin tree [|3.2;4.4|] in 
Printf.printf "pppt : %f,%f\n" pppt.(0) pppt.(1)


(* structure de données pour stocker les k meilleurs *)

type k_meilleurs = { nmax : int; mutable nb : int; mutable voisins : (float array * float * int) list}
(* nmax : nombre maximal de point à conserver *)
(* nb : nombre de points conservés *)
(* voisins : liste des points conservés, triés par distance décroissante, accompagné de leur classe *)

let creerKM k = { nmax = k; nb = 0; voisins = [] };;

let ajouterKM km (pt,dist,c) =
  let rec inserer (pt, dist, c) l = match l with
  | [] -> [(pt, dist, c)]
  | (pt2, dist2, c2) :: ll -> if dist < dist2 then (pt2, dist2, c2) :: (inserer (pt, dist, c) ll) else (pt, dist, c) :: l
  in
  if km.nb < km.nmax then begin (* il reste de la place ! *)
    km.voisins <- inserer (pt, dist, c) km.voisins;
    km.nb <- km.nb + 1
  end else (* c'est plein ! *)
    (* faut-il garder le point ? qui prend alors la place de la tête de liste *)
    let _,dmax,_ = List.hd km.voisins in
    if dist <= dmax then
      km.voisins <- inserer (pt, dist,c) (List.tl km.voisins)
;;


let rayonKM km =
  if (km.nb < km.nmax) then infinity else let _,dmax,_ = List.hd km.voisins in dmax;;


(* meme principe que précédemment : on parcourt l'arbre en mettant à jour la structure km *)
let k_plus_proches_voisins k tree pt =
  let km = creerKM k in (* les k points les plus proches rencontrés *)

  (* parcours intelligent de l'arbre en gardant en mémoire les k points les plus proches *)
  let rec parcours t p = 
    match t with
    | V -> ()
    | N (ag, v, ad) -> 
      (* ordre de visite : quel est le premier sous arbre à visiter ? *)
      let (a1, a2) = if pt.(p mod d) <= v.(p mod d) then (ag, ad) else (ad, ag)
      in
        parcours a1 (p + 1);
        (* est-ce nécessaire de visiter le 2e sous arbre ? *)
        let r = rayonKM km in
        if (r > abs_float (pt.(p mod d) -. v.(p mod d))) then begin
          (* on tient compte du point séparateur, et du 2e sous-arbre *)
          ajouterKM km (v, dist pt v, 0) ; parcours a2 (p + 1)
        end (* else : sous-arbre non exploré *)
  in
  parcours tree 0; List.map (fun (x,y,z) -> x) km.voisins
;;

(* test *)
let kpppt = k_plus_proches_voisins 3 tree [|3.2;4.4|] in
List.iter (fun pt -> Printf.printf "pt : %f,%f\n" pt.(0) pt.(1)) kpppt;;