Skip to content

Commit

Permalink
Merge pull request #1 from databricks-industry-solutions/feature-simp…
Browse files Browse the repository at this point in the history
…lify-pandas-and-dlt

Feature simplify pandas and dlt
  • Loading branch information
zavoraad authored Jan 9, 2024
2 parents 5a16aaf + 7e12cc9 commit 8086e14
Show file tree
Hide file tree
Showing 6 changed files with 12,039 additions and 77 deletions.
54 changes: 22 additions & 32 deletions 01_data_preparation.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,34 @@
# Databricks notebook source
import pyspark.pandas as ps
import pyspark.sql.utils
import pandas as pd
import re
import dlt
import os
from pyspark.sql.functions import *
ps.set_option('compute.ops_on_diff_frames', True)
#
# load handwritten notes from csv
#
def load_data(path="/data/12k_handwritten_clinical_notes.csv"):
return (spark.read.format("csv")
.option("header",True)
.load("file:///" + os.getcwd() + path)
)

# COMMAND ----------

@dlt.table
def load_data():
data = pd.read_csv(
"//Volumes/ang_nara_catalog/rad_llm/clinical_data/12k_handwritten_clinical_notes.csv"
)
data = data.drop(['Unnamed: 0'], axis=1)
return spark.createDataFrame(data)


# COMMAND ----------

@dlt.table
def remove_label_counts_less_than_50():
df = dlt.read('load_data')
df.write.format("delta").mode("overwrite").option("overwriteSchema",True).saveAsTable("ang_nara_catalog.rad_llm.delta_rad")
df = spark.sql("""
SELECT t.input, t.radiology_labels
#
# Only use where there exists 50 or more labels
#
def filtered_table():
df = load_data()
df.createOrReplaceTempView("radiology_data")
return spark.sql("""
SELECT t.input, t.radiology_labels
FROM (
SELECT t.*, COUNT(*) OVER (PARTITION BY radiology_labels) AS cnt
FROM ang_nara_catalog.rad_llm.delta_rad t
FROM radiology_data t
) t
WHERE cnt > 50
""")
return df
""").withColumn("instruction", lit('predict radiology label for the clinical notes'))

# COMMAND ----------

@dlt.table
def filtered_table():
df = dlt.read('remove_label_counts_less_than_50')
df = df.withColumn("instruction", lit('predict radiology label for the clinical notes'))
df.write.format("delta").mode("overwrite").option("overwriteSchema",True).saveAsTable("ang_nara_catalog.rad_llm.delta_rad_filtered")
return df
data = filtered_table()
#TODO save to a volume
data.show()
43 changes: 0 additions & 43 deletions 02_create_monitor.py

This file was deleted.

15 changes: 13 additions & 2 deletions 03_train_llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Databricks notebook source
# DBTITLE 1,Setup Notebook Parameters
dbutils.widgets.text("catalog", "hls_healthcare") #catalog, default value hls_healthcare
dbutils.widgets.text("volume", "hls_dev.radiology_llm") #volume name, default value hls_dev.radiology_llm

# COMMAND ----------

#TODO require cluster type === ???

# COMMAND ----------

#install libraries
!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7

# COMMAND ----------

# MAGIC %sql
# MAGIC USE CATALOG ang_nara_catalog
# MAGIC USE CATALOG ${catalog}

# COMMAND ----------

# MAGIC %sql
# MAGIC CREATE VOLUME IF NOT EXISTS ang_nara_catalog.rad_llm.results
# MAGIC CREATE VOLUME IF NOT EXISTS ${volume}

# COMMAND ----------

Expand Down Expand Up @@ -111,6 +121,7 @@ def load_model(model_name):

# COMMAND ----------

#TODO replace with secrets and require
!huggingface-cli login --token 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

# COMMAND ----------
Expand Down
2 changes: 2 additions & 0 deletions 04_model_prediction_premlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# COMMAND ----------

#TODO replace with secrets and require
!huggingface-cli login --token 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

# COMMAND ----------
Expand Down Expand Up @@ -96,6 +97,7 @@ def pred_wrapper(model, tokenizer, prompt, model_id=1, show_metrics=True, temp=0

# COMMAND ----------

#TODO replace as a SQL UDF
prediction = list()
for index, row in df_test.iterrows():
prompt = row['clinical_notes']
Expand Down
Loading

0 comments on commit 8086e14

Please sign in to comment.