Skip to content

Commit fb7336f

Browse files
author
Liu Zhengyun
committed
sync codes for ainode
1 parent 6228880 commit fb7336f

File tree

8 files changed

+135
-90
lines changed

8 files changed

+135
-90
lines changed

iotdb-core/ainode/iotdb/ainode/core/config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
AINODE_CLUSTER_INGRESS_ADDRESS,
2424
AINODE_CLUSTER_INGRESS_PASSWORD,
2525
AINODE_CLUSTER_INGRESS_PORT,
26-
AINODE_CLUSTER_INGRESS_TIME_ZONE,
2726
AINODE_CLUSTER_INGRESS_USERNAME,
2827
AINODE_CLUSTER_NAME,
2928
AINODE_CONF_DIRECTORY_NAME,
@@ -69,7 +68,6 @@ def __init__(self):
6968
self._ain_cluster_ingress_port = AINODE_CLUSTER_INGRESS_PORT
7069
self._ain_cluster_ingress_username = AINODE_CLUSTER_INGRESS_USERNAME
7170
self._ain_cluster_ingress_password = AINODE_CLUSTER_INGRESS_PASSWORD
72-
self._ain_cluster_ingress_time_zone = AINODE_CLUSTER_INGRESS_TIME_ZONE
7371

7472
# Inference configuration
7573
self._ain_inference_batch_interval_in_ms: int = (
@@ -287,14 +285,6 @@ def set_ain_cluster_ingress_password(
287285
) -> None:
288286
self._ain_cluster_ingress_password = ain_cluster_ingress_password
289287

290-
def get_ain_cluster_ingress_time_zone(self) -> str:
291-
return self._ain_cluster_ingress_time_zone
292-
293-
def set_ain_cluster_ingress_time_zone(
294-
self, ain_cluster_ingress_time_zone: str
295-
) -> None:
296-
self._ain_cluster_ingress_time_zone = ain_cluster_ingress_time_zone
297-
298288

299289
@singleton
300290
class AINodeDescriptor(object):
@@ -432,11 +422,6 @@ def _load_config_from_file(self) -> None:
432422
file_configs["ain_cluster_ingress_password"]
433423
)
434424

435-
if "ain_cluster_ingress_time_zone" in config_keys:
436-
self._config.set_ain_cluster_ingress_time_zone(
437-
file_configs["ain_cluster_ingress_time_zone"]
438-
)
439-
440425
except BadNodeUrlException:
441426
logger.warning("Cannot load AINode conf file, use default configuration.")
442427

iotdb-core/ainode/iotdb/ainode/core/constant.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
AINODE_CLUSTER_INGRESS_PORT = 6667
4040
AINODE_CLUSTER_INGRESS_USERNAME = "root"
4141
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
42-
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
4342

