#include "cs_defs.h"

/*----------------------------------------------------------------------------
 * Standard C library headers
 *----------------------------------------------------------------------------*/

#include <assert.h>
#include <math.h>
#include <stdio.h>
#include "cs_base.h"
#include "cs_halo.h"
#include "cs_mesh.h"
#include "cs_mesh_quantities.h"

/*----------------------------------------------------------------------------
 * PLE library headers
 *----------------------------------------------------------------------------*/

#include <ple_coupling.h>

/*----------------------------------------------------------------------------
 * Local headers
 *----------------------------------------------------------------------------*/

#include "cs_headers.h"

/*----------------------------------------------------------------------------*/

BEGIN_C_DECLS

void
cs_user_source_terms(cs_domain_t  *domain,
                      int           f_id,
                      cs_real_t    *st_exp,
                      cs_real_t    *st_imp)
{
  const cs_field_t *f = cs_field_by_id(f_id);
  if (f != CS_F_(t)) return;

  const cs_lnum_t n_cells = domain->mesh->n_cells;
  const cs_lnum_t n_i_faces = domain->mesh->n_i_faces;
  const cs_lnum_t n_b_faces = domain->mesh->n_b_faces;
  const int nt_cur = cs_glob_time_step->nt_cur;

  const cs_real_t *cell_vol = domain->mesh_quantities->cell_vol;
  const cs_real_3_t *i_fac_cog = (const cs_real_3_t *)domain->mesh_quantities->i_face_cog;
  const cs_real_3_t *i_face_normal = (const cs_real_3_t *)domain->mesh_quantities->i_face_normal;
  const cs_real_t *i_face_surf = domain->mesh_quantities->i_face_surf;

  const cs_real_t *cvar_temp = CS_F_(t)->val;
  const cs_real_3_t *cvar_vel = (const cs_real_3_t *)CS_F_(vel)->val;

  const cs_real_t cp = 1005.0;
  const cs_real_t L = 1.6;
  const cs_real_t rho = 1.0;

  const int kimasf = cs_field_key_id("inner_mass_flux_id");
  const cs_real_t *i_massflux_base = cs_field_by_id(cs_field_get_key_int(f, kimasf))->val;

  cs_real_t mflow = 0, tot_sur = 0;
  for (cs_lnum_t face_id = 0; face_id < n_i_faces; face_id++) {
    if (i_fac_cog[face_id][0] < 0.0001) {
      tot_sur += i_face_surf[face_id];
      mflow += i_massflux_base[face_id] * i_face_normal[face_id][0] / i_face_surf[face_id];
    }
  }
  cs_real_t bbuf[2] = {tot_sur, mflow};
  cs_parall_sum(2, CS_REAL_TYPE, bbuf);
  tot_sur = bbuf[0];
  mflow = bbuf[1];

  if (mflow == 0.0) printf("Warning: Mass flow rate is zero, setting a_T to zero.\n");

  const int location_id = CS_MESH_LOCATION_BOUNDARY_FACES;
  const cs_lnum_t n_elts = cs_mesh_location_get_n_elts(location_id)[0];
  cs_real_t *boundary_flux = NULL;
  BFT_MALLOC(boundary_flux, n_elts, cs_real_t);

  const cs_field_t *fl = cs_thermal_model_field();
  cs_post_boundary_flux(fl->name, n_elts, NULL, boundary_flux);

  const cs_real_t *b_face_surf = cs_glob_mesh_quantities->b_face_surf;
  const cs_lnum_t (*b_face_cog)[3] = cs_glob_mesh_quantities->b_face_cog;

  cs_real_t integral_flux = 0.0, total_surface_area = 0.0;
  for (cs_lnum_t j = 0; j < n_elts; j++) {
    if (b_face_cog[j][2] < 0.00001) {
      integral_flux += boundary_flux[j] * b_face_surf[j];
      total_surface_area += b_face_surf[j];
    }
  }
  cs_real_t bbuf2[2] = {integral_flux, total_surface_area};
  cs_parall_sum(2, CS_REAL_TYPE, bbuf2);
  integral_flux = bbuf2[0];
  total_surface_area = bbuf2[1];
  BFT_FREE(boundary_flux);

  cs_real_t a_T = integral_flux / (mflow * cp * L);

  /* Compute phi_face = cp * rho * a_T * u_face */
  cs_real_t *i_phi_face = NULL, *b_phi_face = NULL;
  BFT_MALLOC(i_phi_face, n_i_faces, cs_real_t);
  BFT_MALLOC(b_phi_face, n_b_faces, cs_real_t);

  for (cs_lnum_t iface = 0; iface < n_i_faces; iface++) {
    const cs_lnum_t c0 = domain->mesh->i_face_cells[2 * iface];
    const cs_lnum_t c1 = domain->mesh->i_face_cells[2 * iface + 1];
    const cs_real_t u0 = cvar_vel[c0][0];
    const cs_real_t u1 = cvar_vel[c1][0];
    const cs_real_t u_face = 0.5 * (u0 + u1);
    i_phi_face[iface] = rho * cp * a_T * u_face;
  }

  for (cs_lnum_t iface = 0; iface < n_b_faces; iface++) {
    const cs_lnum_t c0 = domain->mesh->b_face_cells[iface];
    const cs_real_t u0 = cvar_vel[c0][0];
    b_phi_face[iface] = rho * cp * a_T * u0;
  }

  /* Compute divergence */
  cs_real_t *divergence = NULL;
  cs_real_t feedbackterm = 0.0;
  BFT_MALLOC(divergence, n_cells, cs_real_t);
  cs_divergence(domain->mesh, 1, i_phi_face, b_phi_face, divergence);

  /* Apply source term */
  cs_real_t relaxation = fmin((nt_cur - 800000) / 100000.0 * 0.00005, 0.0001);

for (cs_lnum_t i = 0; i < n_cells; i++) {
  if (nt_cur > 800100) {
    feedbackterm = -divergence[i] * relaxation;
  }
      st_exp[i] += feedbackterm;
      st_imp[i] = 0.0;
  }

  BFT_FREE(i_phi_face);
  BFT_FREE(b_phi_face);
  BFT_FREE(divergence);
}

END_C_DECLS
