Слияние кода завершено, страница обновится автоматически
from typing import cast
import numpy as np
import polars as pl
import lightgbm as lgb
import matplotlib.pyplot as plt
from vnpy.alpha.dataset import AlphaDataset, Segment
from vnpy.alpha.model import AlphaModel
class LgbModel(AlphaModel):
"""LightGBM ensemble learning algorithm"""
def __init__(
self,
learning_rate: float = 0.1,
num_leaves: int = 31,
num_boost_round: int = 1000,
early_stopping_rounds: int = 50,
log_evaluation_period: int = 1,
seed: int | None = None
):
"""
Parameters
----------
learning_rate : float
Learning rate
num_leaves : int
Number of leaf nodes
num_boost_round : int
Maximum number of training rounds
early_stopping_rounds : int
Number of rounds for early stopping
log_evaluation_period : int
Interval rounds for printing training logs
seed : int | None
Random seed
"""
self.params: dict = {
"objective": "mse",
"learning_rate": learning_rate,
"num_leaves": num_leaves,
"seed": seed
}
self.num_boost_round: int = num_boost_round
self.early_stopping_rounds: int = early_stopping_rounds
self.log_evaluation_period: int = log_evaluation_period
self.model: lgb.Booster | None = None
def _prepare_data(self, dataset: AlphaDataset) -> list[lgb.Dataset]:
"""
Prepare data for training and validation
Parameters
----------
dataset : AlphaDataset
The dataset containing features and labels
Returns
-------
list[lgb.Dataset]
List of LightGBM datasets for training and validation
"""
ds: list[lgb.Dataset] = []
# Process training and validation separately
for segment in [Segment.TRAIN, Segment.VALID]:
# Get data for learning
df: pl.DataFrame = dataset.fetch_learn(segment)
df = df.sort(["datetime", "vt_symbol"])
# Convert to numpy arrays
data = df.select(df.columns[2: -1]).to_pandas()
label = np.array(df["label"])
# Add training data
ds.append(lgb.Dataset(data, label=label))
return ds
def fit(self, dataset: AlphaDataset) -> None:
"""
Fit the model using the dataset
Parameters
----------
dataset : AlphaDataset
The dataset containing features and labels
Returns
-------
None
"""
# Prepare task data
ds: list[lgb.Dataset] = self._prepare_data(dataset)
# Execute model training
self.model = lgb.train(
self.params,
ds[0],
num_boost_round=self.num_boost_round,
valid_sets=ds,
valid_names=["train", "valid"],
callbacks=[
lgb.early_stopping(self.early_stopping_rounds), # Early stopping callback
lgb.log_evaluation(self.log_evaluation_period) # Logging callback
]
)
def predict(self, dataset: AlphaDataset, segment: Segment) -> np.ndarray:
"""
Make predictions using the trained model
Parameters
----------
dataset : AlphaDataset
The dataset containing features
segment : Segment
The segment to make predictions on
Returns
-------
np.ndarray
Prediction results
Raises
------
ValueError
If the model has not been fitted yet
"""
# Check if model exists
if self.model is None:
raise ValueError("model is not fitted yet!")
# Get data for inference
df: pl.DataFrame = dataset.fetch_infer(segment)
df = df.sort(["datetime", "vt_symbol"])
# Convert to numpy array
data: np.ndarray = df.select(df.columns[2: -1]).to_numpy()
# Return prediction results
result: np.ndarray = cast(np.ndarray, self.model.predict(data))
return result
def detail(self) -> None:
"""
Display model details with feature importance plots
Generates two plots showing feature importance based on
'split' and 'gain' metrics.
Returns
-------
None
"""
if not self.model:
return
for importance_type in ["split", "gain"]:
ax: plt.Axes = lgb.plot_importance(
self.model,
max_num_features=50,
importance_type=importance_type,
figsize=(10, 20)
)
ax.set_title(f"Feature Importance ({importance_type})")
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )