/*** TP Kmeans - Image
 *    Avec production de l'exécutable final
 *    Avec optimisations : 
 *       - calcul des barycentres efficaces
 *       - ne tenir compte que d'un pixel sur 100
 ***/

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>

#include "image.h"

/* calcule la distance au carré entre 2 pixels */
int dist2(pixel p1, pixel p2) {
    int dr = p1.r - p2.r;
    int dv = p1.v - p2.v;
    int db = p1.b - p2.b;
    return dr * dr + dv * dv + db * db;
}

/* renvoie l'indice du centroide le plus proche de p */
int classer(pixel p, pixel* centroides, int k){   
    // on cherche le centroide le plus proche
    int mc = 0; // meilleur_centroide
    int md2 = dist2(p, centroides[0]); // = meilleure distance au carré
    for (int m = 1; m < k; m += 1){
        pixel c = centroides[m];
        int d2 = dist2(p, c);
        if (d2 < md2){
            mc = m; md2 = d2;
        }
    }
    return mc;
}

bool egaux(pixel p1, pixel p2) {
    return (p1.r == p2.r) && (p1.v == p2.v) && (p1.b == p2.b);
}

/* applique l'algorithme de kmeans 
  - pix : les npix à clusteriser
  - classes : pour indiquer les classes finales des pixels
  - centroides : tableau des k centroides */
void kmeans(pixel* pix, int* classes, int npix, pixel* centroides, int k){
    // initialisation des classes
    for (int i = 0; i < npix; i +=1) {
        classes[i] = -1;
    }
    // initialisation des centroides 
    for (int i = 0; i < k; i += 1){
        centroides[i] = pix[rand() % npix];
    }

    // pour calculer les barycentres en un seul passage
    int* count = malloc(k * sizeof(int)); // compteurs
    int* somme_r = malloc(k * sizeof(int));
    int* somme_v = malloc(k * sizeof(int));
    int* somme_b = malloc(k * sizeof(int));

    bool change = true;
    while (change) {
        printf("Etape kmeans\n");
        change = false;
        // Etape 1 : calcul des nouvelles classes
        for (int i = 0; i < npix; i += 1) {
            int nc = classer(pix[i], centroides, k); // nouvelle classe
            if (nc != classes[i]) {
                change = true;
                classes[i] = nc;
            }
        }
        if (change) { // Etape 2 : nouveau calcul des centroides
            
            for (int i = 0; i < k; i += 1) { // remise à 0 des centroides
                count[i] = 0;
                somme_r[i] = 0;
                somme_v[i] = 0;
                somme_b[i] = 0;
            }
            for (int i = 0; i < npix; i += 1){ // somme des composantes
                int c  = classes[i];
                count[c] += 1;
                somme_r[c] += pix[i].r;
                somme_v[c] += pix[i].v;
                somme_b[c] += pix[i].b;
            }
            for (int i = 0; i < k; i += 1) { // division
                if (count[i] != 0) {
                    centroides[i].r = somme_r[i] / count[i];
                    centroides[i].v = somme_v[i] / count[i];
                    centroides[i].b = somme_b[i] / count[i];
                } else { // ???
                    centroides[i].r = 0;
                    centroides[i].v = 0;
                    centroides[i].b = 0;
                }
            }
        }
    }

    free(count);
    free(somme_r);
    free(somme_v);
    free(somme_b);
}

int main(int argc, char* argv[]) {
    if (argc != 4) {
        printf("usage : ./transform source.ppm dest.ppm k");
        return 0;
    } else {
        char* source = argv[1];
        char* dest = argv[2];
        int k = atoi(argv[3]);
    
        if (k == 0) {
            printf("usage : ./transform source.ppm dest.ppm k");
            return 0;
        }

        printf("Lecture de l'image\n");
        Image* im0 = importerImage(source);
        int n0 = im0->haut;
        int p0 = im0->larg;

        int pas = 10;

        int npix = (n0/pas + 1) * (p0/pas + 1);
        pixel* pix = malloc(npix * sizeof(pixel));

        int pos = 0;
        for (int i = 0; i < n0; i += pas) {
            for (int j = 0; j < p0; j += pas) {
                pix[pos] = getPix(im0, i, j);
                pos += 1;
            }
        }

        printf("Kmeans\n");

        pixel* centroides = malloc(k * sizeof(pixel));
        int* classes = malloc(npix * sizeof(int));

        kmeans(pix, classes, npix, centroides, k);

        printf("Modification de l'image\n");

        for (int i = 0; i < n0; i += 1) {
            for (int j = 0; j < p0; j += 1) {
                pixel p = getPix(im0, i, j);
                int c = classer(p, centroides, k);
                // on change le pixel
                setPix(im0, i, j, centroides[c]);
            }
        }

        printf("Ecriture de l'image\n");
        exporterImage(im0, dest);

        free(pix);
        free(centroides);
        free(classes);
        detruireImage(im0);
   
        return 0;
    }
}