Skip to content

Commit f775cdb

Browse files
committed
Improve documentation and add comments for scripts.
1 parent 5cb8eee commit f775cdb

File tree

5 files changed

+144
-114
lines changed

5 files changed

+144
-114
lines changed

scripts/eval_causal_support.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,32 @@
88
import pandas as pd
99
import numpy as np
1010

11-
import knowledge_tracing.utils.visualize as visualize
12-
import knowledge_tracing.utils.utils as utils
13-
14-
import matplotlib.pyplot as plt
15-
1611

1712
def parse_args(parser):
1813
parser.add_argument(
1914
"--base_pth",
2015
type=str,
16+
default="../data",
17+
help="path to the directory containing the dataset",
2118
)
22-
parser.add_argument("--dataset", type=str, default=1, help="Name of dataset")
19+
parser.add_argument("--dataset", type=str, help="Name of dataset")
2320
parser.add_argument(
2421
"--gap",
2522
type=int,
2623
default=1,
27-
help="The order of transition in evaluation causal support",
24+
help="the order of transition in evaluation causal support",
2825
)
2926
parser.add_argument(
3027
"--test",
3128
type=int,
3229
default=1,
33-
help="Whether using test data in evaluation causal support",
30+
help="whether using test data in evaluation causal support",
3431
)
3532
parser.add_argument(
3633
"--num_sample",
3734
type=int,
3835
default=1e6,
39-
help="Number of samples in estimating causal support",
36+
help="number of samples in estimating causal support",
4037
)
4138

4239
return parser
@@ -48,31 +45,47 @@ def parse_args(parser):
4845
parser = parse_args(parser)
4946
args, extras = parser.parse_known_args()
5047

51-
# assistment12
48+
# Load interaction data and skill corpus from files
49+
# Read interactions CSV file
5250
inter = pd.read_csv(
5351
f"{args.base_pth}/{args.dataset}/multi_skill/interactions.csv", sep="\t"
5452
)
53+
# Load corpus object from pickle file
5554
with open(f"{args.base_pth}/{args.dataset}/multi_skill/Corpus.pkl", "rb") as f:
5655
corpus = pickle.load(f)
5756

57+
# Extract unique skill IDs and count the number of nodes (skills)
5858
skill_id = list(inter.skill_id.unique())
5959
num_node = len(skill_id)
6060

61+
# Create a list of skill texts by matching skill IDs in interactions
6162
skill_list = []
6263
for i in range(len(skill_id)):
6364
text = list(inter.loc[inter["skill_id"] == i]["skill_text"])[0]
6465
skill_list.append(text)
6566

66-
# ----- Calculate transition matrix -----
67+
# Calculate transition matrix for skill sequences
68+
# The gap and start variables define the range of transitions to consider
6769
gap = args.gap
68-
start = 10 if args.test else 0
69-
T = np.zeros((num_node, num_node, 4)) # 0-1, 0-0, 1-1, 1-0
70-
N = np.zeros((num_node, num_node))
70+
start = (
71+
10 if args.test else 0
72+
) # Start index for considering transitions, based on whether it's a test run
73+
T = np.zeros(
74+
(num_node, num_node, 4)
75+
) # Transition counts for each pair of skills and outcomes (0-1, 0-0, 1-1, 1-0)
76+
N = np.zeros(
77+
(num_node, num_node)
78+
) # Total transition counts between each pair of skills
79+
80+
# Iterate through each user sequence in the corpus
7181
for l in range(len(corpus.user_seq_df)):
7282
correct = corpus.user_seq_df["correct_seq"][l]
7383
index = corpus.user_seq_df["skill_seq"][l]
7484

75-
for i in range(start, start + 10 - gap):
85+
# Count transitions and outcomes for each sequence, considering the defined gap
86+
for i in range(
87+
start, start + 10 - gap
88+
): # Ensure transition between different skills
7689
if index[i + gap] != index[i]:
7790
if correct[i] == 0:
7891
if correct[i + gap] == 1:
@@ -84,25 +97,29 @@ def parse_args(parser):
8497
T[index[i], index[i + gap], 2] += 1
8598
else:
8699
T[index[i], index[i + gap], 3] += 1
87-
N[index[i], index[i + gap]] += 1
100+
N[index[i], index[i + gap]] += 1 # Increment total transition count
101+
102+
# Calculate the probability of a successful transition
88103
success_transition = abs(T[..., 2]) / (T[..., 2] + T[..., 3] + 1e-6)
104+
# Create a mask to filter transitions with sufficient data
89105
mask = T[..., 2] + T[..., 3] + T[..., 0] + T[..., 1] > 1
90106

107+
# Counters for calculating causal support
91108
Nc_minus = T[..., 0] + T[..., 1]
92109
Nc_plus = T[..., 2] + T[..., 3]
93110
Ne_minus = T[..., 1] + T[..., 3]
94111
Ne_plus = T[..., 0] + T[..., 2]
95112

96-
# ----- Compute causal support -----
97-
# P(D|G0)
113+
# Compute causal support for the transitions
114+
# Probability of data given no causal relationship (P(D|G0))
98115
num_sample = args.num_sample
99116
w0 = np.arange(0, num_sample, 1) / num_sample
100117
w0 = w0.reshape(num_sample, 1, 1).repeat(num_node, 1).repeat(num_node, -1)
101118
p0 = np.power(w0, np.expand_dims(Ne_plus, 0).repeat(num_sample, 0)) * np.power(
102119
1 - w0, np.expand_dims(Ne_minus, 0).repeat(num_sample, 0)
103120
)
104121

