#!/usr/bin/env python

#-------------------------------------------------------------------------------

# This file is part of Code_Saturne, a general-purpose CFD tool.
#
# Copyright (C) 1998-2018 EDF S.A.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51 Franklin
# Street, Fifth Floor, Boston, MA 02110-1301, USA.

#-------------------------------------------------------------------------------

import os
import torch
import torch.nn as nn
import torch.nn.functional as func
import pandas as pd
import numpy as np

#===============================================================================
# Local functions
#===============================================================================


#-------------------------------------------------------------------------------

def domain_prepare_data_add(domain):
    """
    Additional steps to prepare data
    (called in data preparation stage, between copy of files
    in DATA and copy of link of restart files as defined by domain).
    """

    # setup #################################
    model_folder =  "/home/cenvinzf@coria.fr/workdir/UMONS.coupling/case/SRC/model"
    save_folder = "/home/cenvinzf@coria.fr/workdir/UMONS.coupling/case/SRC/stats"

    data_norm = "STD"
    ann_norm = "STD"
    numClusters = 2
    
    cluster_dir = "/home/cenvinzf@coria.fr/LSTM/SRC/ORCh/clusters/kMeans++/parentCluster-ANN"
    train_dir = "/home/cenvinzf@coria.fr/LSTM/SRC/ORCh/training/ANN/{}".format(data_norm)

    # setup model parameters
    hidden_sizes = [512, 256, 128, 64, 32, 16]
    num_features = 14
    device = 'cpu'
    ##########################################

    print('ANN preprocessing for scalars')
    print('-----------------------------\n')

    class DNN(nn.Module):
        def __init__(self, data, input_dim, hidden_size, output_dim, use_bias=True):
            super(DNN, self).__init__()
            self.hidden_size = hidden_size.copy()
            self.output_dim = output_dim
            self.data =  data
            self.input_dim = input_dim
            self.hidden_size.insert(0, self.input_dim)

            self.linear_net = nn.ModuleList(
                [nn.Linear(self.hidden_size[i], 
                        self.hidden_size[i+1] if i != len(self.hidden_size)-1
                                                else output_dim, 
                        bias=use_bias) 
                        for i in range(len(self.hidden_size))]
            )

        def forward(self, x):
            for i, lin in enumerate(self.linear_net):
                x = lin(x) 
                if i != len(self.hidden_size)-1: x = func.relu(x)

            return x
    
    if not(os.path.exists(save_folder) and os.path.exists(model_folder)):
        # checking if folders exist
        os.makedirs(save_folder, exist_ok=True)
        os.makedirs(model_folder, exist_ok=True)

        # saving centroids for parent clusters 
        outputCluster = torch.load(os.path.join(cluster_dir, "outputCluster.pt"), weights_only=True)
        centroids = outputCluster['centroids']
        with open('{}/centroidParents.pt'.format(save_folder), 'wb') as f:
            f.write(centroids.to(torch.float32).numpy().tobytes())

        # saving global stats for parent clusters
        stats = torch.load(os.path.join(cluster_dir, "stats-kMeans++-{}-#{}parentClusters.pt".format(data_norm, numClusters)), weights_only=True)
        with open('{}/meanGlobal.pt'.format(save_folder), 'wb') as f:
            f.write(stats['mean'].to(torch.float32).numpy().tobytes())
        with open('{}/stdGlobal.pt'.format(save_folder), 'wb') as f:
            f.write(stats['std'].to(torch.float32).numpy().tobytes())

        # looping into parent clusters
        for i in range(numClusters):
            print(' Tracing model parentCluster#{}...'.format(i+1))

            # counting the number of clusters per parentX
            numClusterChild = [i for i in os.listdir(os.path.join(cluster_dir, "parentCluster{}/plotCentroid".format(i)))]
            numClusterChild = len(numClusterChild)

            # centroids
            outputCluster = torch.load(os.path.join(cluster_dir, "parentCluster{}/outputCluster-parent{}.pt".format(i, i)), weights_only=True)
            centroids = outputCluster['centroids']
            with open('{}/centroidParent{}Child.pt'.format(save_folder, i), 'wb') as f:
                f.write(centroids.to(torch.float32).numpy().tobytes())
            
            # stats
            stats = torch.load(os.path.join(cluster_dir, "parentCluster{}/stats-kMeans++-{}-parentCluster{}-#{}childClusters.pt".format(i, data_norm, i, numClusterChild)), weights_only=True)
            with open('{}/meanLocalParent{}.pt'.format(save_folder, i), 'wb') as f:
                f.write(stats['mean'].to(torch.float32).numpy().tobytes())
            with open('{}/stdLocalParent{}.pt'.format(save_folder, i), 'wb') as f:
                f.write(stats['std'].to(torch.float32).numpy().tobytes())

            # looping into child clusters
            for j in range(numClusterChild):
                print('\r  - childCluster#{}...'.format(j+1), end=' ')
                
                if i == 0:   # for parent0

                    # stats for ANN inference
                    statsANN_X = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/stats/statsX-{}.pt".format(i, j, ann_norm)), weights_only=True)
                    with open('{}/centerANN-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_X['center'].to(torch.float32).numpy().tobytes())
                    with open('{}/scalingANN-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_X['scaling'].to(torch.float32).numpy().tobytes())

                    statsANN_Y = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/stats/statsY-{}.pt".format(i, j, ann_norm)), weights_only=True)
                    with open('{}/centerANNtarget-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_Y['center'].to(torch.float32).numpy().tobytes())
                    with open('{}/scalingANNtarget-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_Y['scaling'].to(torch.float32).numpy().tobytes())
                    with open('{}/minANNtarget-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_Y['min'].to(torch.float32).numpy().tobytes())
                    with open('{}/maxANNtarget-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(statsANN_Y['max'].to(torch.float32).numpy().tobytes())

                    # saving values to check if bounds are verified
                    aftReactor = pd.read_csv(os.path.join(cluster_dir, "parentCluster{}/childInputTarget/dfTargetChild_{}.csv".format(i, j)), sep="\t")
                    maxAftReactor = aftReactor.max().to_numpy().astype(np.float32)
                    minAftReactor = aftReactor.min().to_numpy().astype(np.float32)
                    with open('{}/minAftReactor-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(minAftReactor.tobytes())
                    with open('{}/maxAftReactor-parent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(maxAftReactor.tobytes())

                    # loading model
                    dataset = torch.empty((1, 14))   # buffer tensor
                    model = DNN(dataset, hidden_size=hidden_sizes, input_dim=dataset.shape[1], output_dim=dataset.shape[1], use_bias=True).to(device).to(torch.float32)
                    
                    net = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/finalModel.tar".format(i, j)), weights_only=True)
                    model.load_state_dict(net['model_state_dict'])
                    # Use torch.jit.trace/script to generate a torch.jit.ScriptModule via tracing.
                    path = os.path.join(model_folder, "parentCluster#{}/childCluster#{}".format(i, j))
                    os.makedirs(path, exist_ok=True)
                    traced_script_module = torch.jit.script(model)
                    traced_script_module.save('{}/traced_model.pt'.format(path))

                # (END)parent0

                elif i == 1:   # for parent1
                    
                    # counting the number of clusters per parentXchildY
                    numClusterGrandchild = [i for i in os.listdir(os.path.join(cluster_dir, "parentCluster{}/childCluster{}/plotCentroid".format(i, j)))]
                    numClusterGrandchild = len(numClusterGrandchild)

                    # centroids
                    outputCluster = torch.load(os.path.join(cluster_dir, "parentCluster{}/childCluster{}/outputCluster-parent{}-child{}.pt".format(i, j, i, j)), weights_only=True)
                    centroids = outputCluster['centroids']
                    with open('{}/centroidParent{}child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(centroids.to(torch.float32).numpy().tobytes())
            
                    # stats
                    stats = torch.load(os.path.join(cluster_dir, "parentCluster{}/childCluster{}/stats-kMeans++-{}-parentCluster{}-childCluster{}-#{}grandchildClusters.pt".format(i, j, data_norm, i, j, numClusterGrandchild)), weights_only=True)
                    with open('{}/meanLocalParent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(stats['mean'].to(torch.float32).numpy().tobytes())
                    with open('{}/stdLocalParent{}-child{}.pt'.format(save_folder, i, j), 'wb') as f:
                        f.write(stats['std'].to(torch.float32).numpy().tobytes())

                    # looping into grandchild clusters
                    for k in range(numClusterGrandchild):
                        print('\r  - grandchildCluster#{}...'.format(k+1), end=' ')

                        # stats for ANN inference
                        statsANN_X = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/grandchildcluster#{}/stats/statsX-{}.pt".format(i, j, k, ann_norm)), weights_only=True)
                        with open('{}/centerANN-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_X['center'].to(torch.float32).numpy().tobytes())
                        with open('{}/scalingANN-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_X['scaling'].to(torch.float32).numpy().tobytes())

                        statsANN_Y = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/grandchildcluster#{}/stats/statsY-{}.pt".format(i, j, k, ann_norm)), weights_only=True)
                        with open('{}/centerANNtarget-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_Y['center'].to(torch.float32).numpy().tobytes())
                        with open('{}/scalingANNtarget-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_Y['scaling'].to(torch.float32).numpy().tobytes())
                        with open('{}/minANNtarget-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_Y['min'].to(torch.float32).numpy().tobytes())
                        with open('{}/maxANNtarget-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(statsANN_Y['max'].to(torch.float32).numpy().tobytes())

                        # saving values to check if bounds are verified
                        aftReactor = pd.read_csv(os.path.join(cluster_dir, "parentCluster{}/childCluster{}/childInputTarget/dfTargetGrandchild_{}.csv".format(i, j, k)), sep="\t")
                        maxAftReactor = aftReactor.max().to_numpy().astype(np.float32)
                        minAftReactor = aftReactor.min().to_numpy().astype(np.float32)
                        with open('{}/minAftReactor-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(minAftReactor.tobytes())
                        with open('{}/maxAftReactor-parent{}-child{}-grandchild{}.pt'.format(save_folder, i, j, k), 'wb') as f:
                            f.write(maxAftReactor.tobytes())

                        # loading model
                        dataset = torch.empty((1, 14))   # buffer tensor
                        model = DNN(dataset, hidden_size=hidden_sizes, input_dim=dataset.shape[1], output_dim=dataset.shape[1], use_bias=True).to(device).to(torch.float32)
                        
                        net = torch.load(os.path.join(train_dir, "parentCluster#{}/childCluster#{}/grandchildcluster#{}/finalModel.tar".format(i, j, k)), weights_only=True)
                        model.load_state_dict(net['model_state_dict'])
                        # Use torch.jit.trace/script to generate a torch.jit.ScriptModule via tracing.
                        path = os.path.join(model_folder, "parentCluster#{}/childCluster#{}/grandchildCluster#{}".format(i, j, k))
                        os.makedirs(path, exist_ok=True)
                        traced_script_module = torch.jit.script(model)
                        traced_script_module.save('{}/traced_model.pt'.format(path))
                        
            print(' ')
                    
        print(' ')

    else:
        print(' ANN preprocessing already done, skipping...\n')
        
    return

#-------------------------------------------------------------------------------

def define_domain_parameters(domain):
    """
    Define domain execution parameters.
    """

    # Additionnal compiler flags may be passed to the C, C++, or Fortran
    # compilers, and libraries may be added, in case linking of user
    # functions against external libraries is needed.

    # Read parameters file
    # (already done just prior to this stage when
    # running script with --param option)

    # this helps ensure added search paths have priority, but also implies
    # that user optimization options may be superceded by the default ones.

    # domain.restart_input = '/home/cenvinzf@coria.fr/workdir/cavity_h2_flame/case_h2-145_air-55/RESU/20240916-1353/checkpoint'

    domain.compile_cflags = None

    domain.compile_cxxflags = '-std=c++17 -std=gnu++17 -I/home/cenvinzf@coria.fr/libtorch/include -I/home/cenvinzf@coria.fr/libtorch/include/torch/csrc/api/include'
    domain.compile_cxxflags += ' -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/Common/Core -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/Common/DataModel -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/IO/Legacy -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/Utilities/KWIML -I/home/cenvinzf@coria.fr/Desktop/VTK/build/Utilities/KWSys -I/home/cenvinzf@coria.fr/Desktop/VTK/build/Common/Core -I/home/cenvinzf@coria.fr/Desktop/VTK/build/Common/DataModel -I/home/cenvinzf@coria.fr/Desktop/VTK/build/ThirdParty/nlohmannjson -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/ThirdParty/nlohmannjson -I/home/cenvinzf@coria.fr/Desktop/VTK/build/IO/Legacy -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/Common/ExecutionModel -I/home/cenvinzf@coria.fr/Desktop/VTK/build/Common/ExecutionModel -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/Filters/Core -I/home/cenvinzf@coria.fr/Desktop/VTK/build/Filters/Core -I/home/cenvinzf@coria.fr/Desktop/VTK/vtk/IO/Core -I/home/cenvinzf@coria.fr/Desktop/VTK/build/IO/Core'
    domain.compile_fcflags = None

    domain.compile_libs = '-L/home/cenvinzf@coria.fr/libtorch/lib -L/home/cenvinzf@coria.fr/Desktop/VTK/build/lib  -lc10 -ltorch -ltorch_cpu -lstdc++ -lvtksys-9.4 -lvtkFiltersCore-9.4 -lvtkIOLegacy-9.4 -lvtkCommonExecutionModel-9.4 -lvtkCommonDataModel-9.4 -lvtkCommonCore-9.4 -lvtkIOCore-9.4'
    

    # Debugging options
    #------------------
    # To run the solver through a debugger, domain.debug should contain
    # the matching command-line arguments, such as:
    #   domain.debug = '--debugger=gdb'/home/cenvinzf@coria.fr/CS_ANN_coupling_8.0.3/code_saturne-8.0.3/arch/Linux_x86_64/lib/libple.so.2

    # or (for Valgrind):
    #   domain.debug = 'valgrind --tool=memcheck'
    # or (for Valgrind and ddd):
    #   domain.debug = '--debugger=ddd valgrind --tool=memcheck --vgdb-error=1'


    # import pprint
    # pprint.pprint(domain.__dict__)

    return

#-------------------------------------------------------------------------------
# End
#-------------------------------------------------------------------------------
