-
Notifications
You must be signed in to change notification settings - Fork 928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More parameter functionality in matplotlib (default) drawer visualization #2242
Closed
rmhopkins4
wants to merge
10
commits into
projectmesa:main
from
rmhopkins4:matplotlib-viz-param-refactor
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
8629310
Update matplotlib.py
rmhopkins4 2cc3f10
matplotlib visualization supports more params
rmhopkins4 bc67c60
old param keywords, x,y pos
rmhopkins4 9632600
implemented 'norm', reformatted colormap application
rmhopkins4 ed6cf81
formatting
rmhopkins4 5b67ade
update to cmap, norm flow
rmhopkins4 2e4cdd0
colormap application moved to its own function
rmhopkins4 bf59cbc
Update matplotlib.py
rmhopkins4 ddec45a
improved num_agents iteration
rmhopkins4 c6a8a65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
from collections import defaultdict | ||
|
||
import matplotlib.pyplot as plt | ||
import networkx as nx | ||
import solara | ||
from matplotlib.colors import Normalize | ||
from matplotlib.figure import Figure | ||
from matplotlib.ticker import MaxNLocator | ||
|
||
import mesa | ||
from mesa.space import GridContent | ||
|
||
|
||
@solara.component | ||
|
@@ -22,47 +25,100 @@ | |
_draw_continuous_space(space, space_ax, agent_portrayal) | ||
else: | ||
_draw_grid(space, space_ax, agent_portrayal) | ||
|
||
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) | ||
|
||
|
||
# used to make non(less?)-breaking change | ||
# this *does* however block the matplotlib 'color' param which is somewhat distinct from 'c'. | ||
# maybe translate 'size' and 'shape' but not 'color'? | ||
def _translate_old_keywords(data): | ||
""" | ||
Translates old keyword names in the given dictionary to the new names. | ||
""" | ||
key_mapping = {"size": "s", "color": "c", "shape": "marker"} | ||
return {key_mapping.get(key, key): val for (key, val) in data.items()} | ||
|
||
|
||
def _apply_color_map(color, cmap=None, norm=None, vmin=None, vmax=None): | ||
""" | ||
Given parameters for manual colormap application, applies color map | ||
according to default implementation in matplotlib | ||
""" | ||
if not cmap: # if no colormap is provided, return original color | ||
return color | ||
color_map = plt.get_cmap(cmap) | ||
if norm: # check if norm is provided and apply it | ||
if not isinstance(norm, Normalize): | ||
raise TypeError( | ||
"'norm' must be an instance of Normalize or its subclasses." | ||
) | ||
return color_map(norm(color)) | ||
if not (vmin == None or vmax == None): # check for custom norm params | ||
new_norm = Normalize(vmin, vmax) | ||
return color_map(new_norm(color)) | ||
try: | ||
return color_map(color) | ||
except Exception as e: | ||
raise ValueError("Color mapping failed due to invalid arguments") from e | ||
|
||
|
||
# matplotlib scatter does not allow for multiple shapes in one call | ||
def _split_and_scatter(portray_data, space_ax): | ||
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) | ||
|
||
# Extract data from the dictionary | ||
x = portray_data["x"] | ||
y = portray_data["y"] | ||
s = portray_data["s"] | ||
c = portray_data["c"] | ||
m = portray_data["m"] | ||
|
||
if not (len(x) == len(y) == len(s) == len(c) == len(m)): | ||
raise ValueError( | ||
"Length mismatch in portrayal data lists: " | ||
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " | ||
f"color: {len(c)}, marker: {len(m)}" | ||
) | ||
|
||
# Group the data by marker | ||
for i in range(len(x)): | ||
marker = m[i] | ||
grouped_data[marker]["x"].append(x[i]) | ||
grouped_data[marker]["y"].append(y[i]) | ||
grouped_data[marker]["s"].append(s[i]) | ||
grouped_data[marker]["c"].append(c[i]) | ||
|
||
# Plot each group with the same marker | ||
def _split_and_scatter(portray_data: dict, space_ax) -> None: | ||
# if any of the following params are passed into portray(), this is true | ||
cmap_exists = portray_data.pop("cmap", None) | ||
norm_exists = portray_data.pop("norm", None) | ||
vmin_exists = portray_data.pop("vmin", None) | ||
vmax_exists = portray_data.pop("vmax", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a feeling this can be done more elegant |
||
|
||
# enforce marker iterability | ||
markers = portray_data.pop("marker", ["o"] * len(portray_data["x"])) | ||
# enforce default color | ||
if ( # if no 'color' or 'facecolor' or 'c' then default to "tab:blue" color | ||
"color" not in portray_data | ||
and "facecolor" not in portray_data | ||
and "c" not in portray_data | ||
): | ||
portray_data["color"] = ["tab:blue"] * len(portray_data["x"]) | ||
|
||
grouped_data = defaultdict(lambda: {key: [] for key in portray_data}) | ||
|
||
for i, marker in enumerate(markers): | ||
for key in portray_data: | ||
if key == "c": # apply colormap if possible | ||
# prepare arguments | ||
cmap = cmap_exists[i] if cmap_exists else None | ||
norm = norm_exists[i] if norm_exists else None | ||
vmin = vmin_exists[i] if vmin_exists else None | ||
vmax = vmax_exists[i] if vmax_exists else None | ||
# apply colormap with prepared arguments | ||
portray_data["c"][i] = _apply_color_map( | ||
portray_data["c"][i], cmap, norm, vmin, vmax | ||
) | ||
|
||
grouped_data[marker][key].append(portray_data[key][i]) | ||
|
||
for marker, data in grouped_data.items(): | ||
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) | ||
space_ax.scatter(marker=marker, **data) | ||
|
||
|
||
def _draw_grid(space, space_ax, agent_portrayal): | ||
def portray(g): | ||
x = [] | ||
y = [] | ||
s = [] # size | ||
c = [] # color | ||
m = [] # shape | ||
default_values = { | ||
"size": (180 / max(g.width, g.height)) ** 2, | ||
} | ||
|
||
out = {} | ||
num_agents = 0 | ||
for content in g: | ||
if not content: | ||
continue | ||
if isinstance(content, GridContent): # one agent | ||
num_agents += 1 | ||
continue | ||
num_agents += len(content) | ||
|
||
index = 0 | ||
for i in range(g.width): | ||
for j in range(g.height): | ||
content = g._grid[i][j] | ||
|
@@ -73,27 +129,25 @@ | |
content = [content] | ||
for agent in content: | ||
data = agent_portrayal(agent) | ||
x.append(i) | ||
y.append(j) | ||
|
||
# This is the default value for the marker size, which auto-scales | ||
# according to the grid area. | ||
default_size = (180 / max(g.width, g.height)) ** 2 | ||
# establishing a default prevents misalignment if some agents are not given size, color, etc. | ||
size = data.get("size", default_size) | ||
s.append(size) | ||
color = data.get("color", "tab:blue") | ||
c.append(color) | ||
mark = data.get("shape", "o") | ||
m.append(mark) | ||
out = {"x": x, "y": y, "s": s, "c": c, "m": m} | ||
return out | ||
data["x"] = i | ||
data["y"] = j | ||
|
||
for key, value in data.items(): | ||
if key not in out: | ||
# initialize list | ||
out[key] = [default_values.get(key)] * num_agents | ||
out[key][index] = value | ||
index += 1 | ||
|
||
return _translate_old_keywords(out) | ||
|
||
space_ax.set_xlim(-1, space.width) | ||
space_ax.set_ylim(-1, space.height) | ||
|
||
_split_and_scatter(portray(space), space_ax) | ||
|
||
|
||
# draws using networkx's matplotlib integration | ||
def _draw_network_grid(space, space_ax, agent_portrayal): | ||
graph = space.G | ||
pos = nx.spring_layout(graph, seed=0) | ||
|
@@ -107,28 +161,23 @@ | |
|
||
def _draw_continuous_space(space, space_ax, agent_portrayal): | ||
def portray(space): | ||
x = [] | ||
y = [] | ||
s = [] # size | ||
c = [] # color | ||
m = [] # shape | ||
for agent in space._agent_to_index: | ||
# TODO: look into if more default values are needed | ||
# especially relating to 'color', 'facecolor', and 'c' params & | ||
# interactions w/ the current implementation of _split_and_scatter | ||
default_values = {"s": 20} | ||
out = {} | ||
num_agents = len(space._agent_to_index) | ||
|
||
for i, agent in enumerate(space._agent_to_index): | ||
data = agent_portrayal(agent) | ||
_x, _y = agent.pos | ||
x.append(_x) | ||
y.append(_y) | ||
|
||
# This is matplotlib's default marker size | ||
default_size = 20 | ||
# establishing a default prevents misalignment if some agents are not given size, color, etc. | ||
size = data.get("size", default_size) | ||
s.append(size) | ||
color = data.get("color", "tab:blue") | ||
c.append(color) | ||
mark = data.get("shape", "o") | ||
m.append(mark) | ||
out = {"x": x, "y": y, "s": s, "c": c, "m": m} | ||
return out | ||
data["x"], data["y"] = agent.pos | ||
|
||
for key, value in data.items(): | ||
if key not in out: # initialize list | ||
out[key] = [default_values.get(key, default=None)] * num_agents | ||
out[key][i] = value | ||
|
||
return _translate_old_keywords(out) | ||
|
||
# Determine border style based on space.torus | ||
border_style = "solid" if not space.torus else (0, (5, 10)) | ||
|
@@ -146,7 +195,6 @@ | |
space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) | ||
space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) | ||
|
||
# Portray and scatter the agents in the space | ||
_split_and_scatter(portray(space), space_ax) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain or link to how
color
andc
differ in matplotlib (scatter)?