Skip to content

Commit

Permalink
fix: use dask to speed up cxg conversion and run in constant memory (#…
Browse files Browse the repository at this point in the history
…7364)

Co-authored-by: nayib-jose-gloria <[email protected]>
Co-authored-by: Trent Smith
  • Loading branch information
Bento007 and nayib-jose-gloria authored Jan 14, 2025
1 parent 921aa4a commit d1af140
Show file tree
Hide file tree
Showing 27 changed files with 259 additions and 398 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rdev-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,6 @@ jobs:
if: always()
with:
name: logged-in-test-results
path: frontend/playwright-report/
path: /home/runner/work/single-cell-data-portal/single-cell-data-portal/frontend/playwright-report
retention-days: 30
if-no-files-found: error
48 changes: 48 additions & 0 deletions .happy/terraform/modules/batch/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,54 @@ resource aws_batch_job_definition batch_job_def {
container_properties = jsonencode({
"jobRoleArn": "${var.batch_role_arn}",
"image": "${var.image}",
"memory": 8000,
"environment": [
{
"name": "ARTIFACT_BUCKET",
"value": "${var.artifact_bucket}"
},
{
"name": "CELLXGENE_BUCKET",
"value": "${var.cellxgene_bucket}"
},
{
"name": "DATASETS_BUCKET",
"value": "${var.datasets_bucket}"
},
{
"name": "DEPLOYMENT_STAGE",
"value": "${var.deployment_stage}"
},
{
"name": "AWS_DEFAULT_REGION",
"value": "${data.aws_region.current.name}"
},
{
"name": "REMOTE_DEV_PREFIX",
"value": "${var.remote_dev_prefix}"
},
{
"name": "FRONTEND_URL",
"value": "${var.frontend_url}"
}
],
"vcpus": 1,
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "${aws_cloudwatch_log_group.cloud_watch_logs_group.id}",
"awslogs-region": "${data.aws_region.current.name}"
}
}
})
}

