diff --git a/plot_transmission_timeline.py b/plot_transmission_timeline.py index d03007f..15646a0 100755 --- a/plot_transmission_timeline.py +++ b/plot_transmission_timeline.py @@ -192,16 +192,22 @@ if __name__ == "__main__": # create list fo color indices transmission_df["index"] = transmission_df.index color_dict = dict() + color_list = list() i = 0 - for cell_id in transmission_df["cellID"].unique(): + for cell_id in transmission_df["cellID"]: if cell_id not in color_dict: color_dict[cell_id] = i i += 1 + color_list.append(color_dict[cell_id]) + + transmission_df["cell_color"] = color_list + color_dict = None + color_list = None cmap = matplotlib.cm.get_cmap("Set3") - for c in transmission_df["cellID"].unique(): - bounds = transmission_df[["index", "cellID"]].groupby("cellID").agg(["min", "max"]).loc[c] - host.axvspan(bounds.min(), bounds.max(), alpha=0.3, color=cmap.colors[color_dict[c]]) + for c in transmission_df["cell_color"].unique(): + bounds = transmission_df[["index", "cell_color"]].groupby("cell_color").agg(["min", "max"]).loc[c] + host.axvspan(bounds.min(), bounds.max(), alpha=0.3, color=cmap.colors[c]) plt.subplots_adjust()