/*** TP Kmeans - Image
 *    Sans production de l'exécutable final
 *    Sans optimisations : 
 *       - calcul des barycentres inefficace
 *       - en tenant compte de tous les pixels
 ***/

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>

#include "image.h"

/* calcule la distance au carré entre 2 pixels */
int dist2(pixel p1, pixel p2) {
    int dr = p1.r - p2.r;
    int dv = p1.v - p2.v;
    int db = p1.b - p2.b;
    return dr * dr + dv * dv + db * db;
}

/* renvoie l'indice du centroide le plus proche de p */
int classer(pixel p, pixel* centroides, int k){   
    // on cherche le centroide le plus proche
    int mc = 0; // meilleur_centroide
    int md2 = dist2(p, centroides[0]); // = meilleure distance au carré
    for (int m = 1; m < k; m += 1){
        pixel c = centroides[m];
        int d2 = dist2(p, c);
        if (d2 < md2){
            mc = m; md2 = d2;
        }
    }
    return mc;
}

bool egaux(pixel p1, pixel p2) {
    return (p1.r == p2.r) && (p1.v == p2.v) && (p1.b == p2.b);
}

/* applique l'algorithme de kmeans 
  - pix : les npix à clusteriser
  - classes : pour indiquer les classes finales des pixels
  - centroides : tableau des k centroides */
void kmeans(pixel* pix, int* classes, int npix, pixel* centroides, int k){
    // initialisation des classes
    for (int i = 0; i < npix; i +=1) {
        classes[i] = -1;
    }
    // initialisation des centroides 
    for (int i = 0; i < k; i += 1){
        centroides[i] = pix[rand() % npix];
    }
    
    bool change = true;
    while (change) {
        printf("Etape kmeans\n");
        change = false;
        // Etape 1 : calcul des nouvelles classes
        for (int i = 0; i < npix; i += 1) {
            int nc = classer(pix[i], centroides, k); // nouvelle classe
            if (nc != classes[i]) {
                change = true;
                classes[i] = nc;
            }
        }
        if (change) { // Etape 2 : nouveau calcul des centroides
            for (int c = 0; c < k; c += 1) { 
                int somme_r = 0;
                int somme_v = 0;
                int somme_b = 0;
                int count = 0;
                for (int i = 0; i < npix; i += 1){
                    if (classes[i] == c) {
                        count += 1;
                        somme_r += pix[i].r;
                        somme_v += pix[i].v;
                        somme_b += pix[i].b;
                    }
                }
                if (count > 0) {
                    centroides[c].r = somme_r / count;
                    centroides[c].v = somme_v / count;
                    centroides[c].b = somme_b / count;
                }
            }
        }
    }
}

int main() {
    int k = 4;
    
    printf("Lecture de l'image\n");
    Image* im0 = importerImage("johnR.ppm");
    int n0 = im0->haut;
    int p0 = im0->larg;

    int npix = n0 * p0;
    pixel* pix = malloc(npix * sizeof(pixel));

    int pos = 0;
    for (int i = 0; i < n0; i += 1) {
        for (int j = 0; j < p0; j += 1) {
            pix[pos] = getPix(im0, i, j);
            pos += 1;
        }
    }

    printf("Kmeans\n");

    pixel* centroides = malloc(k * sizeof(pixel));
    int* classes = malloc(npix * sizeof(int));

    kmeans(pix, classes, npix, centroides, k);

    printf("Modification de l'image\n");

    pos = 0;
    for (int i = 0; i < n0; i += 1) {
        for (int j = 0; j < p0; j += 1) {
            // on change le pixel
            int c = classes[pos];
            setPix(im0, i, j, centroides[c]);
            pos += 1;
        }
    }

    printf("Ecriture de l'image\n");
    exporterImage(im0, "johnR_4.ppm");

    free(pix);
    free(centroides);
    free(classes);
    detruireImage(im0);
   
    return 0;
}