resource aws_batch_job_definition cxg_job_def {
type = "container"
name = "dp-${var.deployment_stage}-${var.custom_stack_name}-convert"
container_properties = jsonencode({
"jobRoleArn": "${var.batch_role_arn}",
"image": "${var.image}",
"memory": 16000,
"environment": [
{
Expand Down
5 changes: 5 additions & 0 deletions .happy/terraform/modules/batch/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ output batch_job_definition_no_revision {
description = "ARN for the batch job definition"
}

output cxg_job_definition_no_revision {
value = "arn:aws:batch:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:job-definition/${aws_batch_job_definition.cxg_job_def.name}"
description = "ARN for the cxg batch job definition"
}

output batch_job_log_group {
value = aws_cloudwatch_log_group.cloud_watch_logs_group.id
description = "Name of the CloudWatch log group for the batch job"
Expand Down
1 change: 1 addition & 0 deletions .happy/terraform/modules/ecs-stack/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ module upload_error_lambda {
module upload_sfn {
source = "../sfn"
job_definition_arn = module.upload_batch.batch_job_definition_no_revision
cxg_definition_arn = module.upload_batch.cxg_job_definition_no_revision
job_queue_arn = local.job_queue_arn
role_arn = local.sfn_role_arn
custom_stack_name = local.custom_stack_name
Expand Down
73 changes: 10 additions & 63 deletions .happy/terraform/modules/schema_migration/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,13 @@ data aws_caller_identity current {}
locals {
name = "schema-migration"
job_definition_arn = "arn:aws:batch:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:job-definition/dp-${var.deployment_stage}-${var.custom_stack_name}-schema-migration"
swap_job_definition_arn = "${local.job_definition_arn}-swap"
}

resource aws_cloudwatch_log_group batch_cloud_watch_logs_group {
retention_in_days = 365
name = "/dp/${var.deployment_stage}/${var.custom_stack_name}/${local.name}-batch"
}

resource aws_batch_job_definition schema_migrations_swap {
type = "container"
name = "dp-${var.deployment_stage}-${var.custom_stack_name}-${local.name}-swap"
container_properties = jsonencode({
jobRoleArn= var.batch_role_arn,
image= var.image,
environment= [
{
name= "ARTIFACT_BUCKET",
value= var.artifact_bucket
},
{
name= "DEPLOYMENT_STAGE",
value= var.deployment_stage
},
{
name= "AWS_DEFAULT_REGION",
value= data.aws_region.current.name
},
{
name= "REMOTE_DEV_PREFIX",
value= var.remote_dev_prefix
},
{
name= "DATASETS_BUCKET",
value= var.datasets_bucket
},
],
resourceRequirements = [
{
type= "VCPU",
Value="32"
},
{
Type="MEMORY",
Value = "256000"
}
]
linuxParameters= {
maxSwap= 0,
swappiness= 60
},
logConfiguration= {
logDriver= "awslogs",
options= {
awslogs-group= aws_cloudwatch_log_group.batch_cloud_watch_logs_group.id,
awslogs-region= data.aws_region.current.name
}
}
})
}

resource aws_batch_job_definition schema_migrations {
type = "container"
name = "dp-${var.deployment_stage}-${var.custom_stack_name}-${local.name}"
Expand Down Expand Up @@ -94,14 +41,14 @@ resource aws_batch_job_definition schema_migrations {
},
],
resourceRequirements = [
{
type= "VCPU",
Value="2"
},
{
Type="MEMORY",
Value = "2048"
}
{
type= "VCPU",
Value="1"
},
{
Type="MEMORY",
Value = "8000"
}
]
logConfiguration= {
logDriver= "awslogs",
Expand Down Expand Up @@ -385,7 +332,7 @@ resource aws_sfn_state_machine sfn_schema_migration {
"Type": "Task",
"Resource": "arn:aws:states:::batch:submitJob.sync",
"Parameters": {
"JobDefinition": "${resource.aws_batch_job_definition.schema_migrations_swap.arn}",
"JobDefinition": "${resource.aws_batch_job_definition.schema_migrations.arn}",
"JobName": "dataset_migration",
"JobQueue": "${var.job_queue_arn}",
"Timeout": {
Expand Down Expand Up @@ -518,7 +465,7 @@ resource aws_sfn_state_machine sfn_schema_migration {
"Key.$": "$.key_name"
}
},
"MaxConcurrency": 10,
"MaxConcurrency": 30,
"Next": "report",
"Catch": [
{
Expand Down
17 changes: 9 additions & 8 deletions .happy/terraform/modules/sfn/main.tf
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This is used for environment (dev, staging, prod) deployments
locals {
timeout = 86400 # 24 hours
h5ad_timeout = 86400 # 24 hours
cxg_timeout = 172800 # 48 hours
}

data aws_region current {}
Expand Down Expand Up @@ -33,7 +34,7 @@ resource "aws_sfn_state_machine" "state_machine" {
"Validate": {
"Type": "Task",
"Resource": "arn:aws:states:::batch:submitJob.sync",
"Next": "CxgSeuratParallel",
"Next": "Cxg",
"Parameters": {
"JobDefinition":"${var.job_definition_arn}",
"JobName": "validate",
Expand All @@ -60,7 +61,7 @@ resource "aws_sfn_state_machine" "state_machine" {
}
},
"ResultPath": null,
"TimeoutSeconds": ${local.timeout},
"TimeoutSeconds": ${local.h5ad_timeout},
"Retry": [ {
"ErrorEquals": ["AWS.Batch.TooManyRequestsException", "Batch.BatchException", "Batch.AWSBatchException"],
"IntervalSeconds": 2,
Expand All @@ -82,7 +83,7 @@ resource "aws_sfn_state_machine" "state_machine" {
"Next": "HandleSuccess",
"Resource": "arn:aws:states:::batch:submitJob.sync",
"Parameters": {
"JobDefinition.$": "$.batch.JobDefinitionName",
"JobDefinition":"${var.cxg_definition_arn}",
"JobName": "cxg",
"JobQueue.$": "$.job_queue",
"ContainerOverrides": {
Expand Down Expand Up @@ -114,7 +115,7 @@ resource "aws_sfn_state_machine" "state_machine" {
}
],
"ResultPath": null,
"TimeoutSeconds": 360000
"TimeoutSeconds": ${local.cxg_timeout}
},
"CatchCxgFailure": {
"Type": "Pass",
Expand Down Expand Up @@ -182,7 +183,7 @@ resource "aws_sfn_state_machine" "state_machine" {
},
"ConversionError": {
"Type": "Fail",
"Cause": "CXG and/or Seurat conversion failed."
"Cause": "CXG conversion failed."
},
"DownloadValidateError": {
"Type": "Fail",
Expand Down Expand Up @@ -210,7 +211,7 @@ resource "aws_sfn_state_machine" "state_machine_cxg_remaster" {
"End": true,
"Resource": "arn:aws:states:::batch:submitJob.sync",
"Parameters": {
"JobDefinition": "${var.job_definition_arn}",
"JobDefinition": "${var.cxg_definition_arn}",
"JobName": "cxg_remaster",
"JobQueue": "${var.job_queue_arn}",
"ContainerOverrides": {
Expand All @@ -226,7 +227,7 @@ resource "aws_sfn_state_machine" "state_machine_cxg_remaster" {
]
}
},
"TimeoutSeconds": ${local.timeout}
"TimeoutSeconds": ${local.cxg_timeout}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions .happy/terraform/modules/sfn/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ variable job_definition_arn {
description = "ARN of the batch job definition"
}

variable cxg_definition_arn {
type = string
description = "ARN of the cxg batch job definition"
}

variable job_queue_arn {
type = string
description = "ARN of the batch job queue"
Expand Down
2 changes: 1 addition & 1 deletion backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ endif

db/connect_internal:
$(eval DB_PW = $(shell aws secretsmanager get-secret-value --secret-id corpora/backend/${DEPLOYMENT_STAGE}/database --region us-west-2 | jq -r '.SecretString | match(":([^:]*)@").captures[0].string'))
PGOPTIONS='-csearch_path=persistence_schema' PGPASSWORD=${DB_PW} psql --dbname ${DB_NAME} --username ${DB_USER} --host 0.0.0.0 $(ARGS)
PGOPTIONS='-csearch_path=persistence_schema' PGPASSWORD=${DB_PW} psql --dbname ${DB_NAME} --username ${DB_USER} --host 0.0.0.0 --port 5433 $(ARGS)

db/console: db/connect # alias

Expand Down
31 changes: 19 additions & 12 deletions backend/layers/processing/h5ad_data_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from os import path
from typing import Dict, Optional

import anndata
import dask
import numpy as np
import tiledb
from cellxgene_schema.utils import read_h5ad

from backend.common.utils.corpora_constants import CorporaConstants
from backend.common.utils.cxg_constants import CxgConstants
Expand All @@ -32,6 +33,8 @@ class H5ADDataFile:

tile_db_ctx_config = {
"sm.consolidation.buffer_size": consolidation_buffer_size(0.1),
"sm.consolidation.step_min_frags": 2,
"sm.consolidation.step_max_frags": 20, # see https://docs.tiledb.com/main/how-to/performance/performance-tips/tuning-consolidation
"py.deduplicate": True, # May reduce memory requirements at cost of performance
}

Expand Down Expand Up @@ -95,19 +98,23 @@ def to_cxg(

def write_anndata_x_matrices_to_cxg(self, output_cxg_directory, ctx, sparse_threshold):
matrix_container = f"{output_cxg_directory}/X"

x_matrix_data = self.anndata.X
is_sparse = is_matrix_sparse(x_matrix_data, sparse_threshold) # big memory usage
logging.info(f"is_sparse: {is_sparse}")

convert_matrices_to_cxg_arrays(matrix_container, x_matrix_data, is_sparse, ctx) # big memory usage
with dask.config.set(
{
"num_workers": 2, # match the number of workers to the number of vCPUs
"threads_per_worker": 1,
"distributed.worker.memory.limit": "6GB",
"scheduler": "threads",
}
):
is_sparse = is_matrix_sparse(x_matrix_data, sparse_threshold)
logging.info(f"is_sparse: {is_sparse}")
convert_matrices_to_cxg_arrays(matrix_container, x_matrix_data, is_sparse, self.tile_db_ctx_config)

suffixes = ["r", "c"] if is_sparse else [""]
logging.info("start consolidating")
for suffix in suffixes:
tiledb.consolidate(matrix_container + suffix, ctx=ctx)
if hasattr(tiledb, "vacuum"):
tiledb.vacuum(matrix_container + suffix)
tiledb.consolidate(matrix_container, ctx=ctx)
if hasattr(tiledb, "vacuum"):
tiledb.vacuum(matrix_container)

def write_anndata_embeddings_to_cxg(self, output_cxg_directory, ctx):
def is_valid_embedding(adata, embedding_name, embedding_array):
Expand Down Expand Up @@ -183,7 +190,7 @@ def validate_anndata(self):

def extract_anndata_elements_from_file(self):
logging.info(f"Reading in AnnData dataset: {path.basename(self.input_filename)}")
self.anndata = anndata.read_h5ad(self.input_filename)
self.anndata = read_h5ad(self.input_filename, chunk_size=7500)
logging.info("Completed reading in AnnData dataset!")

self.obs = self.transform_dataframe_index_into_column(self.anndata.obs, "obs", self.obs_index_column_name)
Expand Down
8 changes: 6 additions & 2 deletions backend/layers/processing/process_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def validate_h5ad_file_and_add_labels(

output_filename = CorporaConstants.LABELED_H5AD_ARTIFACT_FILENAME
try:
is_valid, errors, _ = self.schema_validator.validate_and_save_labels(local_filename, output_filename)
is_valid, errors, _ = self.schema_validator.validate_and_save_labels(
local_filename, output_filename, n_workers=1
) # match the number of workers to the number of vCPUs
except Exception as e:
self.logger.exception("validation failed")
raise ValidationFailed([str(e)]) from None
Expand All @@ -139,6 +141,8 @@ def validate_h5ad_file_and_add_labels(
self.update_processing_status(
dataset_version_id, DatasetStatusKey.VALIDATION, DatasetValidationStatus.VALID
)
# Skip seurat conversion
self.update_processing_status(dataset_version_id, DatasetStatusKey.RDS, DatasetConversionStatus.SKIPPED)
return output_filename

def populate_dataset_citation(
Expand Down Expand Up @@ -278,7 +282,7 @@ def process(

# Validate and label the dataset
file_with_labels = self.validate_h5ad_file_and_add_labels(
collection_version_id, dataset_version_id, original_h5ad_artifact_file_name
collection_version_id, dataset_version_id, local_filename
)
# Process metadata
metadata = self.extract_metadata(file_with_labels)
Expand Down
Loading

0 comments on commit d1af140

Please sign in to comment.