(* KNN version simple *)

let nb_cl = 3;;  (* nombre de classes ; de 0 à nb_cl - 1 *)

let data = [|  (* ensemble d'entrainement *)
([| 149. ; 304. |], 0);
([| 207. ; 346. |], 0);
([| 79. ; 155. |], 0);
([| 257. ; 287. |], 0);
([| 499. ; 86. |], 0);
([| 432. ; 309. |], 0);
([| 498. ; 343. |], 0);
([| 356. ; 178. |], 0);
([| 401. ; 224. |], 0);
([| 75. ; 119. |], 0);
([| 334. ; 338. |], 0);
([| 374. ; 242. |], 0);
([| 608. ; 378. |], 1);
([| 310. ; 496. |], 1);
([| 610. ; 485. |], 1);
([| 544. ; 298. |], 1);
([| 550. ; 483. |], 1);
([| 670. ; 227. |], 1);
([| 719. ; 237. |], 1);
([| 673. ; 386. |], 1);
([| 319. ; 359. |], 1);
([| 560. ; 252. |], 1);
([| 395. ; 250. |], 1);
([| 299. ; 529. |], 1);
([| 449. ; 533. |], 1);
([| 463. ; 387. |], 1);
([| 265. ; 324. |], 1);
([| 716. ; 429. |], 1);
([| 481. ; 371. |], 1);
([| 502. ; 280. |], 1);
([| 428. ; 623. |], 1);
([| 355. ; 424. |], 1);
([| 796. ; 301. |], 2);
([| 938. ; 430. |], 2);
([| 507. ; 320. |], 2);
([| 519. ; 336. |], 2);
([| 573. ; 269. |], 2);
([| 557. ; 198. |], 2);
([| 761. ; 285. |], 2);
([| 829. ; 434. |], 2);
([| 522. ; 182. |], 2);
([| 766. ; 105. |], 2);
([| 803. ; 124. |], 2);
([| 884. ; 420. |], 2);
([| 769. ; 235. |], 2);
([| 912. ; 151. |], 2);
([| 833. ; 387. |], 2);
|];;

(* distance entre deux points *)
let dist pt1 pt2 = 
  let n = Array.length pt1 in
  let n2 = Array.length pt2 in
  assert (n = n2);
  let s = ref 0. in
  for i = 0 to (n - 1) do
    s := !s +. (pt1.(i) -. pt2.(i)) ** 2.
  done;
  sqrt !s;;

(* classifieur knn *)
let classifie k data pt =
  (* k : nb de plus proches voisins à considérer *)
  (* data : ensemble d'entrainement *)
  (* pt : le point à classer *)
  assert (k <= Array.length data); (* programmation defensive *)
  let n = Array.length data in
  (* tableau des couples (distance, classe) *)
  let distances_cl = Array.make n (0., 0) in 
  for i = 0 to n - 1 do
    let pt2, cl = data.(i) in
    distances_cl.(i) <- (dist pt pt2, cl)
  done;
  (* tri par distance croissante *)
  Array.sort (fun (d1,c1) (d2,c2) -> if d1 < d2 then -1 else 1) distances_cl;
  (* tableau des compteurs *)
  let compteurs = Array.make nb_cl 0 in
  let maxi = ref 0 in 
  let cmaxi = ref 0 in 
  let egalite = ref true in
  for i = 0 to k - 1 do
    let dist, c = distances_cl.(i) in
    compteurs.(c) <- compteurs.(c) + 1;
    if compteurs.(c) > !maxi then begin
      maxi := compteurs.(c); cmaxi := c; egalite := false
    end else if compteurs.(c) = !maxi then 
      egalite := true
  done;
  if !egalite then -1 else !cmaxi
;;

(* K_Meilleurs *)

(* structure de données pour stocker les k meilleurs *)

type k_meilleurs = { nmax : int; mutable nb : int; mutable voisins : (float array * float * int) list}
(* nmax : nombre maximal de point à conserver *)
(* nb : nombre de points conservés *)
(* voisins : liste des points conservés, triés par distance décroissante, accompagnés de leur classe *)

let creerKM k = { nmax = k; nb = 0; voisins = [] };;

let ajouterKM km (pt,dist,c) =
  let rec inserer (pt, dist, c) l = match l with
  | [] -> [(pt, dist, c)]
  | (pt2, dist2, c2) :: ll -> if dist < dist2 then (pt2, dist2, c2) :: (inserer (pt, dist, c) ll) else (pt, dist, c) :: l
  in
  if km.nb < km.nmax then begin (* il reste de la place ! *)
    km.voisins <- inserer (pt, dist, c) km.voisins;
    km.nb <- km.nb + 1
  end else (* c'est plein ! *)
    (* faut-il garder le point ? qui prend alors la place de la tête de liste *)
    let _,dmax,_ = List.hd km.voisins in
    if dist <= dmax then
      km.voisins <- inserer (pt, dist,c) (List.tl km.voisins)
;;

(* classifieur knn optimisé  *)
let classifieKM k data pt =
  (* k : nb de plus proches voisins à considérer *)
  (* data : ensemble d'entrainement *)
  (* pt : le point à classer *)
  assert (k <= Array.length data); (* programmation defensive *)
  let km = creerKM k in
  let n = Array.length data in
  for i = 0 to n - 1 do
    let pt2, cl = data.(i) in
    ajouterKM km (pt2, dist pt pt2, cl)
  done;
  (* tableau des compteurs *)
  let compteurs = Array.make nb_cl 0 in
  let maxi = ref 0 in 
  let cmaxi = ref 0 in 
  let egalite = ref true in
  let traite_meilleur (pt, dist, c) =
    compteurs.(c) <- compteurs.(c) + 1;
    if compteurs.(c) > !maxi then begin
      maxi := compteurs.(c); cmaxi := c; egalite := false
    end else if compteurs.(c) = !maxi then 
      egalite := true
  in List.iter traite_meilleur km.voisins;
  if !egalite then -1 else !cmaxi
;;



let pts = [| [|500.;345.|];  [|550.;300.|] |] in
let ks = [|1;2;3;4;5|] in
for i = 0 to Array.length pts - 1 do
  for j = 0 to Array.length ks - 1 do
    Printf.printf "%d " (classifieKM ks.(j) data pts.(i))
  done; print_newline ()
done;;





(* couleurs associées aux classes / zones*)
let cz0 = Graphics.rgb 50 250 50;;
let cp0 = Graphics.rgb 35 175 35;;
let cz1 = Graphics.rgb 250 50 50;;
let cp1 = Graphics.rgb 175 35 35;;
let cz2 = Graphics.rgb 50 50 250;;
let cp2 = Graphics.rgb 35 35 175;;

let set_color_of_classe cl z =
  let couleur = 
    match cl with
    | 0 -> if z then cz0 else cp0
    | 1 -> if z then cz1 else cp1
    | 2 -> if z then cz2 else cp2
    | _ -> Graphics.white
  in  
  Graphics.set_color couleur;;

(* affichage d'un ensemble d'entraînement *)
let affichage_points data =
  Graphics.set_line_width 3;
  for i = 0 to Array.length data - 1  do
    let coord, cl = data.(i) in
    set_color_of_classe cl false;
    Graphics.draw_rect (int_of_float coord.(0)-1) (int_of_float coord.(1)-1) 3 3
  done;;

(* classifie tous les points de l'image *)
let classifie_zone k data =
  for x = 0 to 999  do
    for y = 0 to 699 do
      if x mod 10 = 0 && y mod 10 = 0 then begin 
        let cl = classifie k data [|float_of_int x; float_of_int y|] in
        set_color_of_classe cl true;
        Graphics.draw_rect (x-1) (y-1) 3 3
      end
    done
  done;;


(* Affichage *)

Graphics.open_graph " 1000x700" ; (* ouverture de la fenetre graphique *)

affichage_points data;;

Graphics.read_key ();; (* attente d'une saisie clavier *)

classifie_zone 3 data;;

affichage_points data;;

Graphics.read_key ();;