diff --git a/mindnlp/peft/peft_model.py b/mindnlp/peft/peft_model.py
index 402e1076f..4f17d3f1f 100644
--- a/mindnlp/peft/peft_model.py
+++ b/mindnlp/peft/peft_model.py
@@ -125,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
         # if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
         #     self.base_model.config.pretraining_tp = 1
 
-    def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
+    def save_pretrained(self, save_directory, safe_serialization=True, **kwargs):
         r"""
         This function saves the adapter model and the adapter configuration files to a directory, so that it can be
         reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
diff --git a/mindnlp/transformers/models/qwen2/modeling_qwen2.py b/mindnlp/transformers/models/qwen2/modeling_qwen2.py
index c16c2b905..ac3bd80ca 100644
--- a/mindnlp/transformers/models/qwen2/modeling_qwen2.py
+++ b/mindnlp/transformers/models/qwen2/modeling_qwen2.py
@@ -826,13 +826,22 @@ def forward(
             # Shift so that tokens < n predict n
             shift_logits = logits[..., :-1, :]
             shift_labels = labels[..., 1:]
-            # Flatten the tokens
-            loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
-            shift_logits = shift_logits.view(-1, self.config.vocab_size)
-            shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
-            # Enable model parallelism
-            loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
-            loss = loss.mean()
+            if ON_ORANGE_PI:
+                # Flatten the tokens
+                loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
+                shift_logits = shift_logits.view(-1, self.config.vocab_size)
+                shift_labels = nn.functional.one_hot(shift_labels.view(-1), self.config.vocab_size)
+                # Enable model parallelism
+                loss, _ = loss_fct(shift_logits, shift_labels.to(shift_logits.dtype))
+                loss = loss.mean()
+            else:
+                # Flatten the tokens
+                loss_fct = CrossEntropyLoss()
+                shift_logits = shift_logits.view(-1, self.config.vocab_size)
+                shift_labels = shift_labels.view(-1)
+                # Enable model parallelism
+                loss = loss_fct(shift_logits, shift_labels)
+                
 
         if not return_dict:
             output = (logits,) + outputs[1:]
@@ -1004,10 +1013,14 @@ def forward(
                 else:
                     loss = loss_fct(pooled_logits, labels)
             elif self.config.problem_type == "single_label_classification":
-                loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
-                labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
-                loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
-                loss = loss.mean()
+                if ON_ORANGE_PI:
+                    loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
+                    labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
+                    loss, _ = loss_fct(pooled_logits.view(-1, self.num_labels), labels.to(pooled_logits.dtype))
+                    loss = loss.mean()
+                else:
+                    loss_fct = CrossEntropyLoss()
+                    loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
             elif self.config.problem_type == "multi_label_classification":
                 loss_fct = BCEWithLogitsLoss()
                 loss = loss_fct(pooled_logits, labels)
@@ -1086,10 +1099,14 @@ def forward(
 
         loss = None
         if labels is not None:
-            loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
-            labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
-            loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
-            loss = loss.mean()
+            if ON_ORANGE_PI:
+                loss_fct = mindspore.ops.SoftmaxCrossEntropyWithLogits()
+                labels = nn.functional.one_hot(labels.view(-1), self.num_labels)
+                loss, _= loss_fct(logits.view(-1, self.num_labels), labels.to(logits.dtype))
+                loss = loss.mean()
+            else:
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
         if not return_dict:
             output = (logits,) + outputs[2:]