diff --git a/plot_transmission_timeline.py b/plot_transmission_timeline.py index 0700396..19b781d 100755 --- a/plot_transmission_timeline.py +++ b/plot_transmission_timeline.py @@ -189,12 +189,19 @@ if __name__ == "__main__": host = host_subplot(111, axes_class=axisartist.Axes) + # create list fo color indices transmission_df["index"] = transmission_df.index + color_dict = dict() + i = 0 + for cell_id in transmission_df["cellID"].unique(): + if cell_id not in color_dict: + color_dict[cell_id] = i + i += 1 cmap = matplotlib.cm.get_cmap("Set3") - for c in transmission_df["cellID"].unique(): - bounds = transmission_df[["index", "cellID"]].groupby("cellID").agg(["min", "max"]).loc[c] - host.axvspan(bounds.min(), bounds.max(), alpha=0.3, color=cmap.colors[c]) + for cell_id in transmission_df["cellID"].unique(): + bounds = transmission_df[["index", "cellID"]].groupby("cellID").agg(["min", "max"]).loc[cell_id] + host.axvspan(bounds.min(), bounds.max(), alpha=0.3, color=cmap.colors[color_dict[cell_id]]) plt.subplots_adjust()