Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create art module #215

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -13,3 +13,8 @@ __pycache__/

# Generated documentation
_build

*.png
*.sas7bdat
*.csv
*.sh
98 changes: 98 additions & 0 deletions plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# files = ["output\simulation_output_20240205-230047\simulation_output_20240205-230047.csv",
# "output\simulation_output_20240205-231331\simulation_output_20240205-231331.csv",
# "output\simulation_output_20240206-104041\simulation_output_20240206-104041.csv",
# "output\simulation_output_20240206-104444\simulation_output_20240206-104444.csv",
# "output\simulation_output_20240206-104524\simulation_output_20240206-104524.csv",
# "output\simulation_output_20240206-104650\simulation_output_20240206-104650.csv",
# "output\simulation_output_20240206-105003\simulation_output_20240206-105003.csv",
# "output\simulation_output_20240206-134436\simulation_output_20240206-134436.csv",
# "output\simulation_output_20240206-135429\simulation_output_20240206-135429.csv",
# "output\simulation_output_20240206-135605\simulation_output_20240206-135605.csv"]
files = ["output\simulation_output_20240220-154213\simulation_output_20240220-154213.csv"]
data = [pd.read_csv(file) for file in files]

men_15_24 = np.array([d["Short term partners (15-64)"] for d in data]).transpose()
medians = np.array([np.median(tstep) for tstep in men_15_24])
plt.plot(medians)
plt.show()

def psb(i, sex):
if(sex == 0):
sex_string = "male"
else:
sex_string = "female"
return f"Partner sex balance ({i}-{i+9}, {sex_string})"

def pps(i, sex):
if(sex == 0):
sex_string = "male"
else:
sex_string = "female"
return f"Short term partners ({i}-{i+9}, {sex})"

sexes = [0, 1]
ages = [15, 25, 35, 45, 55]

for (a,s) in product(ages, sexes):
for d in data:
plt.plot(1989 + 0.25 * d["Time Step"], d[psb(a, s)])
plt.plot(1989 + 0.25 * data[0]["Time Step"], np.zeros(len(data[0])))
print(psb(a,s), f"mean = {np.mean(d[psb(a,s)])}", f"std = {np.std(d[psb(a,s)])}", 10**(np.mean(d[psb(a,s)])))
plt.title(psb(a, s))
plt.show()

# Create a figure with subplots
fig, axs = plt.subplots(len(ages), len(sexes), figsize=(12, 8))
fig.suptitle("Short Term Partners and Partner Sex Balance")

# Iterate over ages and sexes
# for i, a in enumerate(ages):
# for j, s in enumerate(sexes):
# ax = axs[i, j]
# ax.plot(1989 + 0.25 * data[0]["Time Step"], data[0][psb(a, s)], label=pps(a, s))
# ax.plot(1989 + 0.25 * data[0]["Time Step"], np.zeros(len(data[0])), linestyle="--", color="gray")
# ax.set_title(f"{psb(a, s)}")
# ax.set_xlabel("Year")
# ax.set_ylabel("Log Partner Balance")
# ax.set_aspect(1.0 / ax.get_data_ratio(), adjustable='box')
# vals = np.ma.masked_invalid(data[0][psb(a,s)])
# print(psb(a,s), f"mean = {np.mean(vals)}", f"std = {np.std(vals)}", f"average unbalance = {np.round((10**(np.mean(vals))-1)*100,2)}%")

# Add legend
axs[0, 0].legend()

# Adjust spacing between subplots
plt.tight_layout()

# Show the combined plot
plt.show()

# for (a,s) in product(ages, sexes):
# for d in data:
# plt.scatter(1989 + 0.25 * d["Time Step"], d[pps(a, s)])
# plt.plot(1989 + 0.25 * data[0]["Time Step"], np.zeros(len(data[0])))
# print(pps(a,s), np.mean(d[pps(a,s)]), np.std(d[pps(a,s)]), np.exp(np.mean(d[pps(a,s)])))
# plt.title(pps(a, s))
# plt.show()

# for d in data:
# plt.scatter(d["Time Step"], d["Partner sex balance (15-24, male)"])
# plt.plot(data[0]["Time Step"], np.zeros(len(data[0])))
# plt.show()
#
# for d in data:
# plt.scatter(d["Time Step"], d["Partner sex balance (25-34, male)"])
# plt.plot(data[0]["Time Step"], np.zeros(len(data[0])))
# plt.show()
#
# for d in data:
# plt.scatter(d["Time Step"], d["Partner sex balance (35-44, male)"])
# plt.plot(data[0]["Time Step"], np.zeros(len(data[0])))
# plt.show()

157 changes: 157 additions & 0 deletions src/hivpy/art.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .population import Population

import operator as op

import numpy as np
import pandas as pd

import hivpy.column_names as col

from . import output
from .common import COND, SexType, opposite_sex, rng, timedelta, date
from enum import Enum, IntEnum

class HivMonitoringStrategy(Enum):
# strategy for monitoring HIV positive people naive to ART 1: presence of tb or who4
presence_tb_who4 = 1
# strategy for monitoring HIV positive people naive to ART 2: cd4 6 monthly + presence of tb or who4
cd4_6_monthly = 2

class ArtInitiationStrategy(Enum):
# all with who4
all_who4 = 1
# all with tb or who4
all_tb_who4 = 2
# all with hiv diagnosed
all_hiv_diagnosed = 3
# cd4 < 200 or who4
cd4_lt_200_who4 = 4
# cd4 < 200 or tb or who4
cd4_lt_200_tb_who4 = 5
# cd4 < 350 or who4
cd4_lt_350_who4 = 6
# cd4 < 350 and ART immediately to pregnant women
cd4_lt_350_pregnant = 9
# cd4 < 500 and ART immediately to pregnant women
cd4_lt_500_pregnant = 10

