Skip to content

Commit 6450351

Browse files
authored
Merge pull request #191 from piotrbartman/thrust_storage
rollback indexed_storage.py changes
2 parents dd0a2e3 + 785fa2b commit 6450351

File tree

26 files changed

+34
-96
lines changed

26 files changed

+34
-96
lines changed

PySDM/backends/numba/impl/_algorithmic_methods.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def calculate_displacement(dim, scheme, displacement, courant, cell_origin, posi
3030
@staticmethod
3131
@numba.njit(int64(int64[:], float64[:], int64[:], int64, float64[:, :], float64[:, :], float64[:], int64[:], numba.boolean, int64, float64[:]),
3232
**{**conf.JIT_FLAGS, **{'parallel': False}})
33-
# TODO: waits for https://github.com/numba/numba/issues/5279
33+
# TODO: reopen https://github.com/numba/numba/issues/5279 with minimal rep. ex.
3434
def coalescence_body(n, volume, idx, length, intensive, extensive, gamma, healthy, adaptive, subs, adaptive_memory):
3535
result = 1
3636
for i in prange(length - 1):
@@ -86,10 +86,6 @@ def compute_gamma_body(prob, rand):
8686
"""
8787
for i in prange(len(prob)):
8888
prob[i] = np.ceil(prob[i] - rand[i // 2])
89-
# TODO: same in Thrust?
90-
# prob[i] *= -1.
91-
# prob[i] += rand[i // 2]
92-
# prob[i] = -np.floor(prob[i])
9389

9490
@staticmethod
9591
def compute_gamma(prob, rand):

PySDM/backends/numba/impl/_algorithmic_step_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def amin(row, idx, length):
2525
return result
2626

2727
@staticmethod
28-
# @numba.njit(**conf.JIT_FLAGS) # TODO: "np.dot() only supported on float and complex arrays"
28+
# @numba.njit(**conf.JIT_FLAGS) # Note: in Numba 0.51 "np.dot() only supported on float and complex arrays"
2929
def cell_id(cell_id, cell_origin, strides):
3030
cell_id.data[:] = np.dot(strides.data, cell_origin.data)
3131

PySDM/backends/numba/impl/_maths_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def floor(output):
3535
@staticmethod
3636
@numba.njit(void(int64[:, :], float64[:, :]), **conf.JIT_FLAGS)
3737
def floor_out_of_place(output, input_data):
38-
output[:] = np.floor(input_data) # TODO: Try input_data//1 instead of np.floor(input_data)
38+
output[:] = np.floor(input_data)
3939

4040
@staticmethod
4141
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False}})

PySDM/backends/numba/storage/indexed_storage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def to_ndarray(self):
6464
return self.data[:self.length].copy()
6565

6666
def read_row(self, i):
67-
result = IndexedStorage(self.idx, self.data[i, :], (1, *self.shape[1:]), self.dtype)
67+
# TODO: shape like in ThrustRTC
68+
result = IndexedStorage(self.idx, self.data[i, :], *self.shape[1:], self.dtype)
6869
return result
6970

7071
def remove_zeros(self):

PySDM/backends/numba/storage/storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,6 @@ def to_ndarray(self):
166166
def upload(self, data):
167167
np.copyto(self.data, data, casting='safe')
168168

169-
# TODO: optimize
169+
# TODO: remove
170170
def write_row(self, i, row):
171171
self.data[i, :] = row.data

PySDM/backends/thrustRTC/impl/_maths_methods.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
Created at 10.12.2019
33
"""
44

5-
import numpy as np
65
import ThrustRTC as trtc
7-
import CURandRTC as rndrtc
86
from ._storage_methods import StorageMethods
97
from PySDM.backends.thrustRTC.nice_thrust import nice_thrust
108
from PySDM.backends.thrustRTC.conf import NICE_THRUST_FLAGS
@@ -119,31 +117,3 @@ def power(output, exponent):
119117
@nice_thrust(**NICE_THRUST_FLAGS)
120118
def subtract(output, subtrahend):
121119
MathsMethods.__subtract_body.launch_n(output.size(), [output, subtrahend])
122-
# trtc.Transform_Binary(output, subtrahend, output, trtc.Minus())
123-
124-
__urand_init_rng_state_body = trtc.For(['rng', 'states', 'seed'], 'i', '''
125-
rng.state_init(1234, i, 0, states[i]);
126-
''')
127-
128-
__urand_body = trtc.For(['states', 'vec_rnd'], 'i', '''
129-
vec_rnd[i]=states[i].rand01();
130-
''')
131-
132-
__rng = rndrtc.DVRNG()
133-
states = trtc.device_vector('RNGState', 2**19)
134-
__urand_init_rng_state_body.launch_n(states.size(), [__rng, states, trtc.DVInt64(12)])
135-
136-
@staticmethod
137-
@nice_thrust(**NICE_THRUST_FLAGS)
138-
def urand(data, seed=None):
139-
# TODO: print("Numpy import!: ThrustRTC.urand(...)")
140-
141-
seed = seed or np.random.randint(2**16)
142-
dseed = trtc.DVInt64(seed)
143-
# MathsMethods.__urand_init_rng_state_body.launch_n(MathsMethods.states.size(), [MathsMethods.__rng, MathsMethods.states, dseed])
144-
MathsMethods.__urand_body.launch_n(data.size(), [MathsMethods.states, data])
145-
# hdata = data.to_host()
146-
# print(np.mean(hdata))
147-
# np.random.seed(seed)
148-
# output = np.random.uniform(0, 1, data.shape)
149-
# StorageMethods.upload(output, data)

PySDM/backends/thrustRTC/impl/_storage_methods.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
class StorageMethods:
12-
# TODO check static For
1312
storage = trtc.DVVector.DVVector
1413
integer = np.int64
1514
double = np.float64
@@ -27,7 +26,6 @@ def array(shape, dtype):
2726
raise NotImplementedError
2827

2928
data = trtc.device_vector(elem_cls, int(np.prod(shape)))
30-
# TODO: trtc.Fill(data, trtc.DVConstant(np.nan))
3129

3230
StorageMethods.__equip(data, shape, elem_dtype)
3331
return data
@@ -109,19 +107,10 @@ def shuffle_global(idx, length, u01):
109107
@nice_thrust(**NICE_THRUST_FLAGS)
110108
def shuffle_local(idx, u01, cell_start):
111109
StorageMethods.__shuffle_local_body.launch_n(cell_start.size() - 1, [cell_start, u01, idx])
112-
# TODO: print("Numba import!: ThrustRTC.shuffle_local(...)")
113-
# from PySDM.backends.numba.numba import Numba
114-
# host_idx = StorageMethods.to_ndarray(idx)
115-
# host_u01 = StorageMethods.to_ndarray(u01)
116-
# host_cell_start = StorageMethods.to_ndarray(cell_start)
117-
# Numba.shuffle_local(host_idx, host_u01, host_cell_start)
118-
# device_idx = StorageMethods.from_ndarray(host_idx)
119-
# trtc.Copy(device_idx, idx)
120110

121111
@staticmethod
122112
@nice_thrust(**NICE_THRUST_FLAGS)
123113
def to_ndarray(data):
124-
# TODO: move to __equip??
125114
if isinstance(data, StorageMethods.storage):
126115
pass
127116
elif isinstance(data, trtc.DVVector.DVRange):

PySDM/dynamics/condensation/condensation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def register(self, builder):
4242
self.max_substeps = int(self.core.dt)
4343
self.max_substeps = int(self.core.dt)
4444
self.substeps = self.core.Storage.empty(self.core.mesh.n_cell, dtype=int)
45-
self.substeps[:] = np.maximum(1, int(self.core.dt)) # TODO: reset substeps
45+
self.substeps[:] = np.maximum(1, int(self.core.dt)) # TODO: min substep length
4646
self.ripening_flags = self.core.Storage.empty(self.core.mesh.n_cell, dtype=int)
4747
self.ripening_flags[:] = 0
4848

PySDM/mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def mesh_0d(dv=None):
3737

3838
@staticmethod
3939
def __strides(grid):
40-
domain = np.empty(tuple(grid)) # TODO optimize
40+
domain = np.empty(tuple(grid))
4141
strides = np.array(domain.strides).reshape(1, -1) // domain.itemsize
4242
return strides
4343

PySDM/state/products/aerosol_specific_concentration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, radius_threshold):
2222
def get(self):
2323
self.download_moment_to_buffer('volume', rank=0,
2424
filter_range=[0, phys.volume(self.radius_threshold)])
25-
result = self.buffer.copy() # TODO !!!
25+
result = self.buffer.copy() # TODO
2626
self.download_to_buffer(self.core.environment['rhod'])
2727
result[:] /= self.core.mesh.dv
2828
result[:] /= self.buffer

0 commit comments

Comments
 (0)