# -*- coding: utf-8 -*-
"""
Created on Mon Mar 21 13:16:05 2022
@author: Alexandre Kenshilik Coche
@contact: alexandre.co@hotmail.fr
"""
#%% Imports:
import pandas as pd
import numpy as np
import xarray as xr
xr.set_options(keep_attrs = True)
import matplotlib.pyplot as plt
import numbers
from geop4th.graphics import cmapgenerator as cmg
from geop4th import geobricks as geo
import warnings
# import matplotlib.pyplot as plt, mpld3
# import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
# import plotly.graph_objects as go
# import plotly.offline as offline
# from plotly.subplots import make_subplots
# Animation dependencies - imported as needed
try:
import matplotlib.animation as manimation
except ImportError:
manimation = None
try:
from PIL import Image
except ImportError:
Image = None
#%% plot_time_series
[docs]
def plot_time_series(
*,
figweb = None,
data,
labels,
title = 'title',
linecolors = None,
fillcolors = None,
cumul = False,
date_ini_cumul = None,
reference = None,
ref_norm = None,
mean_norm = False,
mean_center = False,
legendgroup = None,
legendgrouptitle_text = None,
stack = False,
stackgroup = None,
col = None,
row = None,
lwidths = None,
lstyles = None,
yaxis = "y1",
fill = None,
mode = "lines",
markers = None,
showlegend = True,
visible = True,
bar_widths = None
):
"""
Description
-----------
This function provides a wrapper to facilitate the use of
`plotly.graph_objects <https://plotly.com/python-api-reference/plotly.graph_objects.html>`_
class. It facilitates the input of several arguments:
- data can be passed in any format
- colors can be passed in any format (np.arrays, lists of strings...)
which makes it possible to use indifferently *plotly* or *matplotlib* colormaps functions.
- colors can be passed universally as ``linecolors`` and ``fillcolors`` argument, no matter
what graphical function is used then by plotly (for instance go.Bar normally needs
colors to be passed in the marker dict whereas go.Scatter needs colors to be passed
in the line dict as well as through an fillcolor argument)
It also offers additional treatments in an easy way:
- plot cumulative values
- normalized values
Parameters
----------
figweb: plotly figure
Can plot on top of a previous plotly figure.
data: array of pandas.DataFrames
Data to plot.
labels: array of strings
Texts for legend.
title: string
Title used for the matplotlib figure (fig1, ax1). Not used for plotly
figure (figweb).
linecolors: np.array
Colors are stored in [R, G, B, Alpha].
For instance: linecolors = [[1.0, 0.5, 0.4, 1],[...]].
cumul: bool
Option to plot cumulated curves (True).
date_ini_cumul: string
The string should indicate the date in the format 'YYYY-MM-DD'.
(for instance: date_ini_cumul = '2000-07-31')
Only used when cumul = True.
reference: pandas.DataFrame
Used for displaying metrics (NSE, NSElog, VOLerr), computed against
the reference data provided here.
ref_norm: pandas.DataFrame (or xarray.DataSet, beta version...)
Used for plotting values normalized against the provided reference data.
mean_norm: bool
Option to normalize each curve against its mean (True).
mean_center: bool
Option to center each curve on its mean (True).
legendgroup: string
To group curves under the provided group identifier.
One group at a time.
legendgrouptitle_text:
Text associated with the group identifier.
stack: bool
Option to plot stacked curves (True).
col: int
Column number for subplots.
row: int
Row number for subplots.
visible: {True, False, "legendonly"}, optional, default True
Determines whether or not this trace is visible.
If ``"legendonly"``, the trace is not drawn, but can appear as a legend
item (provided that the legend itself is visible).
mode: {"markers", "lines", "lines+markers", "bar"}, optional, default "lines"
To select the representation mode.
markers : list of dict
Returns
-------
fig1: matplotlib figure
OBSOLETE: recent developments (normalized curves,
stacked curves...) have not been implemented in this figure.
ax1: matplotlib axis
Related to the previous figure. OBSOLETE.
figweb: plotly figure
This figure version includes all options.
Example
-------
from geop4th import ncplot as ncp
[_, _, figweb] = ncp.plot_time_series(data = [dfA, dfV],
labels = ['Altrui', 'Vergogn'])
"""
# ---- Initialization of data
# ============================
# 0. if data is None
if data is None:
print("Error: `data` is None")
return None, None, figweb
# 1. if data is a single xarray.Dataarray, it is embedded into a list
if isinstance(data, xr.DataArray):
data = [data]
# 2. if data is a single pandas.Series, it is embedded into a list
elif isinstance(data, pd.Series):
data = [data]
# 3.
elif isinstance(data, xr.Dataset):
# 3-A. if there is only one data_vars, it is embedded into a list
if len(data.data_vars) == 1:
data = [data]
# 3-B. if there are more than one data_vars, it is split into a list of several Datasets
else:
data = [data.loc[var] for var in data.data_vars]
# 4. if data is a single pandas.DataFrame, we first determine if this frame
# contains a single data set, or several
elif isinstance(data, pd.DataFrame):
# 4-A. if there is only one column, it is embedded into a list
if len(data.columns) == 1:
data = [data]
# 4-B. if there are more than one column, we determine whether one of the columns is the time index
else:
# 4-B-1. select datetime columns, or columns with 'time' or 't' names
assumed_time_col = data.select_dtypes(include=['datetime', 'datetimetz']).columns.union(data.columns.intersection(['time', 't']))
is_index_time = (data.index.name in ['time', 't']) | (isinstance(data.index, pd.DatetimeIndex))
if len(assumed_time_col) > 0:
if is_index_time:
# If there is already a time index, the time columns will be discarded
data = data.drop(assumed_time_col, axis = 1)
print(f"Warning: '{', '.join(assumed_time_col)}' columns have been discarded")
else:
data = data.set_index(assumed_time_col[0], drop = True)
if len(assumed_time_col) > 1:
data = data.drop(assumed_time_col[1:], axis = 1)
print(f"Warning: '{', '.join(assumed_time_col[1:])}' columns have been discarded")
# 4-B-2. if not, the index is assumed to be the time index
# Convert each columns into a single Dataframe
data = [data[[c]] for c in data.columns]
n_curves = len(data)
# ---- Convert to lists
# ======================
if labels is None:
labels = [None]*n_curves
else:
if isinstance(labels, str):
labels = [labels]
else:
labels = list(labels)
if len(labels) != n_curves:
print(f"/!\ {n_curves} data series but {len(labels)} labels")
if lwidths is None:
lwidths = [None]*n_curves
else:
if isinstance(lwidths, numbers.Number):
lwidths = [lwidths]*n_curves
elif isinstance(lwidths, tuple):
lwidths = list(lwidths)
if lstyles is None:
lstyles = [None]*n_curves
else:
if isinstance(lstyles, str):
lstyles = [lstyles]*n_curves
elif isinstance(lstyles, tuple):
lstyles = list(lstyles)
if bar_widths is None:
bar_widths = [None]*n_curves
else:
if isinstance(bar_widths, numbers.Number):
bar_widths = [bar_widths]*n_curves
# Handling colors
linecolors = cmg.to_rgba_str(linecolors)
if isinstance(linecolors, (str, tuple)):
linecolors = [linecolors]*n_curves
elif isinstance(linecolors, list):
if not isinstance(linecolors[0], (str, list, tuple)):
linecolors = [linecolors]*n_curves
if linecolors is None:
linecolors = [None]*n_curves
else:
linecolors = cmg.to_rgba_str(linecolors)
if isinstance(linecolors, tuple):
linecolors = [linecolors]*n_curves
elif isinstance(linecolors, str):
if linecolors in pc.named_colorscales():
linecolors = pc.sample_colorscale(
linecolors, n_curves, low=0.0, high=1.0,
colortype='rgb')
# linecolors = cmg.to_rgba_str(linecolors)
else:
linecolors = [linecolors]*n_curves
elif isinstance(linecolors, list):
if not isinstance(linecolors[0], (str, list, tuple)):
linecolors = [linecolors]*n_curves
if fillcolors is None:
fillcolors = [None]*n_curves
else:
fillcolors = cmg.to_rgba_str(fillcolors)
if isinstance(fillcolors, tuple):
fillcolors = [fillcolors]*n_curves
elif isinstance(fillcolors, str):
if fillcolors in pc.named_colorscales():
fillcolors = pc.sample_colorscale(
fillcolors, n_curves, low=0.0, high=1.0,
colortype='rgb')
# fillcolors = cmg.to_rgba_str(fillcolors)
else:
fillcolors = [fillcolors]*n_curves
elif isinstance(fillcolors, list):
if not isinstance(fillcolors[0], (str, list, tuple)):
fillcolors = [fillcolors]*n_curves
if markers is None:
markers = [None]*n_curves
else:
if isinstance(markers, dict):
markers = [markers]*n_curves
if legendgrouptitle_text is not None:
legendgrouptitle_text = '<b>' + legendgrouptitle_text + '</b>'
if row is not None:
if isinstance(row, (list, tuple)):
row = [int(r) for r in row]
else:
row = list(np.tile([int(row)], n_curves))
if col is not None:
if isinstance(col, (list, tuple)):
col = [int(r) for r in col]
else:
col = list(np.tile([int(col)], n_curves))
# ---- Formating
# ===============
for i in range(0, len(data)):
data[i] = data[i].copy() # Pour ne pas réécrire sur les variables
# data = data.copy() # Pour ne pas réécrire sur les variables
n_curves = len(data)
if len(labels) != n_curves:
print('/!\ ' + str(n_curves) + ' séries de données mais ' + str(len(labels)) + ' étiquettes')
# Conversion en pandas.dataframe avec 2 colonnes: 'val' et 'time':
for i in range(0, n_curves):
# conversion des datasets en dataframe
if isinstance(data[i], xr.Dataset):
# Déterminer le field :
_tmp_fields = list(data[i].data_vars)
# Créer le Dataframe :
_tmp_df = pd.DataFrame(data = data[i][_tmp_fields[0]].values,
columns = ['val'])
# Rajouter la colonne time :
_tmp_df['time'] = data[i]['time'].values
# Mettre à jour :
data[i] = _tmp_df
# conversion des series en dataframe
if isinstance(data[i], pd.Series):
data[i] = data[i].to_frame(name = 'val')
# Creation de la colonne 'time'
if 'time' not in data[i].columns:
if data[i].index.name == 'time':
data[i]['time'] = data[i].index
elif isinstance(data[i].index, pd.core.indexes.datetimes.DatetimeIndex):
data[i]['time'] = data[i].index
print(f" Warning: Data {i+1}/{len(data)}: index is used as time axis")
elif data[i].select_dtypes(include=['datetime', 'datetimetz']).shape[1] > 0:
assumed_time_col = data[i].select_dtypes(include=['datetime', 'datetimetz']).columns[0]
data.rename(columns = {assumed_time_col: 'time'}, inplace = True)
print(f" Warning: Data {i+1}/{len(data)}: column '{assumed_time_col}' is used as time axis")
# Creation de la colonne 'val'
if 'val' not in data[i].columns:
# not_time_col = set(data[i].columns).difference({'time'}, sort = False)
val_col = data[i].columns.difference(data[i].select_dtypes(include=['datetime', 'datetimetz', object]).columns, sort=False)
data[i].rename(columns = {val_col[0]: 'val'}, inplace = True)
if val_col.size > 1:
print(f" Warning: Data {i}/{len(data)}: column '{val_col[0]}' is used as main values column, but there are {val_col.size - 1} other candidate columns: {', '.join(val_col[1:])}")
# Valeurs cumulées :
if cumul:
for i in range(0, n_curves):
_tmp_df = data[i].copy(deep = True)
# Calcul des écarts temporels entre chaque valeur
timespan = _tmp_df.loc[
:, _tmp_df.columns == 'time'
].diff().shift(-1, fill_value = 0)/np.timedelta64(1, 'D')
# Calcul de la cumulée
# _tmp_df.iloc[:-1]['val'
# ] = (_tmp_df.iloc[:-1]['val'
# ] * timespan.values).cumsum(axis = 0)
_tmp_df[['val'
]] = (_tmp_df[['val'
]] * timespan.values).cumsum(axis = 0)
# # Correction de la dernière valeur
# _tmp_df.iloc[-1]['val'] = np.nan
# Alignement sur une date commune :
cond = _tmp_df.time.dt.normalize() == date_ini_cumul
# Si la date_ini_cumul existe dans le dataframe :
if cond.sum() != 0:
_tmp_df.loc[
:, 'val'] = _tmp_df.loc[
:, 'val'] - _tmp_df.loc[cond, 'val'].values
else:
_tmp_df.loc[
:, 'val'] = _tmp_df.loc[
:, 'val'] - _tmp_df.iloc[-1].loc['val']
data[i] = _tmp_df
# Valeurs normalisées par la moyenne :
if mean_norm:
for i in range(0, n_curves):
_tmp_df = data[i].copy(deep = True)
_tmp_df['val'] = _tmp_df['val'] / _tmp_df['val'].mean()
data[i] = _tmp_df.copy(deep = True)
# Valeurs centrées sur la moyenne :
if mean_center:
for i in range(0, n_curves):
_tmp_df = data[i].copy()
_tmp_df['val'] = _tmp_df['val'] - _tmp_df['val'].mean()
data[i] = _tmp_df
# Valeurs normalisées par rapport à une référence :
if ref_norm is not None:
# Conversion en pandas.dataframe :
if isinstance(ref_norm, xr.Dataset):
# Déterminer le field :
_tmp_fields = list(ref_norm.data_vars)
# Créer le Dataframe :
_tmp_df = pd.DataFrame(data = ref_norm[_tmp_fields[0]].values,
columns = ['val'])
# Rajouter la colonne time :
_tmp_df['time'] = ref_norm['time'].values
# Mettre à jour :
ref_norm = _tmp_df.copy(deep = True)
for i in range(0, n_curves):
_tmp_df = data[i].copy(deep = True)
_tmp_df['val'] = _tmp_df['val'] / ref_norm['val']
data[i] = _tmp_df
# Affichage des indicateurs (NSE, NSElog, KGE ...) :
if reference is not None:
VOLerr = [0]*len(data)
NSE = [0]*len(data)
NSElog = [0]*len(data)
reference.time.dt.tz_localize(None)
# Calcul des métriques :
for i in range(0, n_curves):
data[i]['time'] = data[i]['time'].dt.tz_localize(None)
temp = reference.merge(data[i], left_on = 'time', right_on = 'time')
NSE[i] = 1 - (np.sum((temp.val_x - temp.val_y)**2) / np.sum((temp.val_x - temp.val_x.mean())**2))
cond = (temp.val_x > 0) & (temp.val_y > 0)
NSElog[i] = 1 - (np.sum((np.log(temp.val_x[cond]) - np.log(temp.val_y[cond]))**2) / np.sum((np.log(temp.val_x[cond]) - np.log(temp.val_x.mean()))**2))
VOLerr[i] = np.sum(temp.val_y - temp.val_x) / np.sum(temp.val_x) # NB: not expressed in %
# Courbes empilées ou non :
if stack: # si l'option est activée
if stackgroup is None:
stackgroup = np.random.rand()
else:
stackgroup = None
# ---- Graphics
# ==============
#% Paramétrer graphes :
fig1, ax1 = plt.subplots(1, figsize = (20, 12)) # Initialise la figure et les axes.
# ax1.set_xlim(xlim)
#% Couleurs :
# ------------
# =============================================================================
# if None in linecolors:
# # =============================================================================
# # # Echelle automatique :
# # _cmap = mpl.cm.get_cmap('Spectral', n_curves)
# # color_map = [list(_cmap(i)) for i in range(0, 12)]
# # # 'hsv'
# # # 'Spectral'
# # =============================================================================
#
# # =============================================================================
# # # Echelle personnalisée :
# # color_map = custom(4, [0.949, 0.000, 0.784, 1.000],
# # [1.000, 0.784, 0.059, 0.850],
# # [0.110, 0.733, 1.000, 0.700],
# # )
# # =============================================================================
#
# # Echelle manuelle :
# _cmap_catalog = [
# [0.980, 0.691, 0.168, 0.8], # 0. orange
# [0.973, 0.392, 0.420, 0.9], # 1. orange-rose
# [0.847, 0.000, 0.035, 0.8], # 2. rouge royal
# [0.471, 0.000, 0.118, 0.8], # 3. blackred
# [1.000, 0.557, 0.827, 0.8], # 4. rose bonbon
# [0.949, 0.000, 0.784, 0.8], # 5. fuschia
# [0.655, 0.204, 0.886, 0.8], # 6. pourpre
# [0.404, 0.059, 0.902, 0.8], # 7. violet fugace
# [0.000, 0.000, 0.470, 0.8], # 8. bleu marine - noir
# [0.000, 0.318, 0.910, 0.8], # 9. bleu
# [0.000, 0.707, 0.973, 0.8], # 10. bleu ciel
# [0.000, 0.757, 0.757, 0.8], # 11. bleu-vert émeraude
# [0.625, 0.777, 0.027, 0.8], # 12. vert
# [0.824, 0.867, 0.141, 0.7], # 13. vert-jaune (ou l'inverse)
# [1.000, 0.784, 0.059, 0.8], # 14. jaune-orangé
# [0, 0, 0, 0.5], # 15. noir
# [0.37, 0.37, 0.37, 1], # 16. gris sombrero
# [0.70, 0.70, 0.70, 1], # 17. gris clairero
# ]
# color_map = np.array(_cmap_catalog)[[6, 1, 14, 9, 11, 8, 3, 4, 12, 13, 0], :]
# =============================================================================
#% Epaisseurs :
# -------------
if stack:
lwidths = [0]*n_curves
# Styles de lignes :
# ------------------
lstyle_convert = {'-':'solid', '--':'5, 2', 'dotted':'dot', None:None}
lstyle_plotly = [lstyle_convert[style] for style in lstyles]
# Markers :
# ---------
if not isinstance(figweb, go.Figure):
figweb = go.Figure()
for i in range(0, n_curves):
if mode != 'bar':
# png :
# =============================================================================
# data[i].plot(x = 'time', y = 'val', ax = ax1,
# color = color_map[i],
# label = labels[i],
# lw = lwidths[i], ls = lstyles[i])
# data[i]['label'] = labels[i]
# =============================================================================
# html :
if reference is not None: # Displays NSE, KGE... indications
figweb.add_trace(go.Scatter(
x = data[i].index,
y = data[i].val,
name = labels[i] + # '<b>' + labels[i] + '</b>' +
'<br>(VOL<sub>err</sub>: ' + "{:.2f}".format(VOLerr[i]) + ' | NSE: ' + "{:.2f}".format(NSE[i]) + ' | NSE<sub>log</sub>: ' + "{:.2f})".format(NSElog[i]),
line = {'color': cmg.to_rgba_str(linecolors[i]),
'width': lwidths[i],
'dash': lstyle_plotly[i]},
legendgroup = legendgroup,
legendgrouptitle_text = legendgrouptitle_text,
stackgroup = stackgroup,
yaxis = yaxis,
fill = fill,
fillcolor = cmg.to_rgba_str(fillcolors[i]),
mode = mode,
marker = markers[i],
showlegend = showlegend,
),
row = row,
col = col,
)
else:
figweb.add_trace(go.Scatter(
x = data[i].index,
y = data[i].val,
name = labels[i], # '<b>' + labels[i] + '</b>',
line = {'color': cmg.to_rgba_str(linecolors[i]),
'width': lwidths[i],
'dash': lstyle_plotly[i]},
legendgroup = legendgroup,
legendgrouptitle_text = legendgrouptitle_text,
stackgroup = stackgroup,
yaxis = yaxis,
fill = fill,
fillcolor = cmg.to_rgba_str(fillcolors[i]),
mode = mode,
marker = markers[i],
showlegend = showlegend,
visible = visible,
),
row = row,
col = col,
)
elif mode == 'bar':
# Fill color if missing
if markers[i] is None:
markers[i] = dict()
markers[i]['color'] = fillcolors[i]
markers[i]['line'] = dict()
markers[i]['line']['color'] = linecolors[i]
markers[i]['line']['width'] = lwidths[i]
else:
if 'color' in markers[i]:
if markers[i]['color'] is None:
markers[i]['color'] = fillcolors[i]
else:
markers[i]['color'] = fillcolors[i]
if 'line' in markers[i]:
if 'color' in markers[i]['line']:
if markers[i]['line']['color'] is None:
markers[i]['line']['color'] = fillcolors[i]
else:
markers[i]['line']['color'] = fillcolors[i]
if 'width' in markers[i]['line']:
if markers[i]['line']['width'] is None:
markers[i]['line']['width'] = lwidths[i]
else:
markers[i]['line']['width'] = lwidths[i]
else:
markers[i]['line'] = dict()
markers[i]['line']['color'] = fillcolors[i]
markers[i]['line']['width'] = lwidths[i]
# html :
if reference is not None: # Displays NSE, KGE... indications
figweb.add_trace(go.Bar(
x = data[i].index,
y = data[i].val,
width = bar_widths,
marker = markers[i],
name = labels[i] + # '<b>' + labels[i] + '</b>' +
'<br>(VOL<sub>err</sub>: ' + "{:.2f}".format(VOLerr[i]) + ' | NSE: ' + "{:.2f}".format(NSE[i]) + ' | NSE<sub>log</sub>: ' + "{:.2f})".format(NSElog[i]),
legendgroup = legendgroup,
legendgrouptitle_text = legendgrouptitle_text,
yaxis = yaxis,
showlegend = showlegend,
visible = visible,
),
row = row,
col = col,
)
else:
figweb.add_trace(go.Bar(
x = data[i].index,
y = data[i].val,
width = bar_widths,
marker = markers[i],
name = labels[i], # '<b>' + labels[i] + '</b>',
legendgroup = legendgroup,
legendgrouptitle_text = legendgrouptitle_text,
yaxis = yaxis,
showlegend = showlegend,
visible = visible,
),
row = row,
col = col,
)
# =============================================================================
# figweb.update_layout(bargap = 0)
# =============================================================================
# Version express :
# glob_df = pd.concat(data, sort = False)
# figweb = px.line(glob_df, x = 'time', y = 'val', color = 'label', title = title)
ax1.set_xlabel('Temps [j]', fontsize = 16)
# ax1.set_xticklabels(ax1.get_xticks(), fontsize = 12)
# ax1.set_yticklabels(ax1.get_yticks(), fontsize = 12)
ax1.tick_params(axis = 'both', labelsize = 14)
plt.legend(loc = "upper right", title = "Légende", frameon = False,
fontsize = 18)
ax1.set_title(title, fontsize = 24)
# Légendes sur les courbes (hover) :
if reference is not None:
figweb.update_traces(
hoverinfo = 'all',
# text = ['VOLerr: ' + "{:.3f}".format(VOLerr[i]) + '<br>NSE: ' + "{:.3f}".format(NSE[i]) + '<br>NSElog: ' + "{:.3f}".format(NSElog[i]) for i in range(0, len(data))],
hovertemplate = "t: %{x}<br>" + "y: %{y}<br>",
)
return fig1, ax1, figweb
# f.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\2- CWatM\2022-01-27) Comparaison données météo\Chronique_Tmean_ERA5-vs-Weenat.png")
# figweb.write_html(r"D:\2- Postdoc\2- Travaux\2- Suivi\2- CWatM\2022-01-27) Comparaison données météo\Chronique_Tmean_ERA5-vs-Weenat.html")
# ---- Plot formats
# ===================
## discharge_daily
ax1.set_xlim(['2001-01-01', '2010-12-31'])
fig1.set_size_inches(40, 12)
# =============================================================================
# # Log
# ax1.set_yscale('log')
# ax1.set_ylim([1e-1, 5e2])
# =============================================================================
# Linéaire
ax1.set_yscale('linear')
ax1.set_ylim([0, 100])
fig1.suptitle('Débits journaliers - Bassin du Meu', fontsize = 24) # Titre
ax1.set_title("Station de Monfort-sur-Meu - L'Abbaye", fontsize = 20) # Sous-titre
ax1.set_ylabel('Débit [m3/s]', fontsize = 16)
## baseflow_daily
ax1.set_ylim([0, 100])
fig1.suptitle('Débit de base - Bassin du Meu', fontsize = 24) # Titre
ax1.set_title("Station de Monfort-sur-Meu - L'Abbaye", fontsize = 20) # Sous-titre
ax1.set_ylabel('Débit [m]', fontsize = 16)
## Recharge
ax1.set_yscale('linear')
ax1.set_ylim([0, 10])
fig1.suptitle('Recharge journalière, en un point au centre du Bassin du Meu', fontsize = 24) # Titre
ax1.set_title("La Guivelais - Saint-Onen-la-Chapelle", fontsize = 20) # Sous-titre
ax1.set_ylabel('Recharge [mm/j]', fontsize = 16)
# ---- Export
# =============
fig1.savefig(r"---.png")
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\TEST.png")
# =============================================================================
# def figure_cotech():
# [fig1, ax1] = cwp.plot_time_series(data = [data], labels = ['Données mesurées'])
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\débits_lin_Data.png")
#
# [fig1, ax1] = cwp.plot_time_series(data = [data, model_old], labels = ['Données mesurées', 'Processus souterrains simples'])
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\débits_lin_Data-Prev.png")
#
# [fig1, ax1] = cwp.plot_time_series(data = [data, model_old, model_new], labels = ['Données mesurées', 'Processus souterrains simples', 'Processus souterrains Modflow'])
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\débits_lin_Data-Prev-Modflow.png")
#
# [fig1, ax1] = cwp.plot_time_series(data = [base, data], labels = ['Processus souterrains Modflow', 'Données mesurées'])
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\débits_log_Data-Modflow.png")
#
# [fig1, ax1] = cwp.plot_time_series(data = [precip, base, data], labels = ['Precipitations', 'Processus souterrains Modflow', 'Données mesurées'])
# fig1.savefig(r"D:\2- Postdoc\2- Travaux\2- Suivi\1- Cotech & Copil\2022-03-25) Cotech\débits_log_Precip-Data-Modflow.png")
# =============================================================================
#%% netcdf_to_animation
[docs]
def netcdf_to_animation(data,
output_path,
*,
variable_name=None,
start_date=None,
end_date=None,
vmin=None,
vmax=None,
cmap='viridis',
fps=2,
figsize=(12, 8),
title_template=None,
mask_values=None,
background_data=None,
extent=None,
dpi=100,
bitrate=1800,
**kwargs):
r"""
Convert NetCDF data to animated GIF or MP4.
This function creates animations from spatio-temporal NetCDF data without
writing individual frame files to disk, using matplotlib's animation
capabilities.
Parameters
----------
data : path (str or pathlib.Path) or xarray.Dataset
NetCDF data to animate. Can be a file path or loaded xarray Dataset.
output_path : str or pathlib.Path
Output path for the animation file. Format is determined by file extension
(.gif or .mp4).
variable_name : str, optional
Name of the variable to animate. If None, will use the main variable
detected by geobricks.main_vars().
start_date : str, optional
Start date for the animation in 'YYYY-MM-DD' format.
end_date : str, optional
End date for the animation in 'YYYY-MM-DD' format.
vmin, vmax : float, optional
Min/max values for color scaling. If None, computed from data.
cmap : str or matplotlib colormap, default 'viridis'
Colormap to use for visualization.
fps : int, default 2
Frames per second for the animation.
figsize : tuple, default (12, 8)
Figure size in inches (width, height).
title_template : str, optional
Template for frame titles. Use {time} as placeholder for time values.
Default: '{variable_name} - {time}'
mask_values : tuple or list, optional
Range of values to mask (e.g., (-100, 100) to mask values between -100 and 100).
background_data : path or xarray.Dataset, optional
Background data (e.g., DEM) to display under the main data.
Only the portion corresponding to the data extent will be shown.
extent : list or tuple, optional
Custom spatial extent [west, east, south, north] in the same CRS as the data.
If None, uses the extent of the main data.
dpi : int, default 100
Resolution in dots per inch.
bitrate : int, default 1800
Bitrate in kbps (kilobits per second) for MP4 videos. Higher values = better quality
but larger file size. Common values: 500 (low quality), 1800 (medium), 3500 (high).
Only used for MP4 format, ignored for GIF.
**kwargs
Additional arguments passed to matplotlib's imshow or pcolormesh.
Returns
-------
str
Path to the created animation file.
Notes
-----
This function uses matplotlib's animation framework instead of saving
individual frames to disk.
For very large datasets, consider pre-processing the data to reduce
the temporal resolution (e.g., daily to monthly) before creating animations.
Inspired by https://github.com/johannesuhl/netcdf2mp4
"""
# Determine format from file extension
import os
file_ext = os.path.splitext(output_path)[1].lower()
if file_ext == '.mp4':
if manimation is None:
raise ImportError("matplotlib.animation required for MP4 export. Install with: pip install matplotlib[animation]")
elif file_ext == '.gif':
if Image is None:
raise ImportError("PIL/Pillow required for GIF export. Install with: pip install Pillow")
else:
raise ValueError(f"Unsupported file extension: {file_ext}. Use .gif or .mp4")
# Load data using geobricks
ds = geo.load(data, decode_times=True)
# Get main variable if not specified
if variable_name is None:
variable_name = geo.main_vars(ds)[0]
# Get main spatial and time dimensions
space_dims = geo.main_space_dims(ds)[0]
time_dims = geo.main_time_dims(ds)
if not time_dims:
raise ValueError("No time dimension found in dataset")
time_dim = time_dims[0]
# Subset data by date range
if start_date or end_date:
ds = ds.sel({time_dim: slice(start_date, end_date)})
# Get the variable data
var_data = ds[variable_name]
# Compute vmin/vmax if not provided
if vmin is None or vmax is None:
if vmin is None:
vmin = float(var_data.min())
if vmax is None:
vmax = float(var_data.max())
# Setup title template
if title_template is None:
title_template = f"{variable_name} - {{time}}"
# Create figure and axis
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
# Get coordinate arrays for plotting
if len(space_dims) == 2:
x_coord, y_coord = space_dims
X, Y = np.meshgrid(ds[x_coord], ds[y_coord])
else:
# Fallback to simple indexing
X, Y = None, None
# Determine plot extent
if extent is not None:
# User provided custom extent [west, east, south, north]
plot_extent = extent
west, east, south, north = extent
elif X is not None and Y is not None:
# Use data extent
west, east = float(ds[x_coord].min()), float(ds[x_coord].max())
south, north = float(ds[y_coord].min()), float(ds[y_coord].max())
plot_extent = [west, east, south, north]
else:
plot_extent = None
# Load and clip background data if provided (after plot_extent is defined)
background = None
background_coords = None
if background_data is not None:
background = geo.load(background_data)
# Clip background to plot extent if possible
if plot_extent is not None:
west, east, south, north = plot_extent
try:
# Use geobricks.transform to clip background to data bounds
background_clipped = geo.transform(
background,
bounds=[west, south, east, north], # [west, south, east, north]
to_file=False
)
background = background_clipped
except Exception as e:
print(f"Warning: Background clipping failed ({e}), using full background")
# Get background spatial coordinates after clipping
bg_space_dims = geo.main_space_dims(background)
if len(bg_space_dims) == 2:
bg_x_coord, bg_y_coord = bg_space_dims
background_coords = {
'x_coord': bg_x_coord,
'y_coord': bg_y_coord,
'x': background[bg_x_coord],
'y': background[bg_y_coord]
}
# Create colorbar outside of animation function to avoid duplicates
cbar = None
# Animation function
def animate(frame_idx):
nonlocal cbar
ax.clear()
# Get current time step data
current_data = var_data.isel({time_dim: frame_idx})
current_time = var_data[time_dim].isel({time_dim: frame_idx})
# Plot background if provided (already clipped)
if background is not None:
bg_var = geo.main_vars(background)[0]
bg_data = background[bg_var]
if background_coords is not None:
# Use background's own coordinates (already clipped)
bg_x = background_coords['x']
bg_y = background_coords['y']
# Create meshgrid for background
BG_X, BG_Y = np.meshgrid(bg_x, bg_y)
# Plot with pcolormesh for proper georeferencing
ax.pcolormesh(BG_X, BG_Y, bg_data, cmap='gray', alpha=0.5, shading='auto')
else:
# Fallback: use imshow with extent
if plot_extent is not None:
ax.imshow(bg_data, cmap='gray', alpha=0.5, extent=plot_extent, aspect='auto')
# Apply masking if specified
plot_data = current_data.values
if mask_values is not None:
mask_min, mask_max = mask_values
plot_data = np.ma.masked_where(
(plot_data >= mask_min) & (plot_data <= mask_max),
plot_data
)
# Plot the main data
if X is not None and Y is not None:
im = ax.pcolormesh(X, Y, plot_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
else:
im = ax.imshow(plot_data, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto', **kwargs)
# Format time for title
try:
time_str = pd.to_datetime(current_time.values).strftime('%Y-%m-%d')
except:
time_str = str(current_time.values)
# Set title
title = title_template.format(time=time_str, variable_name=variable_name)
ax.set_title(title, fontsize=14, fontweight='bold')
# Add colorbar only once on first frame
if frame_idx == 0 and cbar is None:
cbar = fig.colorbar(im, ax=ax, orientation='horizontal', pad=0.1)
cbar.set_label(variable_name, fontsize=12)
# Set axis properties and limits
if X is not None and Y is not None:
ax.set_xlabel(x_coord)
ax.set_ylabel(y_coord)
ax.set_aspect('equal')
else:
ax.set_aspect('auto')
# Set axis limits to plot extent
if plot_extent is not None:
west, east, south, north = plot_extent
ax.set_xlim(west, east)
ax.set_ylim(south, north)
return [im]
# Create animation
n_frames = len(var_data[time_dim])
print(f"Creating animation with {n_frames} frames...")
anim = manimation.FuncAnimation(
fig, animate, frames=n_frames, interval=1000//fps, blit=False
)
# Save animation
output_path = str(output_path)
# Save animation based on file extension
if file_ext == '.gif':
writer = manimation.PillowWriter(fps=fps)
anim.save(output_path, writer=writer, dpi=dpi)
elif file_ext == '.mp4':
try:
writer = manimation.FFMpegWriter(fps=fps, bitrate=bitrate)
anim.save(output_path, writer=writer, dpi=dpi)
except Exception as e:
print(f"MP4 export failed: {e}")
print("Try installing FFmpeg or use .gif extension")
raise
plt.close(fig)
# Cleanup
import gc
gc.collect()
print(f"Animation saved to: {output_path}")
return output_path
#%% Other tools
# The previous functions have been moved to: cmapgenerator.py
# def custom(n_steps, *args):
# def custom_two_colors(n_steps, first_color, last_color):