Skip to content

Commit

Permalink
[RR-53] handle errors and save to log table
Browse files Browse the repository at this point in the history
  • Loading branch information
tmwclaxton committed Jan 4, 2024
1 parent 9985aa9 commit b97bb3f
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 44 deletions.
163 changes: 119 additions & 44 deletions internal/dispatcher/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,46 +45,123 @@ func Worker(id int, messageQueue <-chan *sqs.Message, svc *sqs.SQS, sqsURL, s3Bu
for {
pass := true

if totalRequests < gracePeriodRequests {
// if worker id is greater than the allowed workers then return
if id > allowedWorkers {
if totalRequests > gracePeriodRequests {
// stop killing grobid >:|
minGapBetweenRequests = 1 * time.Second
}
// if worker id is greater than the allowed workers then return
if id > allowedWorkers {
pass = false
}
if pass {
// Acquire a semaphore before accessing
if err := grobidSemaphore.Acquire(context.Background(), 1); err != nil {
log.Printf("Worker %d could not acquire semaphore: %v\n", id, err)
pass = false
}
if pass {
// Acquire a semaphore before accessing
if err := grobidSemaphore.Acquire(context.Background(), 1); err != nil {
log.Printf("Worker %d could not acquire semaphore: %v\n", id, err)
pass = false
}
lastRequestTimeMu.Lock()
timeSinceLastRequest := time.Since(lastRequestTime)
lastRequestTimeMu.Unlock()
//log.Printf("Worker %d acquired semaphore\n", id)

// If the time since the last request is less than the minimum gap between requests, sleep for the difference
if timeSinceLastRequest < minGapBetweenRequests {
sleepTime := minGapBetweenRequests - timeSinceLastRequest
//log.Printf("Worker %d sleeping for %v to meet the minimum gap between requests\n", id, sleepTime)
time.Sleep(sleepTime)
}
lastRequestTimeMu.Lock()
lastRequestTime = time.Now()
lastRequestTimeMu.Unlock()
grobidSemaphore.Release(1) // Release the semaphore when the function exits

//log.Printf("Worker %d releasing semaphore\n", id)
lastRequestTimeMu.Lock()
timeSinceLastRequest := time.Since(lastRequestTime)
lastRequestTimeMu.Unlock()
//log.Printf("Worker %d acquired semaphore\n", id)

// If the time since the last request is less than the minimum gap between requests, sleep for the difference
if timeSinceLastRequest < minGapBetweenRequests {
sleepTime := minGapBetweenRequests - timeSinceLastRequest
//log.Printf("Worker %d sleeping for %v to meet the minimum gap between requests\n", id, sleepTime)
time.Sleep(sleepTime)
}
lastRequestTimeMu.Lock()
lastRequestTime = time.Now()
lastRequestTimeMu.Unlock()
grobidSemaphore.Release(1) // Release the semaphore when the function exits

//log.Printf("Worker %d releasing semaphore\n", id)
}

if pass {

message := <-messageQueue
processMessage(id, message, svc, sqsURL, s3Bucket, awsRegion, s, cacheSvc)
err := processMessage(id, message, svc, sqsURL, s3Bucket, awsRegion, s, cacheSvc)
if err != nil {
logging.ErrorLogger.Println(err)
err := handleFail(s, cacheSvc, message, svc, sqsURL, err)
if err != nil {
logging.ErrorLogger.Println(err)
}
}
}
time.Sleep(1 * time.Second)
}
}

func handleFail(s *store.Store, cacheSvc *helpers.CacheHelper, message *sqs.Message, sqsSvc *sqs.SQS, sqsURL string, err error) error {
logging.ErrorLogger.Printf("HANDLING FAILED MESSAGE: %s\n", *message.MessageId)

var msgData map[string]interface{}
if err1 := json.Unmarshal([]byte(*message.Body), &msgData); err1 != nil {
return err1
}
// check if message does NOT have the decrement field
decrement, ok := msgData["decrement"]
//logging.ErrorLogger.Printf("Decrement: %v\n", decrement)
if !ok || decrement != true {
logging.ErrorLogger.Printf("Message has not been decremented yet, id: %s\n", *message.MessageId)

userIDTemp := msgData["user_id"].(string)
userID, _ := strconv.ParseInt(userIDTemp, 10, 64)
screenIDTemp := msgData["screen_id"].(string)
screenID, _ := strconv.ParseInt(screenIDTemp, 10, 64)

s3Location := msgData["s3Location"].(string)

// save log to db
logEntry := store.Log{
Level: "error",

UserMessage: fmt.Sprintf("Error processing file: %s", s3Location),
FullLog: err.Error(),
Stage: "paper_processing",
UserID: userID,
ScreenID: screenID,
}

err = s.SaveLog(logEntry)
if err != nil {
return err
}

key := fmt.Sprintf("rapidresearch_cache_:screen:%d:papers_processing", screenID)
// decrement the cache with the screen id
err = cacheSvc.DecrOrDeleteCache(key)
if err != nil {
return err
}
// edit message to include decrement field
msgData["decrement"] = true
// re-encode message
msgJSON, err := json.Marshal(msgData)
if err != nil {
return err
}
// delete message from queue
_, err = sqsSvc.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: aws.String(sqsURL),
ReceiptHandle: message.ReceiptHandle,
})

// send message onto queue
_, err = sqsSvc.SendMessage(&sqs.SendMessageInput{
MessageBody: aws.String(string(msgJSON)),
QueueUrl: aws.String(sqsURL),
})

if err != nil {
return err
}
}
return nil
}

func createAWSSession(region string) *session.Session {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(region),
Expand Down Expand Up @@ -114,31 +191,31 @@ func downloadFileFromS3(s3Svc *s3.S3, bucket, path string) ([]byte, error) {
return fileContent, nil
}

func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket, awsRegion string, s *store.Store, cacheSvc *helpers.CacheHelper) {
func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket, awsRegion string, s *store.Store, cacheSvc *helpers.CacheHelper) error {
defer func() {
totalRequests++
log.Printf("Total requests: %d\n", totalRequests)
}()
var msgData map[string]interface{}
if err := json.Unmarshal([]byte(*message.Body), &msgData); err != nil {
log.Println("Error decoding JSON message:", err)
return
return err
}

// check if message has all the required fields if not return error
if _, ok := msgData["s3Location"]; !ok {
log.Println("Message missing s3Location field")
return
return nil
}

if _, ok := msgData["user_id"]; !ok {
log.Println("Message missing user_id field")
return
return nil
}

if _, ok := msgData["screen_id"]; !ok {
log.Println("Message missing screen_id field")
return
return nil
}

path := msgData["s3Location"].(string)
Expand All @@ -156,7 +233,7 @@ func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket
if err != nil {
log.Println("Error downloading file from S3:", err)
log.Printf("Bucket: %s, Key: %s\n", s3Bucket, path)
return
return err
}

CrudeGrobidResponse, err := parsing.SendPDF2Grobid(fileContent)
Expand All @@ -169,14 +246,14 @@ func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket
os.Exit(1)
}

return
return err
}

// clean up grobid response
tidyGrobidResponse, err := parsing.TidyUpGrobidResponse(CrudeGrobidResponse)
if err != nil {
log.Println("Error tidying up Grobid response:", err)
return
return err
}

crossRefResponse := &parsing.TidyCrossRefResponse{}
Expand Down Expand Up @@ -237,7 +314,7 @@ func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket
paper, err = s.CreatePaper(pdfDTO, userID, screenID)
if err != nil {
logging.ErrorLogger.Println(err)
return
return err
} else {
logging.InfoLogger.Printf("Created paper: %v\n", paper.ID)
}
Expand Down Expand Up @@ -294,20 +371,17 @@ func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket
log.Printf("Sections iterated: %d\n", len(sections))

key := fmt.Sprintf("rapidresearch_cache_:screen:%d:papers_processing", screenID)
log.Printf("Key: %s\n", key)
// print current cache value
cacheValue, err := cacheSvc.GetCacheValue(key)
// print cache value
val, err := cacheSvc.GetCacheValue(key)
if err != nil {
logging.ErrorLogger.Println(err)
return
return err
}
log.Printf("Current cache value: %s\n", cacheValue)
log.Printf("Cache value: %s\n", val)

// decrement the cache with the screen id
err = cacheSvc.DecrOrDeleteCache(key)
if err != nil {
logging.ErrorLogger.Println(err)
return
return err
}

if helpers.GetEnvVariable("REQUEUE_REQUESTS") == "true" {
Expand Down Expand Up @@ -339,4 +413,5 @@ func processMessage(id int, message *sqs.Message, svc *sqs.SQS, sqsURL, s3Bucket
}

log.Printf("Worker %d finished processing message\n", id)
return nil
}
1 change: 1 addition & 0 deletions internal/dispatcher/worker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package dispatcher
30 changes: 30 additions & 0 deletions internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ type Screen struct {
UpdatedAt string `json:"updated_at"`
}

// Log represents a log entry to be saved in the database
type Log struct {
Level string
UserMessage string
FullLog string
Stage string
UserID int64
ScreenID int64
}

// New creates a new Store instance
func New(db *sql.DB) *Store {
return &Store{db: db}
Expand Down Expand Up @@ -193,6 +203,11 @@ func (store *Store) CreatePaper(dto *parsing.PDFDTO, userID int64, screenID int6
logging.WarningLogger.Println("Title or Abstract empty - UserID: " + strconv.FormatInt(userID, 10) + ", ScreenID: " + strconv.FormatInt(screenID, 10) + ", Title: " + dto.Title + ", Abstract: " + dto.Abstract)
}

// if title and doi are empty, return error
if dto.Title == "" && dto.DOI == "" {
return Paper{}, errors.New("CreatePaper: missing required fields: userID: " + strconv.FormatInt(userID, 10) + ", screenID: " + strconv.FormatInt(screenID, 10) + ", title: " + dto.Title + ", abstract: " + dto.Abstract)
}

// create slug
slug := helpers.GenerateRandomString(14)

Expand Down Expand Up @@ -271,3 +286,18 @@ func (store *Store) FindSectionByHeaderAndText(paperID int64, header string, tex
//log.Printf("Section found: %v\n", section.ID)
return section, nil
}

// SaveLog saves a log entry to the database
func (store *Store) SaveLog(logEntry Log) error {
_, err := store.db.Exec(`
INSERT INTO logs (level, user_message, full_log, stage, user_id, screen_id, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, NOW(), NOW())`,
logEntry.Level, logEntry.UserMessage, logEntry.FullLog, logEntry.Stage, logEntry.UserID, logEntry.ScreenID)

if err != nil {
log.Println("Error saving log entry:", err)
return err
}

return nil
}
38 changes: 38 additions & 0 deletions internal/store/store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package store

import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"simple-go-app/internal/logging"
"testing"
)

// test adding a log entry
func TestStore_AddLog(t *testing.T) {
dbHost := "localhost"
dbPort := "3306"
dbUser := "sail"
dbPassword := "password"
dbName := "rapid_research"

db, err := sql.Open("mysql", dbUser+":"+dbPassword+"@tcp("+dbHost+":"+dbPort+")/"+dbName)
if err != nil {
logging.ErrorLogger.Println("Error opening database:", err)
}
s := New(db)

err = s.SaveLog(Log{
Level: "info",
UserMessage: "test",
FullLog: "test",
Stage: "test",
UserID: 1,
ScreenID: 1,
})
if err != nil {
return
}
if err != nil {
t.Errorf("Error adding log entry: %v", err)
}
}

0 comments on commit b97bb3f

Please sign in to comment.