Skip to content

Commit

Permalink
linear_loglog_fit support different fitting func
Browse files Browse the repository at this point in the history
  • Loading branch information
Jue-Xu committed Aug 4, 2024
1 parent c4788de commit 7b75ace
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 13 deletions.
40 changes: 30 additions & 10 deletions quantum_simulation_recipe/plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,55 @@ def get_colors(self, c: str):
from scipy.optimize import curve_fit
from math import ceil, floor, log, exp

def linear_loglog_fit(x, y, verbose=False):
def linear_loglog_fit(x, y, log_axis='xy', verbose=False):
# Define the linear function
def linear_func(x, a, b):
return a * x + b

log_x = np.array([log(n) for n in x])
log_y = np.array([log(cost) for cost in y])
if log_axis == 'xy':
x = np.array([np.log(n) for n in x])
y = np.array([np.log(cost) for cost in y])
elif log_axis == 'y':
y = np.array([np.log(cost) for cost in y])
elif log_axis == 'x':
x = np.array([np.log(n) for n in x])
elif log_axis == '':
pass
else:
raise ValueError('Invalid log value')
# Fit the linear function to the data
params, covariance = curve_fit(linear_func, log_x, log_y)
params, covariance = curve_fit(linear_func, x, y)
# Extract the parameters
a, b = params
# Predict y values
y_pred = linear_func(log_x, a, b)
y_pred = linear_func(x, a, b)
# Print the parameters
if verbose: print('Slope (a):', a, '; Intercept (b):', b)
exp_y_pred = [exp(cost) for cost in y_pred]
if log == 'xy' or log == 'y':
y_pred = [exp(cost) for cost in y_pred]

return exp_y_pred, a, b
return y_pred, a, b

def plot_fit(ax, x, y, var='t', x_offset=1.07, y_offset=1.0, label='', ext_x=[], linestyle='k--', linewidth=WIDTH, fontsize=MEDIUM_SIZE, verbose=True):
y_pred_em, a_em, b_em = linear_loglog_fit(x, y)
def plot_fit(ax, x, y, var='t', log_axis='xy', x_offset=1.07, y_offset=1.0, label='', ext_x=[], linestyle='k--', linewidth=WIDTH, fontsize=MEDIUM_SIZE, verbose=True):
y_pred_em, a_em, b_em = linear_loglog_fit(x, y, log_axis=log_axis)
if verbose: print(f'a_em: {a_em}; b_em: {b_em}')
if abs(a_em) < 1e-3:
text_a_em = "{:.2f}".format(round(abs(a_em), 4))
else:
text_a_em = "{:.2f}".format(round(a_em, 4))

if ext_x != []: x = ext_x
y_pred_em = [exp(cost) for cost in a_em*np.array([log(n) for n in x]) + b_em]
if log_axis == 'xy':
y_pred_em = [np.exp(cost) for cost in a_em*np.array([np.log(n) for n in x]) + b_em]
elif log_axis == 'y':
y_pred_em = [np.exp(cost) for cost in a_em*np.array(x) + b_em]
elif log_axis == 'x':
y_pred_em = [cost for cost in a_em*np.array([np.log(n) for n in x]) + b_em]
elif log_axis == '':
y_pred_em = [cost for cost in a_em*np.array([n for n in x]) + b_em]
else:
raise ValueError('Invalid log value')

if label =='':
ax.plot(x, y_pred_em, linestyle, linewidth=linewidth)
else:
Expand Down
65 changes: 62 additions & 3 deletions quantum_simulation_recipe/test.ipynb

Large diffs are not rendered by default.

0 comments on commit 7b75ace

Please sign in to comment.