Skip to content

Commit

Permalink
📊 AI: AI epoch adding regressions code
Browse files Browse the repository at this point in the history
  • Loading branch information
veronikasamborska1994 committed Oct 1, 2024
1 parent 5112587 commit 6188c60
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def run_regression(tb):

metrics = ["training_computation_petaflop", "parameters", "training_dataset_size__datapoints"]
# metrics = ["training_computation_petaflop"]

new_columns = []
for metric in metrics:
# Filter out models without the metric information
tb_metric = tb[pd.notnull(tb[metric])]
Expand Down Expand Up @@ -118,13 +118,23 @@ def run_regression(tb):
dl_line = 10 ** (dl_fit[0] + dl_year_grid * dl_fit[1])

# Add the lines back into the table as separate columns
tb.loc[tb["frac_year"] < DL_ERA_START, f"pre_dl_line_{metric}"] = pre_dl_line.reindex(
pre_dl_col_name = f"pre_dl_line_{metric}"
dl_col_name = f"dl_line_{metric}"

tb.loc[tb["frac_year"] < DL_ERA_START, pre_dl_col_name] = pre_dl_line.reindex(
tb.index[tb["frac_year"] < DL_ERA_START]
).values
tb.loc[tb["frac_year"] >= DL_ERA_START, f"dl_line_{metric}"] = dl_line.reindex(
tb.loc[tb["frac_year"] >= DL_ERA_START, dl_col_name] = dl_line.reindex(
tb.index[tb["frac_year"] >= DL_ERA_START]
).values

# Append new column names to the list
new_columns.append(pre_dl_col_name)
new_columns.append(dl_col_name)
for column in new_columns:
# Add metadata to the publication date column
tb[column].metadata.origins = tb["domain"].metadata.origins

tb = tb.drop("frac_year", axis=1)

return tb
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def run(dest_dir: str) -> None:
# Load inputs.
#
# Load garden dataset.
ds_garden = paths.load_dataset("epoch")
ds_garden = paths.load_dataset("epoch_regressions")

# Read table from garden dataset.
tb = ds_garden["epoch"]
tb = tb.rename_index(columns={"system": "country", "days_since_1949": "year"})
tb = tb.rename_index_names({"system": "country", "days_since_1949": "year"})
#
# Save outputs.
#
Expand Down

0 comments on commit 6188c60

Please sign in to comment.