class ArtMonitoringStrategy(Enum):
# 1. Clinical monitoring alone
only_clinical = 1
# 2. Clinical monitoring with single VL confirmation
clinical_single_vl = 2
# 3. Clinical monitoring with VL confirmation
clinical_vl_confirm = 3
# 7. Clinical monitoring with CD4 confirmation
clinical_cd4_confirm = 7
# 8. CD4 monitoring (6 mthly) alone
only_cd4_monitor = 8
# 9. CD4 count monitoring (6 mthly) with single VL confirmation
cd4_monitor_single_vl = 9
# 10. CD4 monitoring (6 mthly) with VL confirmation
cd4_monitor_vl_confirm = 10
# 150. Viral load monitoring (6m, 12m, annual) - WHO
vl_monitor_who = 150
# 152. As above with 2 yearly viral load monitoring
vl_monitor_biannual = 152
# 153. Viral load monitoring (6m, annual) no confirmation
vl_monitor_no_confirm = 153
# 1500.Viral load monitoring (6m, 12m, annual) + adh > 0.8 based on tdf level test;
vl_monitor_tdf_test = 1500

class VmFormat(Enum):
# vm_format=1 plasma lab
plasma_lab = 1
# vm_format=2 whb lab
whb_lab = 2
# vm_format=3 plasma poc
plasma_poc = 3
# vm_format=4 whb poc
whb_poc = 4

class ARTModule:
hiv_monitoring_strategy = HivMonitoringStrategy.presence_tb_who4
art_initiation_strategy = ArtInitiationStrategy.all_tb_who4
art_monitoring_strategy = ArtMonitoringStrategy.only_clinical
vm_format = VmFormat.whb_lab # TODO: check correct initial value?
vl_threshold = 1000
time_of_first_vm = 0.5
min_time_repeat_vm = 0.25 # 3 months?
poc_vl_monitoring = False
cd4_monitoring = False
year_intervention = 2024 # based on current date?

## ART coverage changes
lower_future_art_coverage = False


rate_change_art_init_strategy = {
ArtInitiationStrategy.cd4_lt_200_who4: 0.4,
ArtInitiationStrategy.cd4_lt_350_pregnant: 0.4,
ArtInitiationStrategy.cd4_lt_500_pregnant: 0.4,
ArtInitiationStrategy.all_hiv_diagnosed: 0.4
}

def __init__(self):
# set cd4_monitoring and prob_vl_meas_done
pass


def update_strategies(self, current_date: date):
if current_date < date(2005, 6, 1):
self.hiv_monitoring_strategy = HivMonitoringStrategy.presence_tb_who4
self.art_initiation_strategy = ArtInitiationStrategy.all_tb_who4
self.art_monitoring_strategy = ArtMonitoringStrategy.only_clinical

def set_initiation_strategy(art_strategy: ArtInitiationStrategy,
start_date: date,
end_date: date,
hiv_strategy = None):
if ((self.art_initiation_strategy != art_strategy)
and (start_date <= current_date < end_date)
and (rng.uniform() < self.rate_change_art_init_strategy[art_strategy])
):
self.art_initiation_strategy = art_strategy
if hiv_strategy is not None:
self.hiv_monitoring_strategy = hiv_strategy

set_initiation_strategy(ArtInitiationStrategy.cd4_lt_200_who4,
date(2008, 1, 1),
date(2011, 6, 1),
HivMonitoringStrategy.cd4_6_monthly)

set_initiation_strategy(ArtInitiationStrategy.cd4_lt_350_pregnant,
date(2011, 6, 1),
date(2014, 1, 1))

set_initiation_strategy(ArtInitiationStrategy.cd4_lt_500_pregnant,
date(2014, 1, 1),
date(2016, 6, 1))

# FIXME: This end date is silly because there is no end date for this policy
set_initiation_strategy(ArtInitiationStrategy.all_hiv_diagnosed,
date(2016, 6, 1),
date(3000, 1, 1),
HivMonitoringStrategy.presence_tb_who4)

if current_date >= date(2016, 3, 1):
self.art_monitoring_strategy = ArtMonitoringStrategy.vl_monitor_who
self.vm_format = VmFormat.whb_lab
self.vl_threshold = 1000
self.time_of_first_vm = 0.5
self.min_time_repeat_vm = 0.25
if(self.poc_vl_monitoring):
self.vm_format = VmFormat.whb_poc

if ((current_date >= date(2016, 6, 1)) and self.cd4_monitoring):
self.art_monitoring_strategy = ArtMonitoringStrategy.only_cd4_monitor

if (current_date == date(self.year_intervention, 1, 1)):
# lower future ART coverage

# higher future oral prep coverage
22 changes: 22 additions & 0 deletions src/hivpy/art_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np

from hivpy.exceptions import DataLoadException

from .common import SexType, rng
from .data_reader import DataReader


class ArtData(DataReader):
"""
Class to hold and interpret sexual behaviour data loaded from a yaml file.
"""

def __init__(self, filename):
super().__init__(filename)

try:
self.lower_future_art_coverage = self._get_discrete_dist("lower_future_art_coverage")

except KeyError as ke:
print(ke.args)
raise DataLoadException
3 changes: 3 additions & 0 deletions src/hivpy/data/art.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
lower_future_art_coverage:
Value: [False, True]
Probability: [0.97, 0.03]