
import flow360 as fl
from flow360.examples import NLFAirfoil2D

NLFAirfoil2D.get_files()

project = fl.Project.from_volume_mesh(
    NLFAirfoil2D.mesh_filename, name="Transition Model 2D Airfoil from Python"
)
vm = project.volume_mesh

operating_condition = fl.AerospaceCondition.from_mach_reynolds(
    mach=0.1,
    reynolds_mesh_unit=4e6,
    project_length_unit=project.length_unit,
    temperature=540.0 * fl.u.R,
    alpha=0.0 * fl.u.deg,
)

time_stepping = fl.Steady(
    max_steps=20000, CFL=fl.AdaptiveCFL(convergence_limiting_factor=0.4)
)

transition_model_solver = fl.TransitionModelSolver(
    linear_solver=fl.LinearSolver(max_iterations=30),
    absolute_tolerance=1e-10,
    N_crit=7.2,
    update_jacobian_frequency=1,
    equation_evaluation_frequency=1,
)

outputs = [
    fl.SurfaceOutput(
        surfaces=vm["Block/Aerofoil"],
        output_format="both",
        output_fields=["Cp", "Cf", "yPlus", "CfVec"],
    ),
    fl.SurfaceSliceOutput(
        name="surface_slices",
        entities=[fl.Slice(name="y", normal=(0, 1, 0), origin=(0, -0.5, 0) * fl.u.m)],
        target_surfaces=vm["Block/Aerofoil"],
        output_format="paraview",
        output_fields=["Cp", "Cf", "yPlus", "CfVec"],
    ),
]

with fl.SI_unit_system:

    params = fl.SimulationParams(
        operating_condition=operating_condition,
        time_stepping=time_stepping,
        models=[
            fl.Wall(
                surfaces=vm["Block/Aerofoil"],
            ),
            fl.Freestream(
                surfaces=vm["Block/Farfield"],
            ),
            fl.SlipWall(entities=[vm["Block/Symmetry"]]),
            fl.Fluid(
                turbulence_model_solver=fl.SpalartAllmaras(
                    absolute_tolerance=1e-10,
                    linear_solver=fl.LinearSolver(max_iterations=30),
                    update_jacobian_frequency=1,
                    equation_evaluation_frequency=1,
                ),
                navier_stokes_solver=fl.NavierStokesSolver(
                    linear_solver=fl.LinearSolver(max_iterations=50),
                    absolute_tolerance=1e-12,
                    update_jacobian_frequency=4,
                    equation_evaluation_frequency=1,
                ),
                transition_model_solver=transition_model_solver,
            ),
        ],
        outputs=outputs,
    )

case = project.run_case(params=params, name="Transition Model 2D Airfoil from Python")

import tempfile
import os
import tarfile
import pyvista as pv
import matplotlib.pyplot as plt

def extract_results(results):
    with tempfile.TemporaryDirectory() as temp_dir:
        destination = os.path.join(temp_dir, "case_data")
        os.makedirs(destination, exist_ok=True)

        results.download(
            surface=True,
            destination=destination,
        )

        surfaces_path = os.path.join(destination, "surfaces.tar.gz")
        with tarfile.open(surfaces_path) as tar:
            tar.extractall(path=temp_dir)  # extract directly into temp_dir

        mesh_path = os.path.join(temp_dir, "surface_slice_y.pvtu")
        mesh = pv.read(mesh_path)

        x = mesh.points[:, 0]
        y = mesh.points[:, 1]
        z = mesh.points[:, 2]
        cf = mesh.point_data["Cf"]
        cp = mesh.point_data["Cp"]

        sorted_indices = x.argsort()
        x = x[sorted_indices]
        y = y[sorted_indices]
        z = z[sorted_indices]
        cf = cf[sorted_indices]
        cp = cp[sorted_indices]

        return x, y, z, cf, cp

case.wait()
results = case.results

total_forces = results.total_forces.as_dataframe()
total_forces_filtered = total_forces[total_forces["pseudo_step"] >= 5000]
total_forces_filtered.plot("pseudo_step", ["CD"])
plt.ylabel("Drag Coefficient")

non_linear = results.nonlinear_residuals.as_dataframe()
non_linear_filtered = non_linear[non_linear["pseudo_step"] >= 5000]
non_linear_filtered.plot(
    "pseudo_step",
    ["0_cont", "1_momx", "2_momy", "3_momz", "4_energ", "5_nuHat"],
    logy=True,
)
plt.ylabel("Non-Linear Residual")

x, y, z, cf, cp = extract_results(results)

plt.figure(figsize=(10, 6))
plt.scatter(x, cf, label="Friction Coefficient")
plt.xlabel("Chord Position")
plt.ylabel("Friction Coefficient")
