import importlib
import io
import zipfile
from pathlib import Path
from jp2rt import HAS_PLOT
if HAS_PLOT:
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import numpy as np
from packaging.version import Version, parse
from scipy.stats.mstats import mquantiles
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import PredictionErrorDisplay
from sklearn.model_selection import cross_val_predict, cross_validate
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from tabulate import tabulate
from jp2rt import __version__
MANIFEST_VERSION = Version('1.0')
JP2RT_VERSION = parse(__version__)
MANIFEST = f'Manifest-Version: {MANIFEST_VERSION}\nJP2RT-Version: {JP2RT_VERSION}\n'
ENSEMBLE_REGRESSOR_MODULE = importlib.import_module('sklearn.ensemble')
[docs]
def save_model(model, path):
"""
Saves a model to a file.
Args:
model (:obj:`sklearn.base.BaseEstimator`): The model to save.
path (:obj:`str`): The path of the file to save the model to.
Returns:
:obj:`int`: The number of bytes written to the file.
"""
if not isinstance(path, Path):
path = Path(path)
mbuf = io.BytesIO()
joblib.dump(model, mbuf)
mbuf.flush()
zbuf = io.BytesIO()
with zipfile.ZipFile(zbuf, 'a', zipfile.ZIP_DEFLATED, True) as ouf:
ouf.writestr(f'{path.stem}/MANIFEST.txt', MANIFEST.encode('utf-8'))
ouf.writestr(f'{path.stem}/model.joblib', mbuf.getvalue())
return path.with_suffix('.jp2rt').write_bytes(zbuf.getvalue())
[docs]
def load_model(path):
"""
Loads a model from a file.
Args:
path (:obj:`str`): The path of the file to load the model from.
Returns:
:obj:`sklearn.base.BaseEstimator`: The loaded model.
"""
if not isinstance(path, Path):
path = Path(path)
zbuf = io.BytesIO(path.with_suffix('.jp2rt').read_bytes())
with zipfile.ZipFile(zbuf, 'r') as inf:
manifest = inf.read(f'{path.stem}/MANIFEST.txt').decode('utf-8')
with inf.open(f'{path.stem}/model.joblib') as mbuf:
model = joblib.load(mbuf)
manifest = dict(line.strip().split(': ') for line in manifest.splitlines() if line)
if 'Manifest-Version' not in manifest:
raise ValueError('Invalid model file, manifest missing Manifest-Version')
if parse(manifest['Manifest-Version']) > MANIFEST_VERSION:
raise ValueError('Invalid model file, manifest version too high')
if 'JP2RT-Version' not in manifest:
raise ValueError('Invalid model file, manifest missing JP2RT-Version')
if parse(manifest['JP2RT-Version']) > JP2RT_VERSION:
raise ValueError('Invalid model file, jp2rt version too high')
return model
[docs]
def load_retention_times(path):
"""Loads retention times values from a file and returns them as a numpy array.
The input file must be in tab separated format, must not have an header, and the
retention time must be the first field on every row.
Args:
path (:obj:`str`): The path of the tab separated file to read.
Returns:
:obj:`numpy.array`: the retention times values.
"""
return np.genfromtxt(path, delimiter='\t', comments=None, usecols=(0,))
[docs]
def load_descriptors(path):
"""Loads molecular descriptors values from a file and returns them as a numpy array.
The input file must be in tab separated format, must not have an header, and the
descriptors must be the last fields of every row and must be preceded by a non
numeric field (for instance a SMILES) on every row.
Args:
path (:obj:`str`): The path of the tab separated file to read.
Returns:
:obj:`numpy.array`: the descriptor values.
"""
with open(path) as inf:
first_line = inf.readline()
fields = first_line.split('\t')
n_fields = len(fields)
for n_descriptors, field in enumerate(reversed(fields)): # noqa: B007
try:
float(field)
except ValueError:
break
return np.genfromtxt(path, delimiter='\t', comments=None, usecols=range(n_fields - n_descriptors, n_fields))
[docs]
def list_ensemble_models():
"""Lists the available ensemble models names."""
res = []
for regressor in ENSEMBLE_REGRESSOR_MODULE.__all__:
if not regressor.endswith('Regressor'):
continue
try:
getattr(ENSEMBLE_REGRESSOR_MODULE, regressor)()
res.append(regressor.removesuffix('Regressor'))
except TypeError:
pass
return sorted(res)
[docs]
def simple_ensemble_model_estimate(regressor_name, X, y):
"""
Trains a simple ensemble model using the given regressor and the input data.
Args:
regressor_name (:obj:`str`): The name of the regressor to use.
X (:obj:`numpy.array`): The input data.
y (:obj:`numpy.array`): The target values.
Returns:
:obj:`sklearn.base.BaseEstimator`: The trained model.
"""
if regressor_name not in list_ensemble_models():
raise ValueError(f'Invalid regressor name: {regressor_name}')
regressor = getattr(ENSEMBLE_REGRESSOR_MODULE, regressor_name + 'Regressor')()
all_nan_cols = [i for i, x in enumerate(X.T) if all(np.isnan(x))]
drop_all_nan_cols = ColumnTransformer([('drop_all_nan_cols', 'drop', all_nan_cols)], remainder='passthrough')
imputer = SimpleImputer(strategy='mean')
scaler = StandardScaler()
model = make_pipeline(drop_all_nan_cols, imputer, scaler, regressor)
model.fit(X, y)
return model
[docs]
def evaluate_model(model, X, y, n_splits=5, prob=0.95):
"""Evaluates a model using cross validation and visual inspection.
Args:
model (:obj:`sklearn.base.BaseEstimator`): The model to evaluate.
X (:obj:`numpy.array`): The input data.
y (:obj:`numpy.array`): The target values.
n_splits (:obj:`int`, optional): The number of splits to use for cross validation. Defaults to 5.
prob (:obj:`float`, optional): The confidence interval to use for the residuals distribution plot. Defaults to 0.95.
Returns:
:obj:`dict`: A dictionary with the evaluation results.
"""
cv = cross_validate(model, X, y, cv=n_splits, n_jobs=-1, scoring=['neg_root_mean_squared_error', 'r2'])
y_pred = cross_val_predict(model, X, y, cv=n_splits, n_jobs=-1)
residuals = y - y_pred
q0, q1 = mquantiles(residuals, prob=[1 - prob, prob])
if HAS_PLOT:
fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
axs[0].set_title('Actual vs. Predicted values')
PredictionErrorDisplay.from_predictions(
y, y_pred=y_pred, kind='actual_vs_predicted', ax=axs[0], scatter_kwargs={'s': 1}
)
axs[1].set_title('Residuals vs. Predicted Values')
PredictionErrorDisplay.from_predictions(
y, y_pred=y_pred, kind='residual_vs_predicted', ax=axs[1], scatter_kwargs={'s': 1}
)
axs[2].set_title('Residuals distribution')
sns.histplot(residuals, ax=axs[2], bins=50, kde=True)
axs[2].axvspan(q0, q1, alpha=0.2)
plt.tight_layout()
plt.close(fig) # to avid displaying it
metrics = {
'r2 mean': cv['test_r2'].mean(),
'r2 std': cv['test_r2'].std(),
'rmse mean': -cv['test_neg_root_mean_squared_error'].mean(),
'rmse std': cv['test_neg_root_mean_squared_error'].std(),
'q0': q0,
'q1': q1,
}
return {
'args': {'n_splits': n_splits, 'prob': prob},
'details': {'r2': cv['test_r2'], 'rmse': -cv['test_neg_root_mean_squared_error']},
'metrics': metrics,
'plot': fig if HAS_PLOT else None,
'table': tabulate(metrics.items(), tablefmt='outline', floatfmt='.4f'),
}