4443
# RPC config
4544
AINODE_THRIFT_COMPRESSION_ENABLED = False

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def _step(self):
127127
for i in range(batch_inputs.size(0)):
128128
batch_input_list.append({"targets": batch_inputs[i]})
129129
batch_inputs = self._inference_pipeline.preprocess(
130-
batch_input_list, output_length=requests[0].output_length
130+
batch_input_list,
131+
output_length=requests[0].output_length,
132+
auto_adapt=True,
131133
)
132134
if isinstance(self._inference_pipeline, ForecastPipeline):
133135
batch_output = self._inference_pipeline.forecast(

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
from abc import ABC, abstractmethod
2020

2121
import torch
22+
from torch.nn import functional as F
2223

2324
from iotdb.ainode.core.exception import InferenceModelInternalException
25+
from iotdb.ainode.core.log import Logger
2426
from iotdb.ainode.core.manager.device_manager import DeviceManager
2527
from iotdb.ainode.core.model.model_info import ModelInfo
2628
from iotdb.ainode.core.model.model_loader import load_model
2729

2830
BACKEND = DeviceManager()
31+
logger = Logger()
2932

3033

3134
class BasicPipeline(ABC):
@@ -70,6 +73,7 @@ def preprocess(
7073
7174
infer_kwargs (dict, optional): Additional keyword arguments for inference, such as:
7275
- `output_length`(int): Used to check validation of 'future_covariates' if provided.
76+
- `auto_adapt`(bool): Whether to automatically adapt the covariates.
7377
7478
Raises:
7579
ValueError: If the input format is incorrect (e.g., missing keys, invalid tensor shapes).
@@ -80,6 +84,7 @@ def preprocess(
8084

8185
if isinstance(inputs, list):
8286
output_length = infer_kwargs.get("output_length", 96)
87+
auto_adapt = infer_kwargs.get("auto_adapt", True)
8388
for idx, input_dict in enumerate(inputs):
8489
# Check if the dictionary contains the expected keys
8590
if not isinstance(input_dict, dict):
@@ -121,10 +126,30 @@ def preprocess(
121126
raise ValueError(
122127
f"Each value in 'past_covariates' must be torch.Tensor, but got {type(cov_value)} for key '{cov_key}' at index {idx}."
123128
)
124-
if cov_value.ndim != 1 or cov_value.shape[0] != input_length:
129+
if cov_value.ndim != 1:
125130
raise ValueError(
126-
f"Each covariate in 'past_covariates' must have shape ({input_length},), but got shape {cov_value.shape} for key '{cov_key}' at index {idx}."
131+
f"Individual `past_covariates` must be 1-d, found: {cov_key} with {cov_value.ndim} dimensions in element at index {idx}."
127132
)
133+
# If any past_covariate's length is not equal to input_length, process it accordingly.
134+
if cov_value.shape[0] != input_length:
135+
if auto_adapt:
136+
if cov_value.shape[0] > input_length:
137+
logger.warning(
138+
f"Past covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (> {input_length}), which will be truncated from the beginning."
139+
)
140+
past_covariates[cov_key] = cov_value[-input_length:]
141+
else:
142+
logger.warning(
143+
f"Past covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {input_length}), which will be padded with zeros at the beginning."
144+
)
145+
pad_size = input_length - cov_value.shape[0]
146+
past_covariates[cov_key] = F.pad(
147+
cov_value, (pad_size, 0)
148+
)
149+
else:
150+
raise ValueError(
151+
f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {input_length}), found: {cov_key} with shape {tuple(cov_value.shape)} in element at index {idx}."
152+
)
128153

129154
# Check 'future_covariates' if it exists (optional)
130155
future_covariates = input_dict.get("future_covariates", {})
@@ -134,19 +159,52 @@ def preprocess(
134159
)
135160
# If future_covariates exists, check if they are a subset of past_covariates
136161
if future_covariates:
137-
for cov_key, cov_value in future_covariates.items():
162+
for cov_key, cov_value in list(future_covariates.items()):
163+
# If any future_covariate not found in past_covariates, ignore it or raise an error.
138164
if cov_key not in past_covariates:
139-
raise ValueError(
140-
f"Key '{cov_key}' in 'future_covariates' is not in 'past_covariates' at index {idx}."
141-
)
165+
if auto_adapt:
166+
future_covariates.pop(cov_key)
167+
logger.warning(
168+
f"Future covariate {cov_key} not found in past_covariates {list(past_covariates.keys())}, which will be ignored when executing forecasting."
169+
)
170+
if not future_covariates:
171+
input_dict.pop("future_covariates")
172+
continue
173+
else:
174+
raise ValueError(
175+
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {list(past_covariates.keys())}, "
176+
f"but found {cov_key} in element at index {idx}."
177+
)
142178
if not isinstance(cov_value, torch.Tensor):
143179
raise ValueError(
144180
f"Each value in 'future_covariates' must be torch.Tensor, but got {type(cov_value)} for key '{cov_key}' at index {idx}."
145181
)
146-
if cov_value.ndim != 1 or cov_value.shape[0] != output_length:
182+
if cov_value.ndim != 1:
147183
raise ValueError(
148-
f"Each covariate in 'future_covariates' must have shape ({output_length},), but got shape {cov_value.shape} for key '{cov_key}' at index {idx}."
184+
f"Individual `future_covariates` must be 1-d, found: {cov_key} with {cov_value.ndim} dimensions in element at index {idx}."
149185
)
186+
# If any future_covariate's length is not equal to output_length, process it accordingly.
187+
if cov_value.shape[0] != output_length:
188+
if auto_adapt:
189+
if cov_value.shape[0] > output_length:
190+
logger.warning(
191+
f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (> {output_length}), which will be truncated from the end."
192+
)
193+
future_covariates[cov_key] = cov_value[
194+
:output_length
195+
]
196+
else:
197+
logger.warning(
198+
f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded with zeros at the end."
199+
)
200+
pad_size = output_length - cov_value.shape[0]
201+
future_covariates[cov_key] = F.pad(
202+
cov_value, (0, pad_size)
203+
)
204+
else:
205+
raise ValueError(
206+
f"Individual `future_covariates` must be 1-d with length equal to `output_length` (= {output_length}), found: {cov_key} with shape {tuple(cov_value.shape)} in element at index {idx}."
207+
)
150208
else:
151209
raise ValueError(
152210
f"The inputs must be a list of dictionaries, but got {type(inputs)}."

iotdb-core/ainode/iotdb/ainode/core/ingress/iotdb.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def __init__(
6969
password: str = AINodeDescriptor()
7070
.get_config()
7171
.get_ain_cluster_ingress_password(),
72-
time_zone: str = AINodeDescriptor()
73-
.get_config()
74-
.get_ain_cluster_ingress_time_zone(),
7572
use_rate: float = 1.0,
7673
offset_rate: float = 0.0,
7774
):
@@ -90,7 +87,6 @@ def __init__(
9087
node_urls=[f"{ip}:{port}"],
9188
user=username,
9289
password=password,
93-
zone_id=time_zone,
9490
use_ssl=AINodeDescriptor()
9591
.get_config()
9692
.get_ain_cluster_ingress_ssl_enabled(),
@@ -258,9 +254,6 @@ def __init__(
258254
password: str = AINodeDescriptor()
259255
.get_config()
260256
.get_ain_cluster_ingress_password(),
261-
time_zone: str = AINodeDescriptor()
262-
.get_config()
263-
.get_ain_cluster_ingress_time_zone(),
264257
use_rate: float = 1.0,
265258
offset_rate: float = 0.0,
266259
):
@@ -272,7 +265,6 @@ def __init__(
272265
node_urls=[f"{ip}:{port}"],
273266
username=username,
274267
password=password,
275-
time_zone=time_zone,
276268
use_ssl=AINodeDescriptor()
277269
.get_config()
278270
.get_ain_cluster_ingress_ssl_enabled(),

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,66 @@ def _process_request(self, req):
175175
with self._result_wrapper_lock:
176176
del self._result_wrapper_map[req_id]
177177

178+
def _do_inference_and_construct_resp(
179+
self,
180+
model_id: str,
181+
model_inputs_list: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]],
182+
output_length: int,
183+
inference_attrs: dict,
184+
**kwargs,
185+
) -> list[bytes]:
186+
auto_adapt = kwargs.get("auto_adapt", True)
187+
if (
188+
output_length
189+
> AINodeDescriptor().get_config().get_ain_inference_max_output_length()
190+
):
191+
raise NumericalRangeException(
192+
"output_length",
193+
output_length,
194+
1,
195+
AINodeDescriptor().get_config().get_ain_inference_max_output_length(),
196+
)
197+
198+
if self._pool_controller.has_running_pools(model_id):
199+
infer_req = InferenceRequest(
200+
req_id=generate_req_id(),
201+
model_id=model_id,
202+
inputs=torch.stack(
203+
[data["targets"] for data in model_inputs_list], dim=0
204+
),
205+
output_length=output_length,
206+
)
207+
outputs = self._process_request(infer_req)
208+
else:
209+
model_info = self._model_manager.get_model_info(model_id)
210+
inference_pipeline = load_pipeline(
211+
model_info, device=self._backend.torch_device("cpu")
212+
)
213+
inputs = inference_pipeline.preprocess(
214+
model_inputs_list,
215+
output_length=output_length,
216+
auto_adapt=auto_adapt,
217+
)
218+
if isinstance(inference_pipeline, ForecastPipeline):
219+
outputs = inference_pipeline.forecast(
220+
inputs, output_length=output_length, **inference_attrs
221+
)
222+
elif isinstance(inference_pipeline, ClassificationPipeline):
223+
outputs = inference_pipeline.classify(inputs)
224+
elif isinstance(inference_pipeline, ChatPipeline):
225+
outputs = inference_pipeline.chat(inputs)
226+
else:
227+
outputs = None
228+
logger.error("[Inference] Unsupported pipeline type.")
229+
outputs = inference_pipeline.postprocess(outputs)
230+
231+
# convert tensor into tsblock for the output in each batch
232+
resp_list = []
233+
for batch_idx, output in enumerate(outputs):
234+
resp = convert_tensor_to_tsblock(output)
235+
resp_list.append(resp)
236+
return resp_list
237+
178238
def _run(
179239
self,
180240
req,
@@ -191,65 +251,17 @@ def _run(
191251
inference_attrs = extract_attrs(req)
192252
output_length = int(inference_attrs.pop("output_length", 96))
193253

194-
# model_inputs_list: Each element is a dict, which contains the following keys:
195-
# `targets`: The input tensor for the target variable(s), whose shape is [target_count, input_length].
196254
model_inputs_list: list[
197255
dict[str, torch.Tensor | dict[str, torch.Tensor]]
198256
] = [{"targets": inputs[0]}]
199257

200-
if (
201-
output_length
202-
> AINodeDescriptor().get_config().get_ain_inference_max_output_length()
203-
):
204-
raise NumericalRangeException(
205-
"output_length",
206-
output_length,
207-
1,
208-
AINodeDescriptor()
209-
.get_config()
210-
.get_ain_inference_max_output_length(),
211-
)
212-
213-
if self._pool_controller.has_running_pools(model_id):
214-
infer_req = InferenceRequest(
215-
req_id=generate_req_id(),
216-
model_id=model_id,
217-
inputs=torch.stack(
218-
[data["targets"] for data in model_inputs_list], dim=0
219-
),
220-
output_length=output_length,
221-
)
222-
outputs = self._process_request(infer_req)
223-
else:
224-
model_info = self._model_manager.get_model_info(model_id)
225-
inference_pipeline = load_pipeline(
226-
model_info, device=self._backend.torch_device("cpu")
227-
)
228-
inputs = inference_pipeline.preprocess(
229-
model_inputs_list, output_length=output_length
230-
)
231-
if isinstance(inference_pipeline, ForecastPipeline):
232-
outputs = inference_pipeline.forecast(
233-
inputs, output_length=output_length, **inference_attrs
234-
)
235-
elif isinstance(inference_pipeline, ClassificationPipeline):
236-
outputs = inference_pipeline.classify(inputs)
237-
elif isinstance(inference_pipeline, ChatPipeline):
238-
outputs = inference_pipeline.chat(inputs)
239-
else:
240-
outputs = None
241-
logger.error("[Inference] Unsupported pipeline type.")
242-
outputs = inference_pipeline.postprocess(outputs)
243-
244-
# convert tensor into tsblock for the output in each batch
245-
output_list = []
246-
for batch_idx, output in enumerate(outputs):
247-
output = convert_tensor_to_tsblock(output)
248-
output_list.append(output)
258+
resp_list = self._do_inference_and_construct_resp(
259+
model_id, model_inputs_list, output_length, inference_attrs
260+
)
249261

250262
return resp_cls(
251263
get_status(TSStatusCode.SUCCESS_STATUS),
252-
[output_list[0]] if single_batch else output_list,
264+
[resp_list[0]] if single_batch else resp_list,
253265
)
254266

255267
except Exception as e:

iotdb-core/ainode/resources/conf/iotdb-ainode.properties

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ ain_cluster_ingress_username=root
5252
# Datatype: String
5353
ain_cluster_ingress_password=root
5454

55-
# The time zone of the IoTDB cluster.
56-
# Datatype: String
57-
ain_cluster_ingress_time_zone=UTC+8
58-
5955
# The device space allocated for inference
6056
# Datatype: Float
6157
ain_inference_memory_usage_ratio=0.2

iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ struct TForecastReq {
8787
3: required i32 outputLength
8888
4: optional string historyCovs
8989
5: optional string futureCovs
90-
6: optional map<string, string> options
90+
6: optional bool autoAdapt
91+
7: optional map<string, string> options
9192
}
9293

9394
struct TForecastResp {

0 commit comments

Comments
 (0)