Skip to content
Snippets Groups Projects
Commit 9dda38f2 authored by Travis Ralston's avatar Travis Ralston
Browse files

Update tasks API

parent dcb55ec7
No related branches found
No related tags found
No related merge requests found
...@@ -5,22 +5,22 @@ import ( ...@@ -5,22 +5,22 @@ import (
"github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers" "github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/database"
"net/http" "net/http"
"strconv" "strconv"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/storage"
) )
type TaskStatus struct { type TaskStatus struct {
TaskID int `json:"task_id"` TaskID int `json:"task_id"`
Name string `json:"task_name"` Name string `json:"task_name"`
Params map[string]interface{} `json:"params"` Params *database.AnonymousJson `json:"params"`
StartTs int64 `json:"start_ts"` StartTs int64 `json:"start_ts"`
EndTs int64 `json:"end_ts"` EndTs int64 `json:"end_ts"`
IsFinished bool `json:"is_finished"` IsFinished bool `json:"is_finished"`
} }
func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
...@@ -35,17 +35,20 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn ...@@ -35,17 +35,20 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn
"taskId": taskId, "taskId": taskId,
}) })
db := storage.GetDatabase().GetMetadataStore(rctx) db := database.GetInstance().Tasks.Prepare(rctx)
task, err := db.GetBackgroundTask(taskId) task, err := db.Get(taskId)
if err != nil { if err != nil {
rctx.Log.Error(err) rctx.Log.Error(err)
sentry.CaptureException(err) sentry.CaptureException(err)
return _responses.InternalServerError("failed to get task information") return _responses.InternalServerError("failed to get task information")
} }
if task == nil {
return _responses.NotFoundError()
}
return &_responses.DoNotCacheResponse{Payload: &TaskStatus{ return &_responses.DoNotCacheResponse{Payload: &TaskStatus{
TaskID: task.ID, TaskID: task.TaskId,
Name: task.Name, Name: task.Name,
Params: task.Params, Params: task.Params,
StartTs: task.StartTs, StartTs: task.StartTs,
...@@ -55,9 +58,9 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn ...@@ -55,9 +58,9 @@ func GetTask(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserIn
} }
func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
db := storage.GetDatabase().GetMetadataStore(rctx) db := database.GetInstance().Tasks.Prepare(rctx)
tasks, err := db.GetAllBackgroundTasks() tasks, err := db.GetAll(true)
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
sentry.CaptureException(err) sentry.CaptureException(err)
...@@ -67,7 +70,7 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U ...@@ -67,7 +70,7 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U
statusObjs := make([]*TaskStatus, 0) statusObjs := make([]*TaskStatus, 0)
for _, task := range tasks { for _, task := range tasks {
statusObjs = append(statusObjs, &TaskStatus{ statusObjs = append(statusObjs, &TaskStatus{
TaskID: task.ID, TaskID: task.TaskId,
Name: task.Name, Name: task.Name,
Params: task.Params, Params: task.Params,
StartTs: task.StartTs, StartTs: task.StartTs,
...@@ -80,9 +83,9 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U ...@@ -80,9 +83,9 @@ func ListAllTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.U
} }
func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
db := storage.GetDatabase().GetMetadataStore(rctx) db := database.GetInstance().Tasks.Prepare(rctx)
tasks, err := db.GetAllBackgroundTasks() tasks, err := db.GetAll(false)
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
sentry.CaptureException(err) sentry.CaptureException(err)
...@@ -91,11 +94,8 @@ func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _ap ...@@ -91,11 +94,8 @@ func ListUnfinishedTasks(r *http.Request, rctx rcontext.RequestContext, user _ap
statusObjs := make([]*TaskStatus, 0) statusObjs := make([]*TaskStatus, 0)
for _, task := range tasks { for _, task := range tasks {
if task.EndTs > 0 {
continue
}
statusObjs = append(statusObjs, &TaskStatus{ statusObjs = append(statusObjs, &TaskStatus{
TaskID: task.ID, TaskID: task.TaskId,
Name: task.Name, Name: task.Name,
Params: task.Params, Params: task.Params,
StartTs: task.StartTs, StartTs: task.StartTs,
......
...@@ -17,10 +17,14 @@ type DbTask struct { ...@@ -17,10 +17,14 @@ type DbTask struct {
const selectTask = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE id = $1;" const selectTask = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE id = $1;"
const insertTask = "INSERT INTO background_tasks (task, params, start_ts, end_ts) VALUES ($1, $2, $3, 0) RETURNING id, task, params, start_ts, end_ts;" const insertTask = "INSERT INTO background_tasks (task, params, start_ts, end_ts) VALUES ($1, $2, $3, 0) RETURNING id, task, params, start_ts, end_ts;"
const selectAllTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks;"
const selectIncompleteTasks = "SELECT id, task, params, start_ts, end_ts FROM background_tasks WHERE end_ts <= 0;"
type tasksTableStatements struct { type tasksTableStatements struct {
selectTask *sql.Stmt selectTask *sql.Stmt
insertTask *sql.Stmt insertTask *sql.Stmt
selectAllTasks *sql.Stmt
selectIncompleteTasks *sql.Stmt
} }
type tasksTableWithContext struct { type tasksTableWithContext struct {
...@@ -38,6 +42,12 @@ func prepareTasksTables(db *sql.DB) (*tasksTableStatements, error) { ...@@ -38,6 +42,12 @@ func prepareTasksTables(db *sql.DB) (*tasksTableStatements, error) {
if stmts.insertTask, err = db.Prepare(insertTask); err != nil { if stmts.insertTask, err = db.Prepare(insertTask); err != nil {
return nil, errors.New("error preparing insertTask: " + err.Error()) return nil, errors.New("error preparing insertTask: " + err.Error())
} }
if stmts.selectAllTasks, err = db.Prepare(selectAllTasks); err != nil {
return nil, errors.New("error preparing selectAllTasks: " + err.Error())
}
if stmts.selectIncompleteTasks, err = db.Prepare(selectIncompleteTasks); err != nil {
return nil, errors.New("error preparing selectIncompleteTasks: " + err.Error())
}
return stmts, nil return stmts, nil
} }
...@@ -69,3 +79,26 @@ func (s *tasksTableWithContext) Get(id int) (*DbTask, error) { ...@@ -69,3 +79,26 @@ func (s *tasksTableWithContext) Get(id int) (*DbTask, error) {
} }
return val, err return val, err
} }
func (s *tasksTableWithContext) GetAll(includingFinished bool) ([]*DbTask, error) {
results := make([]*DbTask, 0)
q := s.statements.selectAllTasks
if !includingFinished {
q = s.statements.selectIncompleteTasks
}
rows, err := q.QueryContext(s.ctx)
if err != nil {
if err == sql.ErrNoRows {
return results, nil
}
return nil, err
}
for rows.Next() {
val := &DbTask{}
if err = rows.Scan(&val.TaskId, &val.Name, &val.Params, &val.StartTs, &val.EndTs); err != nil {
return nil, err
}
results = append(results, val)
}
return results, nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment