Skip to content

Commit

Permalink
update version 2.0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
sunhailin committed Jul 9, 2024
1 parent d418c7a commit 4dc88a0
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ func main() {

### Version

* version 2.0.4 - 2024/07/09
* Update `W2NER` input feature problem.(Missing `MaxSeqLength` config)
* Code style fix. Reducing nil cases
* Add `slice.StringSliceTruncatePrecisely` function for logic to handle [][] string data truncation.

* version 2.0.3 - 2024/07/08
* Fix `w2ner.pieces2word` nil slice caused infer error.

Expand Down
2 changes: 2 additions & 0 deletions models/transformers/bert_w2ner.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ func (w *W2NerModelService) getBertInputFeature(batchInferData [][]string) []*W2
for j, token := range inferData {
tokens[j] = w.getTokenizerResult(token)
}
// The minus 2 is due to the retention of the CLS and SEP positions.
tokens = utils.StringSliceTruncatePrecisely(tokens, w.MaxSeqLength-2)
batchInferTokens[i] = tokens
batchInferPieces[i] = utils.Flatten2DSlice(tokens)
batchInputFeatures[i] = &W2NERInputFeature{}
Expand Down
60 changes: 60 additions & 0 deletions nvidia_inferenceserver/triton_service_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"time"

Expand Down Expand Up @@ -291,7 +292,14 @@ func (t *TritonClientService) ModelHTTPInfer(
timeout)
defer fasthttp.ReleaseResponse(modelInferResponse)

if modelInferResponse == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, errors.New("modelInferResponse is nil"))
}

if inferErr != nil || modelInferResponse.StatusCode() != fasthttp.StatusOK {
if inferErr == nil && modelInferResponse.Body() != nil {
inferErr = errors.New("Triton error resp: " + string(modelInferResponse.Body()))
}
return nil, t.httpErrorHandler(modelInferResponse.StatusCode(), inferErr)
}
// decode Result.
Expand Down Expand Up @@ -341,6 +349,9 @@ func (t *TritonClientService) CheckServerAlive(timeout time.Duration) (bool, err
}
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsLive, nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return false, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -362,6 +373,9 @@ func (t *TritonClientService) CheckServerReady(timeout time.Duration) (bool, err
}
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsReady, nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return false, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -386,6 +400,9 @@ func (t *TritonClientService) CheckModelReady(modelName, modelVersion string, ti
t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion+"/ready",
nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return false, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -404,6 +421,9 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta
}
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIPrefix, nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -431,6 +451,9 @@ func (t *TritonClientService) ModelMetadataRequest(
t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion,
nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -460,6 +483,9 @@ func (t *TritonClientService) ModelIndex(
}
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForRepoIndex, reqBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -486,6 +512,9 @@ func (t *TritonClientService) ModelConfiguration(
t.getServerURL()+TritonAPIForModelPrefix+modelName+
TritonAPIForModelVersionPrefix+modelVersion+"/config", timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -512,6 +541,9 @@ func (t *TritonClientService) ModelInferStats(
t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion+"/stats",
timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -532,6 +564,9 @@ func (t *TritonClientService) ModelLoadWithHTTP(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/load", modelConfigBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -564,6 +599,10 @@ func (t *TritonClientService) ModelUnloadWithHTTP(
) (*RepositoryModelUnloadResponse, error) {
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/unload", modelConfigBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -624,6 +663,9 @@ func (t *TritonClientService) ShareMemoryStatus(
}
apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout(uri, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -669,6 +711,9 @@ func (t *TritonClientService) ShareCUDAMemoryRegister(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/register", reqBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -695,6 +740,9 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/unregister", nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -732,6 +780,9 @@ func (t *TritonClientService) ShareSystemMemoryRegister(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/register", reqBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -758,6 +809,9 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/unregister", nil, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand All @@ -784,6 +838,9 @@ func (t *TritonClientService) GetModelTracingSetting(
apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout(
t.getServerURL()+TritonAPIForModelPrefix+modelName+"/trace/setting", timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down Expand Up @@ -816,6 +873,9 @@ func (t *TritonClientService) SetModelTracingSetting(
apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(
t.getServerURL()+TritonAPIForModelPrefix+modelName+"/trace/setting", reqBody, timeout)
defer fasthttp.ReleaseResponse(apiResp)
if apiResp == nil {
return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil)
}
if httpErr != nil || apiResp.StatusCode() != fasthttp.StatusOK {
return nil, t.httpErrorHandler(apiResp.StatusCode(), httpErr)
}
Expand Down
22 changes: 22 additions & 0 deletions test/slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,25 @@ func TestGenerateRange(t *testing.T) {
t.Errorf("Test case 1 failed. Expected %v, got %v", expected1, result1)
}
}

func TestRemoveOldestElements(t *testing.T) {
input := [][]string{
{"a", "b", "c"},
{"d", "e", "f", "g"},
{"h", "i"},
{"j", "k", "l", "m", "n"},
{"o", "p", "q", "r", "s", "t", "u", "v", "w"},
{"x", "y", "z"},
}
testMaxLen := 12
result := utils.StringSliceTruncatePrecisely(input, testMaxLen)
expected := [][]string{
{"a", "b", "c"},
{"d", "e", "f", "g"},
{"h", "i"},
{"j", "k", "l"},
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Test case 1 failed. Expected %v, got %v", expected, result)
}
}
1 change: 1 addition & 0 deletions utils/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
)

var (
ErrApiRespNil = errors.New("apiResp is nil") // empty http response body.
ErrEmptyVocab = errors.New("empty vocab") // empty vocab error.
ErrEmptyCallbackFunc = errors.New("callback function is nil") // empty callback function.
ErrEmptyHTTPRequestBody = errors.New("http request body is nil") // empty http request body.
Expand Down
40 changes: 40 additions & 0 deletions utils/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,43 @@ func GenerateRange[T IntNumeric](start, end int) []T {

return result
}

// StringSliceTruncatePrecisely Truncation control granularity at sub-element level
// More precise than StringSliceTruncate
func StringSliceTruncatePrecisely(slices [][]string, maxLen int) [][]string {
// count total length
totalLen := 0
for _, slice := range slices {
totalLen += len(slice)
}

// early return
if totalLen < maxLen {
return slices
}

// If the total length exceeds maxLen
// remove the children one by one, starting from the end.
if totalLen > maxLen {
removeCount := totalLen - maxLen

// Delete elements from the end
for removeCount > 0 {
lastSliceIndex := len(slices) - 1
lastSlice := slices[lastSliceIndex]

if len(lastSlice) <= removeCount {
// If the length of the last sub-slice is less than or
// equal to the number to be deleted, delete the entire sub-slice
removeCount -= len(lastSlice)
slices = slices[:lastSliceIndex]
} else {
// Otherwise only the required number of elements are deleted
slices[lastSliceIndex] = lastSlice[:len(lastSlice)-removeCount]
removeCount = 0
}
}
}

return slices
}

0 comments on commit 4dc88a0

Please sign in to comment.