105-
# P(D|G1)
122+
# Probability of data given a causal relationship (P(D|G1))
106123
w0 = np.arange(0, num_sample, 1) / num_sample
107124
w0 = w0.reshape(num_sample, 1, 1).repeat(num_node, 1).repeat(num_node, -1)
108125
w0 = w0.repeat(num_sample, 0)
@@ -119,5 +136,5 @@ def parse_args(parser):
119136

120137
p1 = np.multiply(p_e1_c1, p_e1_c0)
121138

122-
# Support
139+
# Calculate and print the causal support
123140
support = np.log(p1.mean(0) + 1e-6) - np.log(p0.mean(0) + 1e-6)

scripts/generate_synthetic_data.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
def parse_args(parser):
1919
# ----- global -----
2020
parser.add_argument(
21-
"--random_seed", type=int, default=1, help="Random seed for reproducibility"
21+
"--random_seed", type=int, default=1, help="random seed for reproducibility"
2222
)
2323
parser.add_argument(
24-
"--num_sequence", type=int, default=1, help="Number of sequences to generate"
24+
"--num_sequence", type=int, default=1, help="number of sequences to generate"
2525
)
2626
parser.add_argument(
2727
"--learner_model",
2828
type=str,
2929
default="graph_ou",
3030
choices=["hlr", "ppe", "ou", "graph_ou"],
31-
help="Type of learner model: hlr, ou, graph_ou, egraph_ou, ppe",
31+
help="type of learner models: hlr, ou, graph_ou, egraph_ou, ppe",
3232
)
3333

3434
# ----- time points -----
@@ -37,71 +37,71 @@ def parse_args(parser):
3737
type=str,
3838
default="random",
3939
choices=["random", "uniform"],
40-
help="Type of time distribution: random or uniform",
40+
help="type of time distributions: random or uniform",
4141
)
4242
parser.add_argument(
43-
"--time_step", type=int, default=20, help="Time step between points"
43+
"--time_step", type=int, default=20, help="time step between points"
4444
)
4545
parser.add_argument(
46-
"--max_time_step", type=int, default=250, help="Maximum time step"
46+
"--max_time_step", type=int, default=250, help="maximum time step"
4747
)
4848

4949
# ----- random graph -----
5050
parser.add_argument(
51-
"--num_node", type=int, default=2, help="Number of nodes in the random graph"
51+
"--num_node", type=int, default=10, help="number of nodes in the random graph"
5252
)
5353
parser.add_argument(
5454
"--edge_prob",
5555
type=float,
5656
default=0.4,
57-
help="Probability of an edge between nodes",
57+
help="probability of an edge between nodes",
5858
)
5959

6060
# ----- ou process -----
6161
parser.add_argument(
6262
"--mean_rev_speed",
6363
type=float,
6464
default=0.02,
65-
help="Mean reversion speed parameter",
65+
help="mean reversion speed parameter",
6666
)
6767
parser.add_argument(
6868
"--mean_rev_level",
6969
type=float,
7070
default=0.7,
71-
help="Mean reversion level parameter",
71+
help="mean reversion level parameter",
7272
)
73-
parser.add_argument("--vola", type=float, default=0.01, help="Volatility parameter")
74-
parser.add_argument("--rho", type=float, default=2, help="Rho parameter")
75-
parser.add_argument("--omega", type=float, default=0.75, help="Omega parameter")
73+
parser.add_argument("--vola", type=float, default=0.01, help="volatility parameter")
74+
parser.add_argument("--rho", type=float, default=2, help="rho parameter")
75+
parser.add_argument("--omega", type=float, default=0.75, help="omega parameter")
7676
parser.add_argument(
77-
"--gamma", type=float, default=[0.1, 0.2, 0.5, 0.75, 1], help="Gamma parameter"
77+
"--gamma", type=float, default=[0.1, 0.2, 0.5, 0.75, 1], help="gamma parameter"
7878
)
7979

8080
# ----- hlr process -----
8181
parser.add_argument(
8282
"--theta",
8383
type=list,
8484
default=[1 / 4, 1 / 2, -1 / 3],
85-
help="List of theta parameters",
85+
help="list of theta parameters",
8686
)
8787

8888
# ----- ppe process -----
8989
parser.add_argument(
9090
"--learning_rate",
9191
type=float,
9292
default=[0.01, 0.05, 0.1, 0.2, 0.5, 1],
93-
help="Learning rate for the PPE process",
93+
help="learning rate for the PPE process",
9494
)
9595
parser.add_argument(
96-
"--decay_rate", type=float, default=0.2, help="Decay rate for the PPE process"
96+
"--decay_rate", type=float, default=0.2, help="decay rate for the PPE process"
9797
)
9898

9999
# ----- save path -----
100100
parser.add_argument(
101101
"--save_path",
102102
type=str,
103103
default="..kt_data/synthetic",
104-
help="Path to save results",
104+
help="path to save results",
105105
)
106106

107107
return parser

scripts/predict_learner_performance_baseline.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def global_parse_args():
2929
parser = argparse.ArgumentParser(description="Global")
3030

3131
parser.add_argument(
32-
"--model_name", type=str, default="CausalKT", help="Choose a model to run."
32+
"--model_name", type=str, default="CausalKT", help="choose a model to run."
3333
)
3434

3535
return parser
@@ -52,7 +52,6 @@ def global_parse_args():
5252
# ----- args -----
5353
# reference:
5454
# # https://docs.python.org/3/library/argparse.html?highlight=parse_known_args#argparse.ArgumentParser.parse_known_args
55-
5655
global_args.model_name = model_name
5756
global_args.time = datetime.datetime.now().isoformat()
5857
global_args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)