{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "d1fb5ac5-9f8b-4dec-bbdd-83452f8ede08",
    "dms": {},
    "execution_count": null
   },
   "source": [
    "# Prerequisites\n",
    "\n",
    "## Make your you have already launch a notebook session with Spark cluster binded (soft link to an AnalycDB's Spark job resource group)\n",
    "\n",
    "## When you define the notebook session, make sure you have grant OSS Access permission to your notebook session.\n",
    "\n",
    "Otherwise you can't publish your python environment into different Spark environment (executor node)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "97244161-ab0b-4b2b-b002-3e4f833f78bf",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_902591"
    },
    "execution_count": null
   },
   "source": [
    "# 1. Install the python library for Spark&Notebook environment\n",
    "\n",
    " - For example, reset the library deployment include: `scikit-learn pandas pytz xgboost scipy PyArrow joblib threadpoolctl numpy typing_extensions packaging sqlparse pyyaml click pytz`. \n",
    "    - This is a must have step for Spark based distrubuted machine learning usecase. \n",
    "    - Because you need to ensure consistency across your distributed Python environment.\n",
    "\n",
    " - And then you need to distribute the libraries to all the Spark executor node environment by the command: `!pyp_persist`\n",
    " - When you restart the kernel/session, no need to manually trigger this step again. Just jump it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_id": "6b1834fc-c695-416d-b714-42fb0deb912d",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 1,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "!apt update && apt -y install build-essential "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_id": "cfda33d0-c4b9-42a9-a533-7d219f975d8a",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 2,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "!pip uninstall -y scikit-learn pandas xgboost scipy PyArrow joblib threadpoolctl numpy python-dateutil six pytz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "cell_id": "3462b100-4bd4-4dde-bc0d-e8789ecab19a",
    "dms": {
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_909895"
    },
    "execution_count": 3,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "!pip install xgboost scikit-learn pandas==2.0.2 PyArrow scipy joblib threadpoolctl python-dateutil six numpy==1.26.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_id": "4d87a51d-c294-475d-8450-ce8d089663fb",
    "dms": {
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_229418"
    },
    "execution_count": 4,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "!pyp_persist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "ab8b4886-1ac8-4bc5-8537-d42a21ee3e4f",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_615060"
    },
    "execution_count": null
   },
   "source": [
    "# 2. Create the Spark application and define the resource\n",
    " - When you want to run the machinelearning-dedicated Spark job, you must set the 'config(\"spark.dynamicAllocation.enabled\", \"false\")'\n",
    " - And recommend you set the \"spark.executor.instances\" equals to your Spark max executor count  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "cell_id": "31459f6a-8e04-4954-8a7f-cec27e5a4433",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 5,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder.appName(\"MachineLearning\") \\\n",
    "    .config(\"spark.dynamicAllocation.enabled\", \"false\") \\\n",
    "    .config(\"spark.executor.instances\", \"4\") \\\n",
    "    .config(\"spark.jars\", f\"oss://******/mlflow-spark_2.12-3.5.1.jar\")\\\n",
    "    .getOrCreate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "65b38b5b-06dc-4b6b-b29e-3691c3631238",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_658156"
    },
    "execution_count": null
   },
   "source": [
    "## Verify whether the py library has beed installed successfully\n",
    "### Make the pandas as an example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cell_id": "9d34a42e-1a14-47bf-b7f5-f46310e238f4",
    "dms": {
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_474011"
    },
    "execution_count": 6,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "import socket\n",
    "\n",
    "def check_pandas_on_executor(x):\n",
    "    try:\n",
    "        import pandas as pd\n",
    "        return (socket.gethostname(), f\"pandas:{pd.__version__}\")\n",
    "    except ImportError as e:\n",
    "        return (socket.gethostname(), \"FAILED\", str(e))\n",
    "\n",
    "nodes_info = spark.sparkContext.parallelize(range(10), 10) \\\n",
    "                  .map(check_pandas_on_executor) \\\n",
    "                  .distinct() \\\n",
    "                  .collect()\n",
    "\n",
    "print(\"Executor python env check：\")\n",
    "for info in nodes_info:\n",
    "    print(info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "906c1963-a86a-410a-86a7-23aa8e0ef4b8",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_469893"
    },
    "execution_count": null
   },
   "source": [
    "# 3. Install the python library dedicated for notebook kernel environment (optional) \n",
    "\n",
    "This operation will only work for the notebook local environment. Has no affect on the Spark executor's Python environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "fd4b93a6-a842-4bdc-abed-d13a6eb4b00e",
    "disabled": true,
    "dms": {
     "tenant_id": "661889",
     "user_id": 1441649,
     "username": "wangyouzhuo",
     "variable": "output_379951"
    },
    "execution_count": null,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration.ipynb"
   },
   "outputs": [],
   "source": [
    "!pip install <xxxxxxxxx>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "4c82ab10-9cfa-4a2e-b44e-444e9a7d3d36",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_985222"
    },
    "execution_count": null
   },
   "source": [
    "# 4. Mock the stock dataset and write the dataset into OSS bucket as Parquet file\n",
    " - Remeber to define your own OSS bucket and OSS_OUTPUT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cell_id": "76fcb2ae-8835-418b-9996-caee4f7ff5eb",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 7,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "OSS_BUCKET = \"aliyun-oa-adb-spark-******-oss-cn-beijing\"\n",
    "OSS_ROOT_PATH = f\"oss://{OSS_BUCKET}/demo20260302_tmp/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "cell_id": "9baebe5a-644c-4643-b4dd-51332b47d706",
    "dms": {
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_597904"
    },
    "execution_count": 8,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.types import *\n",
    "\n",
    "OSS_OUTPUT_PATH = f\"{OSS_ROOT_PATH}raw_market_data\"\n",
    "\n",
    "print(\">>> Start generating mock market data...\")\n",
    "\n",
    "def generate_mock_data():\n",
    "    start_date = \"20230101\"\n",
    "    end_date = \"20231231\"\n",
    "    date_range = pd.date_range(start=start_date, end=end_date, freq='D')\n",
    "    num_stocks = 20  # Base number of shares\n",
    "    \n",
    "    # --- Generate basic stock data ---\n",
    "    dfs = []\n",
    "    for i in range(num_stocks):\n",
    "        symbol = f\"MOCK_{i:04d}\"  # Generate virtual stock codes\n",
    "        \n",
    "        # Generate price series (simulating reasonable volatility)\n",
    "        base_price = np.random.uniform(10, 100)  # The base price is between 10 and 100.\n",
    "        price_series = [base_price]\n",
    "        for _ in range(1, len(date_range)):\n",
    "            # Daily price fluctuations: Random fluctuations within ±5%\n",
    "            change = np.random.normal(0, 0.02)\n",
    "            price_series.append(max(1, price_series[-1] * (1 + change)))\n",
    "        \n",
    "        # Generate trading volume (related to price).\n",
    "        volume_series = [np.random.poisson(price * 1000) for price in price_series]\n",
    "        \n",
    "        # Create a DataFrame\n",
    "        df = pd.DataFrame({\n",
    "            \"date\": date_range,\n",
    "            \"ticker\": symbol,\n",
    "            \"open\": price_series,\n",
    "            \"high\": [p * np.random.uniform(1, 1.02) for p in price_series],\n",
    "            \"low\": [p * np.random.uniform(0.98, 1) for p in price_series],\n",
    "            \"close\": price_series,\n",
    "            \"volume\": volume_series})\n",
    "        \n",
    "        # Ensure the type is correct\n",
    "        df['date'] = df['date'].astype(str)\n",
    "        df['volume'] = df['volume'].astype(float)\n",
    "        \n",
    "        dfs.append(df)\n",
    "    \n",
    "    full_df = pd.concat(dfs)\n",
    "    \n",
    "    # --- Data Augmentation: Simulating Large Data Volumes ---\n",
    "    print(\">>> Perform data fission to simulate massive amounts of data...\")\n",
    "    aug_dfs = []\n",
    "    # Split 20 stocks into 1000 virtual stocks\n",
    "    for i in range(50): \n",
    "        temp = full_df.copy()\n",
    "        # Modify the ticker name, for example, MOCK_0000_001\n",
    "        temp['ticker'] = temp['ticker'] + f\"_{i:03d}\"\n",
    "        # Add small random perturbations to simulate different trends.\n",
    "        noise = np.random.normal(1, 0.005, len(temp))\n",
    "        for col in ['open', 'high', 'low', 'close']:\n",
    "            temp[col] = temp[col] * noise\n",
    "        aug_dfs.append(temp)\n",
    "        \n",
    "    return pd.concat(aug_dfs)\n",
    "\n",
    "# Generate mock data\n",
    "pdf = generate_mock_data()\n",
    "print(f\">>> Data is ready; Pandas DataFrame structure:{pdf.shape}\")\n",
    "print(\">>> Converting to Spark DataFrame...\")\n",
    "\n",
    "# Defining a schema is more robust and avoids errors in automatic inference.\n",
    "schema = StructType([\n",
    "    StructField(\"date\", StringType(), True),\n",
    "    StructField(\"ticker\", StringType(), True),\n",
    "    StructField(\"open\", DoubleType(), True),\n",
    "    StructField(\"high\", DoubleType(), True),\n",
    "    StructField(\"low\", DoubleType(), True),\n",
    "    StructField(\"close\", DoubleType(), True),\n",
    "    StructField(\"volume\", DoubleType(), True)\n",
    "])\n",
    "\n",
    "# Create DataFrame (Optimize transmission using Arrow)\n",
    "sdf = spark.createDataFrame(pdf, schema=schema)\n",
    "\n",
    "print(f\">>> Writing to OSS:{OSS_OUTPUT_PATH} ...\")\n",
    "\n",
    "# Write to Parquet files\n",
    "# mode(\"overwrite\"): Overwrite mode, suitable for repeated runs of the demo\n",
    "# partitionBy(\"date\"): Partition by date, crucial for subsequent Delta Lake or Hive queries\n",
    "(sdf.write\n",
    "    .mode(\"overwrite\")\n",
    "    .partitionBy(\"date\")\n",
    "    .parquet(OSS_OUTPUT_PATH))\n",
    "\n",
    "print(\">>> Writing complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "9567142e-0a1a-444c-ba07-1dbdeaf0d9b7",
    "dms": {
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_188298"
    },
    "execution_count": null
   },
   "source": [
    "# 5. Create database/table to store the extracted stock data as Delta Lake table format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cell_id": "9a7fda08-adcc-48bd-96ba-b35ab26208d8",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 9,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "DATABASE_LOCATION = f\"{OSS_ROOT_PATH}db_location/\"\n",
    "DB_NAME = 'stockdata'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "cell_id": "b07a4e20-7e3e-41c1-8d15-d6213ac0f72b",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zeheg0s047u6exp",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "661889",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_797651"
    },
    "execution_count": 12,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": []
   },
   "outputs": [],
   "source": [
    "spark.sql(f\"DROP DATABASE IF EXISTS {DB_NAME};\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "cell_id": "4d6bdab9-0037-45b9-8b0b-bfb8ec91da36",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zeheg0s047u6exp",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_972113"
    },
    "execution_count": 13,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": []
   },
   "outputs": [],
   "source": [
    "spark.sql(f\"CREATE DATABASE IF NOT EXISTS stockdata LOCATION '{DATABASE_LOCATION}';\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "f9434a47-8ae5-4925-b3c7-1fcc5144c0c9",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zee0njw3p2bf2gh",
     "exec_type": "spark_sql",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_79899"
    },
    "execution_count": null
   },
   "source": [
    "## 5.1. Read the raw OSS Parquet data and save it as DeltaLake table named 'stockdata.bronze_market_data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "cell_id": "5c253efa-9f20-478f-b0d9-c05a165a8c18",
    "dms": {
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_438144"
    },
    "execution_count": 14,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "# Read the data (make sure it includes a date column).\n",
    "path = OSS_OUTPUT_PATH + \"/date=2023-01-03/\"\n",
    "raw_df = spark.read \\\n",
    "    .option(\"basePath\", OSS_OUTPUT_PATH) \\\n",
    "    .parquet(path)\n",
    "(raw_df.write.format(\"delta\")\n",
    "    .mode(\"append\")              \n",
    "    .option(\"overwriteSchema\", \"true\") \n",
    "    .partitionBy(\"date\")            # Key: Explicitly specify the partition key to ensure physical storage is isolated by date.\n",
    "    .saveAsTable(\"stockdata.bronze_market_data\"))\n",
    "print(\">>>The table structure has been reset, and the data has been written successfully!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "cell_id": "6858195b-6bef-4183-b791-e6be9bfe4005",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zee0njw3p2bf2gh",
     "exec_type": "spark_sql",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_403902"
    },
    "execution_count": 15,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": [],
    "vscode": {
     "languageId": "sql"
    }
   },
   "outputs": [],
   "source": [
    "SELECT * FROM stockdata.bronze_market_data LIMIT 10;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "cf3e4b80-8dc5-4ec1-ac64-3d0ae9c57868",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zee0njw3p2bf2gh",
     "exec_type": "spark_sql",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_825641"
    },
    "execution_count": null
   },
   "source": [
    "## 5.2. Build the silver layer in DeltaLake table format to store the stock feature table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "cell_id": "98bb8a61-2688-4f8c-bef7-f55ec30d8d90",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "633135",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_235198"
    },
    "execution_count": 16,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import col\n",
    "from pyspark.sql.types import *\n",
    "from delta.tables import *\n",
    "\n",
    "# =================================================================================\n",
    "# 0. Configuration and Initialization\n",
    "# ===================================================================================\n",
    "BRONZE_TABLE = f\"{DB_NAME}.bronze_market_data\"\n",
    "SILVER_TABLE = f\"{DB_NAME}.silver_features\"\n",
    "\n",
    "# [Optimization Point 1]: Enable Arrow optimization configuration to accelerate Pandas UDF data transfer\n",
    "# Enable Delta automatic write optimization (solves small file issue)\n",
    "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "spark.conf.set(\"spark.databricks.delta.optimizeWrite.enabled\", \"true\") \n",
    "spark.conf.set(\"spark.databricks.delta.autoCompact.enabled\", \"true\")\n",
    "\n",
    "# =================================================================================\n",
    "# 2. Read Data (Keep Data Clean)\n",
    "# ====================================================================================\n",
    "print(f\">>> [Step 2] Read the table: {BRONZE_TABLE}\")\n",
    "spark.catalog.refreshTable(BRONZE_TABLE)\n",
    "\n",
    "# Clean the input data, only extracting the necessary physical columns.\n",
    "clean_cols = [\"date\", \"ticker\", \"open\", \"high\", \"low\", \"close\", \"volume\"]\n",
    "bronze_df = spark.table(BRONZE_TABLE).select(*clean_cols)\n",
    "\n",
    "# ==============================================================================\n",
    "# 3. Define the calculation logic (fix type errors + performance optimization)\n",
    "# ==============================================================================\n",
    "def calculate_tech_indicators(pdf: pd.DataFrame) -> pd.DataFrame:\n",
    "    # A. Ensure sorting by time\n",
    "    pdf = pdf.sort_values(\"date\")\n",
    "    \n",
    "    # B. Calculate RSI (vectorized computation, extremely fast)    \n",
    "    close_series = pdf['close']\n",
    "    delta = close_series.diff()\n",
    "    up = delta.clip(lower=0)\n",
    "    down = -1 * delta.clip(upper=0)\n",
    "    \n",
    "    # ewm: Exponentially weighted moving average (alpha=1/14)\n",
    "    ma_up = up.ewm(com=13, adjust=False).mean()\n",
    "    ma_down = down.ewm(com=13, adjust=False).mean()\n",
    "    rs = ma_up / ma_down\n",
    "    rsi = 100 - (100 / (1 + rs))\n",
    "    \n",
    "    # Assignment\n",
    "    pdf['rsi_14'] = rsi.fillna(0)\n",
    "    \n",
    "    # [Optimization Point 2 - Critical Fix]: Force date type conversion\n",
    "    # PyArrow does not support direct serialization of Python datetime.date objects; they must be converted to String.\n",
    "    pdf['date'] = pdf['date'].astype(str)\n",
    "    return pdf\n",
    "\n",
    "# ==============================================================================\n",
    "# 4. Perform distributed computing\n",
    "# ==============================================================================\n",
    "print(\">>> [Step 3]Start parallel calculation of RSI by stock grouping...\")\n",
    "\n",
    "# Define Output Schema using DDL\n",
    "output_schema_ddl = \"\"\"\n",
    "    date string,\n",
    "    ticker string,\n",
    "    open double,\n",
    "    high double,\n",
    "    low double,\n",
    "    close double,\n",
    "    volume double,\n",
    "    rsi_14 double\n",
    "\"\"\"\n",
    "\n",
    "# Grouped Parallel Computation\n",
    "# Note: Spark automatically handles shuffle, sending data from the same Ticker to the same Executor.\n",
    "silver_df = bronze_df.groupby(\"ticker\").applyInPandas(\n",
    "    calculate_tech_indicators, \n",
    "    schema=output_schema_ddl)\n",
    "\n",
    "# ==============================================================================\n",
    "# 5. Write to the Silver table (Merge + Z-Order optimization)\n",
    "# ==============================================================================\n",
    "print(f\">>> [Step 4]Ready to write to Silver table: {SILVER_TABLE}\")\n",
    "\n",
    "# Before merging, you must ensure that the (ticker, date) combination is unique; dropDuplicates will keep the first one and discard duplicates.\n",
    "print(\"   -> Deduplication is being performed on the source data...\")\n",
    "silver_df_deduped = silver_df.dropDuplicates([\"ticker\", \"date\"])\n",
    "\n",
    "if spark.catalog.tableExists(SILVER_TABLE):\n",
    "    print(\"   -> The table already exists; execute Delta Merge (Upsert)...\")\n",
    "    deltaTable = DeltaTable.forName(spark, SILVER_TABLE)\n",
    "    \n",
    "    # Perform a merge (using the deduplicated dataframe).\n",
    "    (deltaTable.alias(\"target\")\n",
    "      .merge(\n",
    "        silver_df_deduped.alias(\"source\"), \n",
    "        \"target.ticker = source.ticker AND target.date = source.date\"\n",
    "      )\n",
    "      .whenMatchedUpdateAll()\n",
    "      .whenNotMatchedInsertAll()\n",
    "      .execute())\n",
    "    print(\"   -> Merge complete。\")\n",
    "    \n",
    "else:\n",
    "    print(\"   ->If the table does not exist, perform a full initialization (Create)...\")\n",
    "    (silver_df_deduped.write\n",
    "        .format(\"delta\")\n",
    "        .mode(\"overwrite\")\n",
    "        .partitionBy(\"date\")\n",
    "        .saveAsTable(SILVER_TABLE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "cell_id": "7eae0f20-97b4-47b0-9fef-c1af2bfa6005",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zee0njw3p2bf2gh",
     "exec_type": "spark_sql",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_387787"
    },
    "execution_count": 17,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": [],
    "vscode": {
     "languageId": "sql"
    }
   },
   "outputs": [],
   "source": [
    "select * from stockdata.silver_features limit 10;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "fc9012aa-3496-4b8e-ae18-ebeb181fff62",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-2zee0njw3p2bf2gh",
     "exec_type": "spark_sql",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "tenant_id": "633135",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_715455"
    },
    "execution_count": null
   },
   "source": [
    "## 5.3.Build the gold layer as DeltaLake format and prepare the training dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "cell_id": "c2e45965-9918-4026-a708-b2d84223dc62",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "633135",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_835002"
    },
    "execution_count": 18,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "from pyspark.sql.functions import lead, col, expr\n",
    "from pyspark.sql.window import Window\n",
    "\n",
    "GOLD_TABLE = f\"{DB_NAME}.gold_training_set\"\n",
    "\n",
    "print(f\">>> [Step 3]Gold Layer: Constructing the training sample table {GOLD_TABLE}\")\n",
    "\n",
    "# 1. Read Silver layer features\n",
    "# ------------------------------------------------------------------\n",
    "silver_df = spark.table(SILVER_TABLE)\n",
    "\n",
    "# 2. Construct a label (prediction target)\n",
    "# Assumption objective: Predict the rate of return \"tomorrow\".\n",
    "# Logic: Use the lead function to align next day's closing price to today's.\n",
    "# ------------------------------------------------------------------\n",
    "window_spec = Window.partitionBy(\"date\").orderBy(\"ticker\")\n",
    "\n",
    "# Label = (Tomorrow's closing price / Today's closing price) - 1\n",
    "gold_df = silver_df.withColumn(\n",
    "    \"label\", \n",
    "    (lead(\"close\", 1).over(window_spec) / col(\"close\")) - 1)\n",
    "\n",
    "# Filter out the last day (because there is no data for tomorrow, so the label will be empty).\n",
    "gold_df = gold_df.dropna(subset=[\"label\"])\n",
    "\n",
    "# 3. Feature vectorization (Vector Assembly)\n",
    "# Spark XGBoost requires merging all feature columns into a single Vector type column.\n",
    "# ------------------------------------------------------------------\n",
    "feature_cols = [\"open\", \"high\", \"low\", \"close\", \"volume\", \"rsi_14\"]\n",
    "assembler = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
    "gold_df_final = assembler.transform(gold_df).select(\"date\", \"ticker\", \"features\", \"label\")\n",
    "\n",
    "# 4. Write to Gold table\n",
    "# ------------------------------------------------------------------\n",
    "print(f\"   -> Write to the Gold table (containing the Features vector and Label)...\")\n",
    "if spark.catalog.tableExists(GOLD_TABLE):\n",
    "    # Scenario where simulation data continuously accumulates\n",
    "    gold_df_final.write.format(\"delta\").mode(\"append\").saveAsTable(GOLD_TABLE)\n",
    "else:\n",
    "    gold_df_final.write.format(\"delta\").mode(\"overwrite\").saveAsTable(GOLD_TABLE)\n",
    "\n",
    "# ================================================================================\n",
    "# Demo: Time Travel\n",
    "# Scenario: We find that the newly generated data today has problems, causing model training errors. I want to read the data from the \"previous version\".\n",
    "# ===================================================================================\n",
    "print(\"\\n>>> [Time Travel Demo] Demo version rollback...\")\n",
    "\n",
    "# Method A: Based on version number (most robust, suitable for demos)\n",
    "# versionAsOf=0 represents the state of the table when it was first created\n",
    "try:\n",
    "    df_v0 = spark.read.format(\"delta\").option(\"versionAsOf\", 0).table(GOLD_TABLE)\n",
    "    print(f\"   -> Successfully read Version 0 data, number of rows: {df_v0.count()}\")\n",
    "except Exception as e:\n",
    "    print(\"   -> Unable to read Version 0 (possibly first creation).\")\n",
    "\n",
    "# Method B: Based on timestamps (commonly used in production environments)\n",
    "# Note: This requires the table to actually exist at that point in time.\n",
    "# df_snapshot = spark.read.format(\"delta\").option(\"timestampAsOf\", \"2023-09-27 12:00:00\").table(GOLD_TABLE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "cell_id": "6fadc1a1-e71a-487c-8546-87c2287e5aa5",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "633135",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_47517"
    },
    "execution_count": 19,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "gold_df_final.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "de6f4bb6-390a-40eb-9a77-74b0f5e23565",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "633135",
     "user": "",
     "user_id": "1441649",
     "username": "wangyouzhuo",
     "variable": "output_856771"
    },
    "execution_count": null
   },
   "source": [
    "# 6. Start the xgboost model training and evaluate it (one-shot)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "cell_id": "58990547-ec74-40c4-82b9-cd91ad26ef52",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "633135",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_489796"
    },
    "execution_count": 20,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "from xgboost.spark import SparkXGBRegressor\n",
    "from pyspark.ml.evaluation import RegressionEvaluator\n",
    "\n",
    "# Configure the model save path (OSS path)\n",
    "MODEL_OUTPUT_PATH = f\"oss://{OSS_BUCKET}/models/guccidemo/\"\n",
    "\n",
    "print(f\"\\n>>> [Step 4] Model Training: Distributed XGBoost Training\")\n",
    "\n",
    "# 1. Read the training data (directly read the Delta Lake table in ADB MySQL).\n",
    "# ------------------------------------------------------------------\n",
    "train_df = spark.table(GOLD_TABLE)\n",
    "\n",
    "# We simply divide the training and test sets (split by time would be more rigorous, but here we'll use random splitting for the demo).\n",
    "train_data, test_data = train_df.randomSplit([0.8, 0.2], seed=42)\n",
    "print(f\"   -> Training set size:{train_data.count()}, Test set size:{test_data.count()}\")\n",
    "\n",
    "# 2.Define the XGBoost regressor\n",
    "# ------------------------------------------------------------------\n",
    "# num_workers: Set to the number of Executors in the Spark cluster.\n",
    "xgb = SparkXGBRegressor(\n",
    "    features_col=\"features\",\n",
    "    label_col=\"label\",\n",
    "    num_workers=4,          \n",
    "    learning_rate=0.1,\n",
    "    max_depth=5,\n",
    "    missing=0.0             # Missing value handling\n",
    ")\n",
    "\n",
    "# 3. Model Training (Fit)\n",
    "# ------------------------------------------------------------------\n",
    "print(\"   -> Begin distributed training...\")\n",
    "model = xgb.fit(train_data)\n",
    "\n",
    "# 4. Model Evaluation\n",
    "# ------------------------------------------------------------------\n",
    "predictions = model.transform(test_data)\n",
    "evaluator = RegressionEvaluator(labelCol=\"label\", predictionCol=\"prediction\", metricName=\"rmse\")\n",
    "rmse = evaluator.evaluate(predictions)\n",
    "print(f\"   -> Model evaluation RMSE: {rmse:.6f}\")\n",
    "\n",
    "# 5. Save the model (Model Registry)\n",
    "# ------------------------------------------------------------------\n",
    "# In production environments, MLflow is typically stored; this demonstration shows how to store it in OSS.\n",
    "print(f\"   ->Save the model to:{MODEL_OUTPUT_PATH}\")\n",
    "model.write().overwrite().save(MODEL_OUTPUT_PATH)\n",
    "\n",
    "print(\"\\n>>> The entire process is now complete! From data cleaning to model training, the data journey is finished.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "a754dc43-5e9c-4d3a-9dde-19327d737c80",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "661889",
     "user": "",
     "user_id": "1508718",
     "username": "dms_oneops",
     "variable": "output_648286"
    },
    "execution_count": null
   },
   "source": [
    "# 7. Define the MLFlow's fixed internal service IP address (Notebook&MLflow integration)\n",
    "\n",
    " - Before connecting, make sure you have manually added Notebook VPC's IPv4 CIDR Block into MLFlow's ALB ACL group.\n",
    "\n",
    " - So notebook can talk to your MLflow and register the model path in MLflow's server-side.\n",
    "\n",
    " - Then define your first experiment name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "cell_id": "218c0845-d124-422a-96b2-4fe9698eb0c9",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "661889",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_878580"
    },
    "execution_count": 21,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "import os\n",
    "from mlflow.exceptions import MlflowException\n",
    "\n",
    "\n",
    "remote_server_uri = \"http://****\"\n",
    "experiment_name = \"/test/TestExperiment_1\"\n",
    "artifact_location = f\"oss://{OSS_BUCKET}/test_mlflow/experiment/\"\n",
    "\n",
    "def init_mlflow_experiment():\n",
    "    mlflow.set_tracking_uri(remote_server_uri)\n",
    "    print(f\"Tracking URI: {mlflow.get_tracking_uri()}\")\n",
    "\n",
    "    try:\n",
    "        # Check if the experiment already exists.\n",
    "        experiment = mlflow.get_experiment_by_name(experiment_name)\n",
    "        \n",
    "        if experiment is None:\n",
    "            print(f\"Creating a new experiment: {experiment_name}\")\n",
    "            # Create an experiment and specify the OSS storage location.\n",
    "            experiment_id = mlflow.create_experiment(\n",
    "                name=experiment_name,\n",
    "                artifact_location=artifact_location\n",
    "            )\n",
    "        else:\n",
    "            # Check if the experiment is in a deleted state.\n",
    "            if experiment.lifecycle_stage == \"deleted\":\n",
    "                print(f\"Warning: Experiment '{experiment_name}' It has been marked for deletion. Please restore it or change its name in the MLflow UI.\")\n",
    "                experiment_id = experiment.experiment_id\n",
    "            else:\n",
    "                print(f\"Experiment '{experiment_name}'already exists. Please reuse it.\")\n",
    "                experiment_id = experiment.experiment_id\n",
    "        \n",
    "        mlflow.set_experiment(experiment_name)\n",
    "        return experiment_id\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"An error occurred while initializing the MLflow experiment: {e}\")\n",
    "        return None\n",
    "\n",
    "exp_id = init_mlflow_experiment()\n",
    "\n",
    "# --- Test the first round Run ---\n",
    "if exp_id:\n",
    "    with mlflow.start_run():\n",
    "        mlflow.log_param(\"status\", \"successfully_initialized\")\n",
    "        print(f\"Training record successfully started, experiment ID:{exp_id}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "f2c9d7bd-72ac-4ff8-a854-efce1e0208ec",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "661889",
     "user": "",
     "user_id": "1508718",
     "username": "dms_oneops",
     "variable": "output_448330"
    },
    "execution_count": null
   },
   "source": [
    "# 8. Performing large-scale distributed model training and hyperparameter tuning using PySpark, XGBoost, and MLflow."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "6d844e9c-1bb4-4c48-83c3-36467be65739",
    "dms": {},
    "execution_count": null
   },
   "source": [
    "## 8.1. Experiment tracking and hyperparameter tuning\n",
    "\n",
    " - *Business Pain Point*:  The user repeatedly tried different learning rates (0.1, 0.05, 0.01), running the training many times, and eventually forgot which set of parameters yielded the lowest RMSE, and also couldn't remember which day's data was used.\n",
    " - *Solution*:  Use `mlflow.start_run` in conjunction with a loop to automatically record the parameters, metrics, and model files for each training run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "cell_id": "1a6f0913-9c59-44aa-aca7-e3e82e2cfe49",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "661889",
     "user": "",
     "user_id": 2102049,
     "username": "muze.xjw",
     "variable": "output_627185"
    },
    "execution_count": 22,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "import mlflow.spark\n",
    "from xgboost.spark import SparkXGBRegressor\n",
    "from pyspark.ml.evaluation import RegressionEvaluator\n",
    "from pyspark.sql.functions import col\n",
    "\n",
    "# 1. Set up the experiment name (just like creating a repository in Git).\n",
    "experiment_name = \"/test/TestExperiment_1\"\n",
    "mlflow.set_experiment(experiment_name)\n",
    "\n",
    "# 2. Load training data from the existing table -> Gold Table\n",
    "gold_df = spark.table(f\"{DB_NAME}.gold_training_set\")\n",
    "train_data, test_data = gold_df.randomSplit([0.8, 0.2], seed=42)\n",
    "\n",
    "# 3. Define the hyperparameter search space.\n",
    "params_grid = [\n",
    "    {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 100},\n",
    "    {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 200},\n",
    "    {'learning_rate': 0.01, 'max_depth': 6, 'n_estimators': 300}]\n",
    "\n",
    "print(\">>> Starting hyperparameter tuning (Experiment Tracking)...\")\n",
    "\n",
    "# Enable Spark automatic logging (which will record Spark version, data path, etc.)\n",
    "mlflow.spark.autolog()\n",
    "\n",
    "for params in params_grid:\n",
    "    # Key step: Starting an MLflow Run\n",
    "    with mlflow.start_run(run_name=f\"lr_{params['learning_rate']}_depth_{params['max_depth']}\"):\n",
    "        # A. Record parameters\n",
    "        mlflow.log_params(params)\n",
    "        \n",
    "        # B. Recording data versions (Delta Lake's Time Travel feature)\n",
    "        # Obtain the latest version number of the Delta Table to ensure reproducibility.\n",
    "        delta_version = spark.sql(f\"DESCRIBE HISTORY {DB_NAME}.gold_training_set\").select(\"version\").first()[0]\n",
    "        mlflow.log_param(\"data_version\", delta_version)\n",
    "        mlflow.set_tag(\"data_source\", f\"{DB_NAME}.gold_training_set\")\n",
    "\n",
    "        # C. Train the model\n",
    "        xgb = SparkXGBRegressor(\n",
    "            features_col=\"features\", \n",
    "            label_col=\"label\",\n",
    "            num_workers=2,\n",
    "            **params\n",
    "        )\n",
    "        model = xgb.fit(train_data)\n",
    "        \n",
    "        # D. Evaluate and record the indicators.\n",
    "        predictions = model.transform(test_data)\n",
    "        evaluator = RegressionEvaluator(labelCol=\"label\", predictionCol=\"prediction\", metricName=\"rmse\")\n",
    "        rmse = evaluator.evaluate(predictions)\n",
    "        \n",
    "        # Manually record key metrics (Autolog will also record them, but manual recording is safer).\n",
    "        mlflow.log_metric(\"rmse\", rmse)\n",
    "        \n",
    "        # E. Record model entities\n",
    "        mlflow.spark.log_model(model, \"model\")\n",
    "        print(f\"Run completed - Params: {params} | RMSE: {rmse:.6f}\")\n",
    "print(\">>> All experiments have finished running. Please check the comparison charts in the MLflow UI.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "80f73933-b819-45e4-8cca-4d390188fee6",
    "dms": {
     "datalake_type": "AdbMySQL",
     "db_cluster": "amv-uf6lkx9h79t76q12",
     "exec_type": "",
     "page_index": [
      1
     ],
     "page_size": [
      20
     ],
     "password": "",
     "tenant_id": "661889",
     "user": "",
     "user_id": "1508718",
     "username": "dms_oneops",
     "variable": "output_462486"
    },
    "execution_count": null
   },
   "source": [
    "## 8.2. Model selection and registration\n",
    "\n",
    " - **Business Pain Point**: After the experiment is completed, the Run ID is a string of random characters (e.g., a1b2c3...), making it difficult to tell the trading system to \"deploy the model with the lowest RMSE.\"\n",
    " - **Solution**: Automatically select the best run through code and register it as a model version with a \"friendly name\" (Version 1, 2...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "cell_id": "b34693de-2dda-4469-bc07-a233215c54be",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 24,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "from mlflow.tracking import MlflowClient\n",
    "client = MlflowClient()\n",
    "\n",
    "# 1. Search for the best experimental records\n",
    "runs = client.search_runs(\n",
    "    experiment_ids=[client.get_experiment_by_name(experiment_name).experiment_id],\n",
    "    filter_string=\"\",\n",
    "    order_by=[\"metrics.rmse ASC\"],\n",
    "    max_results=1)\n",
    "\n",
    "best_run = runs[0]\n",
    "best_run_id = best_run.info.run_id\n",
    "best_rmse = best_run.data.metrics['rmse']\n",
    "print(f\">>> Found best model Run ID: {best_run_id}, RMSE: {best_rmse}\")\n",
    "\n",
    "# 2. Model Registry\n",
    "# Model name: Quant_A_Share_Prediction\n",
    "model_name = \"Quant_A_Share_Prediction\"\n",
    "model_uri = f\"runs:/{best_run_id}/model\"\n",
    "\n",
    "print(f\">>> Registering the model with the Registry...: {model_name}...\")\n",
    "model_details = mlflow.register_model(model_uri=model_uri, name=model_name)\n",
    "\n",
    "# 3. Simulated approval process: Promoting the model from \"None\" to \"Staging\" (pre-production environment).\n",
    "client.transition_model_version_stage(\n",
    "    name=model_name,\n",
    "    version=model_details.version,\n",
    "    stage=\"Staging\",\n",
    "    archive_existing_versions=True  # Automatically archive old versions.\n",
    ")\n",
    "\n",
    "print(f\">>> Model {model_name} version {model_details.version} Successfully promoted to Staging status!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "4a2eabcd-edc8-45d4-a50b-5cf7c66d5c6c",
    "dms": {
     "tenant_id": "661889",
     "user_id": "1508718",
     "variable": "output_475299"
    },
    "execution_count": null
   },
   "source": [
    "## 8.3. Distributed inference in a production environment\n",
    "\n",
    " - **Business Pain Point**: The next morning, the user has new data for 5000 stocks (stored in the Silver table), and needs to perform large-scale parallel predictions using a model in the Staging state to generate trading signals.\n",
    "- **Solution**: Use mlflow.pyfunc.spark_udf. This is a powerful tool for Spark projects; it can load any model managed by MLflow as a Spark UDF, leveraging the cluster's parallel computing capabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "cell_id": "19964083-5e26-4b8f-a70e-ad1a7f3e26f3",
    "dms": {
     "user_id": 2102049,
     "username": "muze.xjw"
    },
    "execution_count": 25,
    "file_path": "/Workspace/code/MarkovMLFlowIntegration (2) (2).ipynb"
   },
   "outputs": [],
   "source": [
    "import mlflow.pyfunc\n",
    "from pyspark.sql.functions import struct, col\n",
    "\n",
    "# 1. Dynamically load the model from the \"Staging\" environment\n",
    "# Regardless of whether the backend version is V5 or V10, the code only targets \"Staging\"\n",
    "model_uri = \"models:/Quant_A_Share_Prediction/Staging\"\n",
    "\n",
    "print(f\">>> Loading Staging model from MLflow for inference: {model_uri}\")\n",
    "\n",
    "# 2. Wrap the model as a Spark UDF\n",
    "# This step automatically handles Python dependencies and serialization\n",
    "predict_udf = mlflow.pyfunc.spark_udf(spark, model_uri)\n",
    "\n",
    "# 3. Read the latest prediction data (assuming it's today's Silver data)\n",
    "# Note: You need to construct the same feature vector as used during training, \n",
    "# typically by reusing VectorAssembler logic.\n",
    "# For demonstration purposes, we assume input_df is the data to be predicted.\n",
    "input_df = spark.table(f\"{DB_NAME}.gold_training_set\").filter(\"date = '2023-12-29'\")\n",
    "\n",
    "# 4. Perform distributed prediction\n",
    "# Note: XGBoost models usually require Vector types or specific columns as input.\n",
    "# The calling method depends on whether you saved a Pipeline or just an Estimator during log_model.\n",
    "# If a Pipeline (containing VectorAssembler) was saved, raw columns can be passed directly.\n",
    "# If only the Model was saved, the \"features\" column must be passed.\n",
    "\n",
    "print(\">>> Starting distributed prediction...\")\n",
    "predictions_df = input_df.withColumn(\"predicted_return\", \n",
    "    predict_udf(struct(\"features\")) # Pass the \"features\" column to the UDF\n",
    ")\n",
    "\n",
    "# 5. Generate trading signals (Example: Buy if predicted return > 2%)\n",
    "signals_df = predictions_df.select(\"date\", \"ticker\", \"predicted_return\") \\\n",
    "    .filter(\"predicted_return > 0.02\") \\\n",
    "    .orderBy(col(\"predicted_return\").desc())\n",
    "\n",
    "signals_df.show(10)\n",
    "print(\">>> Trading signals generated!\")"
   ]
  }
 ],
 "metadata": {
  "comm_id": "c6d32e60-1ebb-4a9d-a036-4ef5000cf599",
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
