Skip to content

Commit

Permalink
Updates to figure generation for re-submission
Browse files Browse the repository at this point in the history
  • Loading branch information
joshmoore committed Oct 4, 2021
1 parent 1812b6b commit 7a31150
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
26 changes: 20 additions & 6 deletions benchmark/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,23 @@
import ptitprince as pt


def plot_csv(filename: str, pngpath: str):
def plot_csv(
filename: str,
outpath: str,
font: int = 14,
width: int = 10,
height: int = 4,
):
csv = pd.read_csv(filename)

f, ax = plt.subplots(figsize=(8, 6))
f, ax = plt.subplots(figsize=(width, height))
ax = pt.RainCloud(
x="type",
y="seconds",
hue="source",
data=csv,
palette="Set2",
order=("Overhead", "Zarr", "TIFF", "HDF5"),
# bw = .2,
width_viol=0.6,
ax=ax,
Expand All @@ -30,16 +37,23 @@ def plot_csv(filename: str, pngpath: str):
# ax.set(ylim=(0.0002, 5))
ax.set_xscale("log")
ax.set_xlabel("seconds per chunk")
for item in (
[ax.title, ax.xaxis.label, ax.yaxis.label]
+ ax.get_xticklabels()
+ ax.get_yticklabels()
):
item.set_fontsize(font)

ax.axes.get_yaxis().get_label().set_visible(False)
handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[0:3], labels[0:3], loc="lower left")
plt.legend(handles[0:3], labels[0:3], loc="lower left", prop={"size": font})
plt.tight_layout()
f.savefig(pngpath, dpi=600)
f.savefig(outpath)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("csv", help="input csv file")
parser.add_argument("png", help="output png file")
parser.add_argument("out", help="output filename")
ns = parser.parse_args()
plot_csv(ns.csv, ns.png)
plot_csv(ns.csv, ns.out)
42 changes: 23 additions & 19 deletions notebooks/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ def file_count(shape, chunkXY, chunkZ=1, chunkT=1, chunkC=1):
)


grays = (
(0.2, 0.2, 0.2),
(0.4, 0.4, 0.4),
(0.6, 0.6, 0.6),
(0.8, 0.8, 0.8),
)


def plot(ax, twoD=True):
def plot(ax, twoD=True, font=16):
if twoD:
shape = (1, 8, 1, 2 ** 16, 2 ** 16)
chunkSizesXY = [32, 1024]
Expand All @@ -41,17 +33,24 @@ def plot(ax, twoD=True):

if twoD:
ax.set_xlabel("Chunk size (X and Y)")
ax.set_title("XYZCT: 64k x 64k x 1 x 8 x 1")
ax.set_title("XYZCT: (64k, 64k, 1, 8, 1)")
chunkDim = "C"
annTitle = "Chosen chunk size:\n256 x 256 x 1 x 1 x 1"
annTitle = "Chosen chunk size:\n(256, 256, 1, 1, 1)"
xy = ((256), file_count(shape, 256))
else:
ax.set_xlabel("Chunk size (XYZ)")
ax.set_title("XYZCT: 1k x 1k x 1k x 1 x 100")
ax.set_title("XYZCT: (1k, 1k, 1k, 1, 100)")
chunkDim = "T"
annTitle = "Chosen chunk size:\n32 x 32 x 32 x 1 x 1"
annTitle = "Chosen chunk size:\n(32, 32, 32, 1, 1)"
xy = ((32), file_count(shape, 32, chunkZ=32))

for item in (
[ax.title, ax.xaxis.label, ax.yaxis.label]
+ ax.get_xticklabels()
+ ax.get_yticklabels()
):
item.set_fontsize(font)

styles = ["solid", "dashed", "dashdot", "dotted"]
for whichChunk, chunkOther in enumerate(chunkSizesOther):
numFiles = []
Expand All @@ -69,10 +68,9 @@ def plot(ax, twoD=True):
ax.plot(
fileSize,
numFiles,
linewidth=4.0,
linewidth=0.5,
label=f"{chunkOther}",
linestyle=styles.pop(0),
color=grays[whichChunk],
)

ax.annotate(
Expand All @@ -84,20 +82,26 @@ def plot(ax, twoD=True):
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="left",
verticalalignment="center",
fontsize=font - 4,
)
leg = ax.legend(
loc="lower left", title=f"Chunk size ({chunkDim})", frameon=False
loc="lower left",
title=f"Chunk size ({chunkDim})",
frameon=False,
prop={"size": font},
)
for legobj in leg.legendHandles:
legobj.set_linewidth(2.0)
legobj.set_linewidth(0.5)

for axis in ["top", "bottom", "left", "right"]:
ax.spines[axis].set_linewidth(0.5)

return fig


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("filename")
parser.add_argument("--dpi", type=int, default=600)
ns = parser.parse_args()
# fig = plt.figure()
# ax2D = fig.add_subplot(2, 1, 1)
Expand All @@ -107,4 +111,4 @@ def plot(ax, twoD=True):
plot(ax[1], False)
plot(ax[0], True)

plt.savefig(ns.filename, dpi=ns.dpi)
plt.savefig(ns.filename)

0 comments on commit 7a31150

Please sign in to comment.