
import numpy as np
import matplotlib.pyplot as plt
import flow360 as fl
from flow360.examples import DTU_WindTurbine
DTU_WindTurbine.get_files()

project = fl.Project.from_geometry(
    DTU_WindTurbine.geometry, length_unit="m", name="DTU Wind Turbine stopping-criterion tutorial"
)
geometry = project.geometry

with fl.SI_unit_system:
    farfield = fl.AutomatedFarfield()

    R = 89.166  # blade radius [m] — DTU 10MW reference turbine

    rotor_cylinder = fl.Cylinder(
        center=(0, 0, 0),
        outer_radius=110 * fl.u.m,
        height=40 * fl.u.m,
        axis=(0, 0, 1),
        name="RotorCylinder",
    )

    wake_cylinder = fl.Cylinder(
        center=(0, 0, 55) * fl.u.m,
        outer_radius=200 * fl.u.m,
        height=200 * fl.u.m,
        axis=(0, 0, 1),
        name="WakeCylinder",
    )

rotating_zone = fl.RotationVolume(
    name="RotatingZone",
    spacing_axial=2 * fl.u.m,
    spacing_circumferential=2 * fl.u.m,
    spacing_radial=2 * fl.u.m,
    entities=rotor_cylinder,
    enclosed_entities=geometry["*"],
)

meshing = fl.MeshingParams(
    defaults=fl.MeshingDefaults(
        surface_max_edge_length=2 * fl.u.m,
        boundary_layer_first_layer_thickness=1e-4 * fl.u.m,
    ),
    volume_zones=[farfield, rotating_zone],
    refinements=[
        fl.UniformRefinement(entities=wake_cylinder, spacing=4 * fl.u.m),
        fl.UniformRefinement(entities=rotor_cylinder, spacing=2 * fl.u.m),
    ],
)

omega = 8.836 * fl.u.rpm            # rated angular velocity — DTU 10MW
dt_deg = 6.0                        # physical time step [degrees]; change as needed
revs = 50                           # maximum number of revolutions to simulate
steps_per_rev = int(360 / dt_deg)
step_size = (dt_deg * fl.u.deg / omega.to("deg/s")).to("s")

time_stepping = fl.Unsteady(
    steps=revs * steps_per_rev,
    step_size=step_size,
    max_pseudo_steps=50,
)

rotation = fl.Rotation(
    name="RotatingZone",
    volumes=rotor_cylinder,
    spec=fl.AngularVelocity(omega),
)

revs_for_mean = 2
window_steps = revs_for_mean * steps_per_rev

walls = fl.Wall(surfaces=geometry["*"], use_wall_function=True)
thrust_output = fl.ForceOutput(
    name="thrust_monitor",
    output_fields=["CFz", "CMz"],
    moving_statistic=fl.MovingStatistic(
        method="mean",
        moving_window_size=window_steps,
    ),
    models=[walls],
)

revs_for_stop_criterion = 3
tolerance_window_size = revs_for_stop_criterion * steps_per_rev
run_control = fl.RunControl(
    stopping_criteria=[
        fl.StoppingCriterion(
            name="thrust_convergence",
            monitor_output=thrust_output,
            monitor_field="CFz",
            tolerance=1e-4,
            tolerance_window_size=tolerance_window_size,
        )
    ]
)

with fl.SI_unit_system:
    params = fl.SimulationParams(
        meshing=meshing,
        operating_condition=fl.AerospaceCondition(
            velocity_magnitude=11.0,       # rated wind speed [m/s]
            alpha=90 * fl.u.deg,            # wind aligned with rotor axis (+Z)
            reference_velocity_magnitude=omega.to("rad/s").to_value()*R,  # ≈ tip speed [m/s] for non-dimensionalisation
        ),
        reference_geometry=fl.ReferenceGeometry(
            moment_center=(0, 0, 0),
            moment_length=(R, R, R),
            area=np.pi * R**2,
        ),
        time_stepping=time_stepping,
        run_control=run_control,
        models=[
            fl.Fluid(
                navier_stokes_solver=fl.NavierStokesSolver(
                    absolute_tolerance=1e-10,
                    relative_tolerance=1e-1,
                ),
                turbulence_model_solver=fl.SpalartAllmaras(
                    absolute_tolerance=1e-10,
                    relative_tolerance=1e-1,
                ),
            ),
            walls,
            fl.Freestream(surfaces=farfield.farfield),
            rotation,
        ],
        outputs=[
            fl.SurfaceOutput(
                surfaces=geometry["*"],
                output_fields=["Cp", "Cf", "CfVec"],
            ),
            thrust_output,
        ],
    )

case = project.run_case(
    params=params,
    name="Time-Accurate Rotor with Thrust-Convergence Stopping Criterion",
    use_beta_mesher=True,
)

case.wait()

density = case.params.operating_condition.thermal_state.density
U_ref   = case.params.operating_condition.velocity_magnitude
A_ref   = case.params.reference_geometry.area
q       = 0.5 * float(density) * float(U_ref) ** 2 * float(A_ref)

thrust_monitor_mean = case.results.custom_forces["thrust_monitor_moving_statistic"].as_dataframe()
thrust_monitor_mean["Thrust_mean [N]"] = thrust_monitor_mean["totalCFz_mean"]
thrust_monitor_mean_step = thrust_monitor_mean["end_index"]
thrust_monitor = case.results.custom_forces["thrust_monitor"].as_dataframe()
thrust_monitor["Thrust [N]"] = thrust_monitor["totalCFz"]
thrust_monitor_step = thrust_monitor["physical_step"]

fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=False)

axes[0].plot(thrust_monitor["physical_step"], thrust_monitor["Thrust [N]"],
             color="steelblue", alpha=0.5, linewidth=0.8, label="Instantaneous")
axes[0].plot(thrust_monitor_mean["end_index"], thrust_monitor_mean["Thrust_mean [N]"],
             color="navy", linewidth=2, label=f"{revs_for_mean}-rev rolling mean")
axes[0].set_ylabel("Thrust [N]")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_title("Thrust convergence")

final_step = thrust_monitor["physical_step"].iloc[-1]
non_linear = case.results.nonlinear_residuals.as_dataframe()
nl = non_linear.copy()
max_ps = nl["pseudo_step"].max() + 1
nl["step_with_pseudo"] = nl["physical_step"] + nl["pseudo_step"] / max_ps
for col in ["0_cont", "1_momx", "2_momy", "3_momz", "4_energ", "5_nuHat"]:
    if col in nl.columns:
        axes[1].semilogy(nl["step_with_pseudo"], nl[col], label=col, alpha=0.7)
axes[1].set_xlim(final_step - steps_per_rev + 1, final_step + 1)
axes[1].set_xlabel("Physical step (with pseudo steps)")
axes[1].set_ylabel("Nonlinear residual")
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Simulation stopped at physical step {final_step} "
      f"({final_step / steps_per_rev:.1f} revolutions)")
