1
+ import numpy as np
2
+
3
+ import matplotlib as mpl
4
+ import matplotlib .pyplot as plt
5
+ import matplotlib .ticker as ticker
6
+
7
+ from matplotlib .colors import ListedColormap
8
+ import colorsys
9
+
10
+ default_color_cycle = ["#B65655FF" , "#5471abFF" , "#6aa66eFF" , "#A66E6AFF" ]
11
+
12
+ # Function to lighten a color
13
+ def lighten_color (color , amount = 0.3 ):
14
+ # Convert color from hexadecimal to RGB
15
+ r , g , b , a = tuple (int (color [i :i + 2 ], 16 ) for i in (1 , 3 , 5 , 7 ))
16
+ # Convert RGB to HLS
17
+ h , l , s = colorsys .rgb_to_hls (r / 255 , g / 255 , b / 255 )
18
+ # Lighten the luminance component
19
+ l = min (1 , l + amount )
20
+ # Convert HLS back to RGB
21
+ r , g , b = tuple (round (c * 255 ) for c in colorsys .hls_to_rgb (h , l , s ))
22
+ # Convert RGB back to hexadecimal
23
+ new_color = f"#{ r :02x} { g :02x} { b :02x} { a :02x} "
24
+ return new_color
25
+
26
+ def set_color_cycle (color_cycle , alpha = 0.3 , mfc = True ):
27
+ color_cycle_light = [lighten_color (color , alpha ) for color in color_cycle ]
28
+ if mfc :
29
+ colors = mpl .cycler (mfc = color_cycle_light , color = color_cycle , markeredgecolor = color_cycle )
30
+ else :
31
+ colors = mpl .cycler (color = color_cycle , markeredgecolor = color_cycle )
32
+ mpl .rc ('axes' , prop_cycle = colors )
33
+
34
+ set_color_cycle (default_color_cycle )
35
+ # mpl.rc('axes', grid=True, edgecolor='k', prop_cycle=colors)
36
+ # mpl.rcParams['axes.prop_cycle'] = colors
37
+ # mpl.rcParams['lines.markeredgecolor'] = 'C'
38
+
39
+ mpl .rcParams ['font.family' ] = 'sans-serif' # 'Helvetica'
40
+ mpl .rcParams ['axes.linewidth' ] = 1.5
41
+ mpl .rcParams ["xtick.direction" ] = 'out' # 'out'
42
+ mpl .rcParams ["ytick.direction" ] = 'out'
43
+ mpl .rcParams ['xtick.major.width' ] = 1.5
44
+ mpl .rcParams ['ytick.major.width' ] = 1.5
45
+ mpl .rcParams ['ytick.minor.width' ] = 1.5
46
+ mpl .rcParams ['lines.markersize' ] = 11
47
+ mpl .rcParams ['legend.frameon' ] = True
48
+ mpl .rcParams ['lines.linewidth' ] = 1.5
49
+ # plt.rcParams['lines.markeredgecolor'] = 'k'
50
+ mpl .rcParams ['lines.markeredgewidth' ] = 1.5
51
+ mpl .rcParams ['figure.dpi' ] = 100
52
+ mpl .rcParams ['figure.figsize' ] = (8 , 6 )
53
+ mpl .rcParams ['figure.autolayout' ] = True
54
+ mpl .rcParams ['axes.grid' ] = True
55
+ mpl .rcParams ['savefig.bbox' ] = 'tight'
56
+ mpl .rcParams ['savefig.transparent' ] = True
57
+
58
+ SMALL_SIZE = 14
59
+ MEDIUM_SIZE = 18 #default 10
60
+ LARGE_SIZE = 24
61
+ # MARKER_SIZE = 10
62
+
63
+ plt .rc ('font' , size = MEDIUM_SIZE ) # controls default text sizes
64
+ plt .rc ('axes' , titlesize = LARGE_SIZE + 2 ) # fontsize of the axes title
65
+ plt .rc ('axes' , labelsize = LARGE_SIZE ) # fontsize of the x and y labels
66
+ plt .rc ('xtick' , labelsize = LARGE_SIZE ) # fontsize of the tick labels
67
+ plt .rc ('ytick' , labelsize = LARGE_SIZE ) # fontsize of the tick labels
68
+ plt .rc ('legend' , fontsize = MEDIUM_SIZE - 2 ) # legend fontsize
69
+ plt .rc ('figure' , titlesize = LARGE_SIZE ) # fontsize of the figure title
70
+
71
+ # def data_plot(x, y, marker, label, alpha=1, linewidth=1, loglog=True, markeredgecolor='black'):
72
+ # if loglog:
73
+ # plt.loglog(x, y, marker, label=label, linewidth=linewidth, markeredgecolor=markeredgecolor, markeredgewidth=0.5, alpha=alpha)
74
+ # else:
75
+ # plt.plot(x, y, marker, label=label, linewidth=linewidth, markeredgecolor=markeredgecolor, markeredgewidth=0.5, alpha=alpha)
76
+
77
+ from scipy .optimize import curve_fit
78
+ from math import ceil , floor , log , exp
79
+
80
+ def linear_loglog_fit (x , y , verbose = False ):
81
+ # Define the linear function
82
+ def linear_func (x , a , b ):
83
+ return a * x + b
84
+
85
+ log_x = np .array ([log (n ) for n in x ])
86
+ log_y = np .array ([log (cost ) for cost in y ])
87
+ # Fit the linear function to the data
88
+ params , covariance = curve_fit (linear_func , log_x , log_y )
89
+ # Extract the parameters
90
+ a , b = params
91
+ # Predict y values
92
+ y_pred = linear_func (log_x , a , b )
93
+ # Print the parameters
94
+ if verbose : print ('Slope (a):' , a , '; Intercept (b):' , b )
95
+ exp_y_pred = [exp (cost ) for cost in y_pred ]
96
+
97
+ return exp_y_pred , a , b
98
+
99
+ def plot_fit (ax , x , y , var = 't' , x_offset = 1.07 , y_offset = 1.0 , label = '' , ext_x = [], linestyle = 'k--' , linewidth = 1.5 , fontsize = MEDIUM_SIZE , verbose = True ):
100
+ y_pred_em , a_em , b_em = linear_loglog_fit (x , y )
101
+ if verbose : print (f'a_em: { a_em } ; b_em: { b_em } ' )
102
+ if abs (a_em ) < 1e-3 :
103
+ text_a_em = "{:.2f}" .format (round (abs (a_em ), 4 ))
104
+ else :
105
+ text_a_em = "{:.2f}" .format (round (a_em , 4 ))
106
+
107
+ if ext_x != []: x = ext_x
108
+ y_pred_em = [exp (cost ) for cost in a_em * np .array ([log (n ) for n in x ]) + b_em ]
109
+ if label == '' :
110
+ ax .plot (x , y_pred_em , linestyle , linewidth = linewidth )
111
+ else :
112
+ ax .plot (x , y_pred_em , linestyle , linewidth = linewidth , label = label )
113
+ ax .annotate (r'$O(%s^{%s})$' % (var , text_a_em ), xy = (x [- 1 ], np .real (y_pred_em )[- 1 ]), xytext = (x [- 1 ]* x_offset , np .real (y_pred_em )[- 1 ]* y_offset ), fontsize = fontsize )
114
+
115
+ return a_em , b_em
116
+
117
+ def ax_set_text (ax , x_label , y_label , title = None , legend = 'best' , xticks = None , yticks = None , grid = None , log = '' , ylim = None ):
118
+ ax .set_xlabel (x_label )
119
+ ax .set_ylabel (y_label )
120
+ if title : ax .set_title (title )
121
+ if legend : ax .legend (loc = legend )
122
+
123
+ if log == 'x' :
124
+ ax .set_xscale ('log' )
125
+ elif log == 'y' :
126
+ ax .set_yscale ('log' )
127
+ elif log == 'xy' :
128
+ # ax.set_xscale('log')
129
+ # ax.set_yscale('log')
130
+ ax .loglog ()
131
+ else :
132
+ pass
133
+
134
+ if grid : ax .grid ()
135
+ if ylim : ax .set_ylim ([ylim [0 ]* 0.85 , ylim [1 ]* 1.15 ])
136
+
137
+ if xticks is not None :
138
+ ax .set_xticks (xticks )
139
+ ax .get_xaxis ().set_major_formatter (mpl .ticker .ScalarFormatter ())
140
+
141
+ if yticks is not None :
142
+ ax .set_yticks (yticks )
143
+ ax .get_yaxis ().set_major_formatter (mpl .ticker .ScalarFormatter ())
144
+
145
+ def matrix_plot (M ):
146
+ fig , ax = plt .subplots ()
147
+ real_matrix = np .real (M )
148
+ # Plot the real part using a colormap
149
+ ax .imshow (real_matrix , cmap = 'RdYlBu' , interpolation = 'nearest' , origin = 'upper' )
150
+ # Create grid lines
151
+ ax .grid (True , which = 'both' , color = 'black' , linewidth = 1 )
152
+ # Add color bar for reference
153
+ cbar = plt .colorbar (ax .imshow (real_matrix , cmap = 'RdYlBu' , interpolation = 'nearest' , origin = 'upper' ), ax = ax , orientation = 'vertical' )
154
+ cbar .set_label ('Real Part' )
155
+ # Add labels to the x and y axes
156
+ plt .xlabel ('X-axis' )
157
+ plt .ylabel ('Y-axis' )
158
+ # Show the plot
159
+ plt .title ('Complex Matrix with Grid' )
160
+ plt .show ()
0 commit comments