Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 24, 2024
1 parent 990fc94 commit 5bb5133
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 189 deletions.
13 changes: 7 additions & 6 deletions folktexts/col_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ def __init__(
# > infer `question` from value map (if possible)
elif (
self._value_map is not None
and isinstance(self._value_map, dict)
and self._question is None
):
self._question = MultipleChoiceQA.create_question_from_value_map(
column=self.name,
value_map=self._value_map,
attribute=self.short_description,
)
if isinstance(self._value_map, dict):
logging.debug(f"Inferring multiple-choice question for column '{self.name}'.")
self._question = MultipleChoiceQA.create_question_from_value_map(
column=self.name,
value_map=self._value_map,
attribute=self.short_description,
)

# Else, warn if both were provided (as they may use inconsistent value maps)
elif self._value_map is not None and self._question is not None:
Expand Down
196 changes: 13 additions & 183 deletions notebooks/run-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Got both `value_map` and `question` for column 'PUBCOV'. Please make sure value mappings are consistent.\n",
"INFO:root:Got both `value_map` and `question` for column 'PUBCOV==1'. Please make sure value mappings are consistent.\n",
"INFO:root:Loading model '/Users/acruz/huggingface-models/google--gemma-2b'\n",
"Gemma's activation function should be approximate GeLU and not exact GeLU.\n",
"Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
Expand All @@ -144,7 +142,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "777772d26ec34d19b9885eadd730d46e",
"model_id": "8808498339a344dbb460f18b07b72ca7",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -254,8 +252,8 @@
"output_type": "stream",
"text": [
"Loading ACS data...\n",
"CPU times: user 44.3 s, sys: 14.5 s, total: 58.8 s\n",
"Wall time: 1min 4s\n"
"CPU times: user 43.2 s, sys: 10.1 s, total: 53.4 s\n",
"Wall time: 53.2 s\n"
]
}
],
Expand Down Expand Up @@ -346,155 +344,6 @@
"Print a few example prompts:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7d6e30b6-7e47-4211-8d96-af1d1b0f50f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ACSTaskMetadata(name='ACSIncome', description=\"predict whether an individual's income is above $50,000\", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText object at 0x14a6f2910>, 'COW': <folktexts.col_to_text.ColumnToText object at 0x14a6f2790>, 'SCHL': <folktexts.col_to_text.ColumnToText object at 0x14a6f3810>, 'MAR': <folktexts.col_to_text.ColumnToText object at 0x14a70ca90>, 'OCCP': <folktexts.col_to_text.ColumnToText object at 0x14a70cdd0>, 'POBP': <folktexts.col_to_text.ColumnToText object at 0x14a70ce10>, 'RELP': <folktexts.col_to_text.ColumnToText object at 0x14a70ce50>, 'WKHP': <folktexts.col_to_text.ColumnToText object at 0x14a70d9d0>, 'SEX': <folktexts.col_to_text.ColumnToText object at 0x14a70da50>, 'RAC1P': <folktexts.col_to_text.ColumnToText object at 0x14a70de50>, 'PINCP': <folktexts.col_to_text.ColumnToText object at 0x14a70e490>, 'PINCP>50000': <folktexts.col_to_text.ColumnToText object at 0x14a70e690>, 'PUBCOV': <folktexts.col_to_text.ColumnToText object at 0x14a70e790>, 'PUBCOV==1': <folktexts.col_to_text.ColumnToText object at 0x14a70ea10>, 'DIS': <folktexts.col_to_text.ColumnToText object at 0x14a70eb50>, 'ESP': <folktexts.col_to_text.ColumnToText object at 0x14a70f150>, 'CIT': <folktexts.col_to_text.ColumnToText object at 0x14a70fcd0>, 'MIG': <folktexts.col_to_text.ColumnToText object at 0x14a71d550>, 'MIG==1': <folktexts.col_to_text.ColumnToText object at 0x14a71d8d0>, 'MIL': <folktexts.col_to_text.ColumnToText object at 0x14a71d910>, 'ANC': <folktexts.col_to_text.ColumnToText object at 0x14a71db90>, 'NATIVITY': <folktexts.col_to_text.ColumnToText object at 0x14a71de10>, 'DEAR': <folktexts.col_to_text.ColumnToText object at 0x14a71dfd0>, 'DEYE': <folktexts.col_to_text.ColumnToText object at 0x14a71e190>, 'DREM': <folktexts.col_to_text.ColumnToText object at 0x14a71e350>, 'ESR': <folktexts.col_to_text.ColumnToText object at 0x14a71e510>, 'ESR==1': <folktexts.col_to_text.ColumnToText object at 0x14a71ea10>, 'ST': <folktexts.col_to_text.ColumnToText object at 0x14a71ea50>, 'FER': <folktexts.col_to_text.ColumnToText object at 0x14a71ead0>, 'JWMNP': <folktexts.col_to_text.ColumnToText object at 0x14a71eb10>, 'JWMNP>20': <folktexts.col_to_text.ColumnToText object at 0x14a71eed0>, 'JWTR': <folktexts.col_to_text.ColumnToText object at 0x14a71ef10>, 'POVPIP': <folktexts.col_to_text.ColumnToText object at 0x14a71f650>, 'GCL': <folktexts.col_to_text.ColumnToText object at 0x14a71f690>, 'PUMA': <folktexts.col_to_text.ColumnToText object at 0x14a71f6d0>, 'POWPUMA': <folktexts.col_to_text.ColumnToText object at 0x14a71f910>, 'HINS2': <folktexts.col_to_text.ColumnToText object at 0x14a71f950>, 'HINS2==1': <folktexts.col_to_text.ColumnToText object at 0x14a71fd10>}, sensitive_attribute='RAC1P', target_threshold=Threshold(value=50000, op='>'), folktables_obj=<folktables.folktables.BasicProblem object at 0x14a6e2e50>)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.task"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8ca8457a-bde4-4587-a896-f504c517774d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>AGEP</th>\n",
" <th>COW</th>\n",
" <th>SCHL</th>\n",
" <th>MAR</th>\n",
" <th>OCCP</th>\n",
" <th>POBP</th>\n",
" <th>RELP</th>\n",
" <th>WKHP</th>\n",
" <th>SEX</th>\n",
" <th>RAC1P</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1815911</th>\n",
" <td>31</td>\n",
" <td>1.0</td>\n",
" <td>16.0</td>\n",
" <td>1</td>\n",
" <td>4700.0</td>\n",
" <td>49</td>\n",
" <td>0</td>\n",
" <td>45.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1699859</th>\n",
" <td>42</td>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>5</td>\n",
" <td>4251.0</td>\n",
" <td>303</td>\n",
" <td>0</td>\n",
" <td>40.0</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1658375</th>\n",
" <td>61</td>\n",
" <td>1.0</td>\n",
" <td>19.0</td>\n",
" <td>3</td>\n",
" <td>8130.0</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>40.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2347220</th>\n",
" <td>47</td>\n",
" <td>1.0</td>\n",
" <td>19.0</td>\n",
" <td>5</td>\n",
" <td>8740.0</td>\n",
" <td>40</td>\n",
" <td>0</td>\n",
" <td>50.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2260528</th>\n",
" <td>38</td>\n",
" <td>1.0</td>\n",
" <td>18.0</td>\n",
" <td>1</td>\n",
" <td>8130.0</td>\n",
" <td>39</td>\n",
" <td>1</td>\n",
" <td>50.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" AGEP COW SCHL MAR OCCP POBP RELP WKHP SEX RAC1P\n",
"1815911 31 1.0 16.0 1 4700.0 49 0 45.0 1 1\n",
"1699859 42 1.0 9.0 5 4251.0 303 0 40.0 1 8\n",
"1658375 61 1.0 19.0 3 8130.0 29 0 40.0 1 1\n",
"2347220 47 1.0 19.0 5 8740.0 40 0 50.0 1 1\n",
"2260528 38 1.0 18.0 1 8130.0 39 1 50.0 1 1"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_sample, _y_sample = dataset.sample_n_train_examples(n=5)\n",
"X_sample"
]
},
{
"cell_type": "code",
"execution_count": 16,
Expand Down Expand Up @@ -636,7 +485,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "92df961a8e4d4a7f87c2df043d077f29",
"model_id": "da09cf7633c34ed589b80bce7273facb",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -658,8 +507,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 338 ms, sys: 1.99 s, total: 2.33 s\n",
"Wall time: 9.68 s\n"
"CPU times: user 419 ms, sys: 276 ms, total: 695 ms\n",
"Wall time: 7.77 s\n"
]
},
{
Expand Down Expand Up @@ -1069,20 +918,20 @@
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LLMClassifier(encode_row=functools.partial(&lt;function encode_row_prompt at 0x30c044720&gt;, task=ACSTaskMetadata(name=&#x27;ACSIncome&#x27;, description=&quot;predict whether an individual&#x27;s income is above $50,000&quot;, features=[&#x27;AGEP&#x27;, &#x27;COW&#x27;, &#x27;SCHL&#x27;, &#x27;MAR&#x27;, &#x27;OCCP&#x27;, &#x27;POBP&#x27;, &#x27;RELP&#x27;, &#x27;WKHP&#x27;, &#x27;SEX&#x27;, &#x27;RAC1P&#x27;], target=&#x27;PINCP&#x27;, cols_to_text={&#x27;AGEP&#x27;: &lt;folktexts.col_to_text.ColumnToText object at 0x14a6f...\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LLMClassifier(encode_row=functools.partial(&lt;function encode_row_prompt at 0x307164720&gt;, task=ACSTaskMetadata(name=&#x27;ACSIncome&#x27;, description=&quot;predict whether an individual&#x27;s income is above $50,000&quot;, features=[&#x27;AGEP&#x27;, &#x27;COW&#x27;, &#x27;SCHL&#x27;, &#x27;MAR&#x27;, &#x27;OCCP&#x27;, &#x27;POBP&#x27;, &#x27;RELP&#x27;, &#x27;WKHP&#x27;, &#x27;SEX&#x27;, &#x27;RAC1P&#x27;], target=&#x27;PINCP&#x27;, cols_to_text={&#x27;AGEP&#x27;: &lt;folktexts.col_to_text.ColumnToText object at 0x151a0...\n",
"\t213: AddedToken(&quot;&lt;/s&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t214: AddedToken(&quot;&lt;/sub&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t215: AddedToken(&quot;&lt;/sup&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t216: AddedToken(&quot;&lt;/code&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;LLMClassifier<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>LLMClassifier(encode_row=functools.partial(&lt;function encode_row_prompt at 0x30c044720&gt;, task=ACSTaskMetadata(name=&#x27;ACSIncome&#x27;, description=&quot;predict whether an individual&#x27;s income is above $50,000&quot;, features=[&#x27;AGEP&#x27;, &#x27;COW&#x27;, &#x27;SCHL&#x27;, &#x27;MAR&#x27;, &#x27;OCCP&#x27;, &#x27;POBP&#x27;, &#x27;RELP&#x27;, &#x27;WKHP&#x27;, &#x27;SEX&#x27;, &#x27;RAC1P&#x27;], target=&#x27;PINCP&#x27;, cols_to_text={&#x27;AGEP&#x27;: &lt;folktexts.col_to_text.ColumnToText object at 0x14a6f...\n",
"})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;LLMClassifier<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>LLMClassifier(encode_row=functools.partial(&lt;function encode_row_prompt at 0x307164720&gt;, task=ACSTaskMetadata(name=&#x27;ACSIncome&#x27;, description=&quot;predict whether an individual&#x27;s income is above $50,000&quot;, features=[&#x27;AGEP&#x27;, &#x27;COW&#x27;, &#x27;SCHL&#x27;, &#x27;MAR&#x27;, &#x27;OCCP&#x27;, &#x27;POBP&#x27;, &#x27;RELP&#x27;, &#x27;WKHP&#x27;, &#x27;SEX&#x27;, &#x27;RAC1P&#x27;], target=&#x27;PINCP&#x27;, cols_to_text={&#x27;AGEP&#x27;: &lt;folktexts.col_to_text.ColumnToText object at 0x151a0...\n",
"\t213: AddedToken(&quot;&lt;/s&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t214: AddedToken(&quot;&lt;/sub&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t215: AddedToken(&quot;&lt;/sup&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t216: AddedToken(&quot;&lt;/code&gt;&quot;, rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"})</pre></div> </div></div></div></div>"
],
"text/plain": [
"LLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x30c044720>, task=ACSTaskMetadata(name='ACSIncome', description=\"predict whether an individual's income is above $50,000\", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText object at 0x14a6f...\n",
"LLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x307164720>, task=ACSTaskMetadata(name='ACSIncome', description=\"predict whether an individual's income is above $50,000\", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText object at 0x151a0...\n",
"\t213: AddedToken(\"</s>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t214: AddedToken(\"</sub>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t215: AddedToken(\"</sup>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
Expand Down Expand Up @@ -1119,27 +968,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Test data features shape: (166, 10)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f918dd9254ed4ff5af4f055bb1a9fdba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Computing risk estimates: 0%| | 0/6 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Test data features shape: (166, 10)\n",
"INFO:root:Loaded predictions from /Users/acruz/folktexts-results/google--gemma-2b/google--gemma-2b_bench-3421378798/ACSIncome_subsampled-0.001_seed-42_hash-3607287350.test_predictions.csv.\n",
"INFO:root:\n",
"** Test results **\n",
"Model: google--gemma-2b;\n",
Expand All @@ -1158,8 +988,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.9 s, sys: 39.6 s, total: 44.5 s\n",
"Wall time: 4min 3s\n"
"CPU times: user 1.44 s, sys: 4.44 s, total: 5.88 s\n",
"Wall time: 860 ms\n"
]
},
{
Expand Down

0 comments on commit 5bb5133

Please sign in to comment.