"""
G x E x M Framework — ML Microservice Stub (Stages 3-4)
=========================================================
This is a RUNNABLE placeholder for the "AI Brain" stages of the framework
(Random Forest / XGBoost / DNN / LSTM / Ensemble Stack). It implements the
exact request/response contract the PHP app expects, but scores genotypes
with a simple heuristic instead of real trained models — swap the
`run_model()` function for real scikit-learn / XGBoost / TensorFlow code
when you're ready.

Run:
    pip install flask
    python app.py
    # -> listening on http://localhost:5000

The PHP app (config/db.php -> ML_API_BASE_URL) must point at this service.
"""

from flask import Flask, request, jsonify
import random

app = Flask(__name__)

API_KEY = "CHANGE_ME_SHARED_SECRET"  # must match config/db.php -> ML_API_KEY


@app.before_request
def check_api_key():
    if request.endpoint == "health":
        return
    key = request.headers.get("X-API-Key")
    if key != API_KEY:
        return jsonify({"error": "invalid or missing X-API-Key"}), 401


@app.route("/health", methods=["GET"])
def health():
    return jsonify({"status": "ok"})


@app.route("/predict", methods=["POST"])
def predict():
    """
    Expected request body (sent by PHP model_runs.php):
    {
      "model_type": "RandomForest" | "XGBoost" | "DNN" | "LSTM" | "Ensemble_Stack",
      "training_trials": [
         {"trial_id":.., "genotype_id":.., "environment_id":.., "management_id":..,
          "yield_kg_ha":.., "quality_score":.., "stress_index":..}, ...
      ],
      "target_environment_id": <int>,
      "target_management_id": <int>
    }

    Response body (consumed by PHP model_runs.php + recommendations.php):
    {
      "recommendations": [
         {"genotype_id": <int>, "predicted_yield": <float>, "gg_score": <float>}, ...
      ],
      "metrics": {"rmse": .., "r2": .., "model_type": .., "n_training_rows": ..}
    }
    """
    body = request.get_json(force=True)
    model_type = body.get("model_type", "Ensemble_Stack")
    trials = body.get("training_trials", [])

    # ---- Real implementation would go here: -----------------------------
    #   1. Load genomic_data_ref / weather_data_ref for each trial
    #   2. Build feature matrix (Stage 2 preprocessing: imputation, scaling)
    #   3. Feature selection (Stage 3: RFE, SHAP, MI, KPCA)
    #   4. Train / apply RandomForest, XGBoost, DNN, LSTM (Stage 4)
    #   5. Stack predictions with a meta-learner
    #   6. Project performance (GG) for each candidate genotype in the
    #      target environment x management scenario (Stage 5)
    # -----------------------------------------------------------------------

    # Placeholder heuristic: rank distinct genotypes seen in training data
    # by their average historical yield, with a little jitter to simulate
    # model variance across model_type.
    genotype_avg = {}
    for t in trials:
        gid = t.get("genotype_id")
        y = t.get("yield_kg_ha")
        if gid is None or y is None:
            continue
        genotype_avg.setdefault(gid, []).append(float(y))

    random.seed(hash(model_type) % 1000)
    scored = []
    for gid, yields in genotype_avg.items():
        avg_yield = sum(yields) / len(yields)
        jitter = random.uniform(-0.05, 0.05)
        predicted_yield = round(avg_yield * (1 + jitter), 2)
        gg_score = round(min(1.0, max(0.0, (predicted_yield / 10000) + jitter)), 4)
        scored.append({
            "genotype_id": gid,
            "predicted_yield": predicted_yield,
            "gg_score": gg_score,
        })

    scored.sort(key=lambda r: r["predicted_yield"], reverse=True)

    metrics = {
        "model_type": model_type,
        "n_training_rows": len(trials),
        "rmse": round(random.uniform(150, 400), 2),   # placeholder
        "r2": round(random.uniform(0.6, 0.9), 3),      # placeholder
        "note": "Placeholder metrics — replace run_model() with real training.",
    }

    return jsonify({"recommendations": scored, "metrics": metrics})


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=True)
