1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/mirrors-vn-py

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
lgb_model.py 4.7 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
vn.py Отправлено 26.03.2025 08:56 ec8f986
from typing import cast
import numpy as np
import polars as pl
import lightgbm as lgb
import matplotlib.pyplot as plt
from vnpy.alpha.dataset import AlphaDataset, Segment
from vnpy.alpha.model import AlphaModel
class LgbModel(AlphaModel):
"""LightGBM ensemble learning algorithm"""
def __init__(
self,
learning_rate: float = 0.1,
num_leaves: int = 31,
num_boost_round: int = 1000,
early_stopping_rounds: int = 50,
log_evaluation_period: int = 1,
seed: int | None = None
):
"""
Parameters
----------
learning_rate : float
Learning rate
num_leaves : int
Number of leaf nodes
num_boost_round : int
Maximum number of training rounds
early_stopping_rounds : int
Number of rounds for early stopping
log_evaluation_period : int
Interval rounds for printing training logs
seed : int | None
Random seed
"""
self.params: dict = {
"objective": "mse",
"learning_rate": learning_rate,
"num_leaves": num_leaves,
"seed": seed
}
self.num_boost_round: int = num_boost_round
self.early_stopping_rounds: int = early_stopping_rounds
self.log_evaluation_period: int = log_evaluation_period
self.model: lgb.Booster | None = None
def _prepare_data(self, dataset: AlphaDataset) -> list[lgb.Dataset]:
"""
Prepare data for training and validation
Parameters
----------
dataset : AlphaDataset
The dataset containing features and labels
Returns
-------
list[lgb.Dataset]
List of LightGBM datasets for training and validation
"""
ds: list[lgb.Dataset] = []
# Process training and validation separately
for segment in [Segment.TRAIN, Segment.VALID]:
# Get data for learning
df: pl.DataFrame = dataset.fetch_learn(segment)
df = df.sort(["datetime", "vt_symbol"])
# Convert to numpy arrays
data = df.select(df.columns[2: -1]).to_pandas()
label = np.array(df["label"])
# Add training data
ds.append(lgb.Dataset(data, label=label))
return ds
def fit(self, dataset: AlphaDataset) -> None:
"""
Fit the model using the dataset
Parameters
----------
dataset : AlphaDataset
The dataset containing features and labels
Returns
-------
None
"""
# Prepare task data
ds: list[lgb.Dataset] = self._prepare_data(dataset)
# Execute model training
self.model = lgb.train(
self.params,
ds[0],
num_boost_round=self.num_boost_round,
valid_sets=ds,
valid_names=["train", "valid"],
callbacks=[
lgb.early_stopping(self.early_stopping_rounds), # Early stopping callback
lgb.log_evaluation(self.log_evaluation_period) # Logging callback
]
)
def predict(self, dataset: AlphaDataset, segment: Segment) -> np.ndarray:
"""
Make predictions using the trained model
Parameters
----------
dataset : AlphaDataset
The dataset containing features
segment : Segment
The segment to make predictions on
Returns
-------
np.ndarray
Prediction results
Raises
------
ValueError
If the model has not been fitted yet
"""
# Check if model exists
if self.model is None:
raise ValueError("model is not fitted yet!")
# Get data for inference
df: pl.DataFrame = dataset.fetch_infer(segment)
df = df.sort(["datetime", "vt_symbol"])
# Convert to numpy array
data: np.ndarray = df.select(df.columns[2: -1]).to_numpy()
# Return prediction results
result: np.ndarray = cast(np.ndarray, self.model.predict(data))
return result
def detail(self) -> None:
"""
Display model details with feature importance plots
Generates two plots showing feature importance based on
'split' and 'gain' metrics.
Returns
-------
None
"""
if not self.model:
return
for importance_type in ["split", "gain"]:
ax: plt.Axes = lgb.plot_importance(
self.model,
max_num_features=50,
importance_type=importance_type,
figsize=(10, 20)
)
ax.set_title(f"Feature Importance ({importance_type})")

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/mirrors-vn-py.git
git@api.gitlife.ru:oschina-mirror/mirrors-vn-py.git
oschina-mirror
mirrors-vn-py
mirrors-vn-py
master