From 27f62d1d8a0817a53493ab37f08c0c2ebf83d674 Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Mon, 19 Jun 2023 10:47:15 +0200 Subject: [PATCH] Route build log requests through API So that we can add authorization --- api/graph/resolver.go | 27 +++++++++--- api/graph/schema.graphqls | 4 +- api/graph/schema.resolvers.go | 6 +-- api/server.go | 77 +++++++++++++++++++++++++++++++++-- buildsrht/blueprints/jobs.py | 16 +++++--- go.mod | 1 + 6 files changed, 110 insertions(+), 21 deletions(-) diff --git a/api/graph/resolver.go b/api/graph/resolver.go index 9e30065..632395b 100644 --- a/api/graph/resolver.go +++ b/api/graph/resolver.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" + "git.sr.ht/~sircmpwn/core-go/config" "github.com/99designs/gqlgen/graphql" "git.sr.ht/~sircmpwn/builds.sr.ht/api/graph/model" @@ -15,8 +16,22 @@ import ( type Resolver struct{} -func FetchLogs(ctx context.Context, url string) (*model.Log, error) { - log := &model.Log{FullURL: url} +func FetchLogs(ctx context.Context, runner string, jobID int, taskName string) (*model.Log, error) { + conf := config.ForContext(ctx) + origin := config.GetOrigin(conf, "builds.sr.ht", true) + + var ( + externalURL string + internalURL string + ) + if taskName == "" { + externalURL = fmt.Sprintf("%s/query/log/%d/log", origin, jobID) + internalURL = fmt.Sprintf("http://%s/logs/%d/log", runner, jobID) + } else { + externalURL = fmt.Sprintf("%s/query/log/%d/%s/log", origin, jobID, taskName) + internalURL = fmt.Sprintf("http://%s/logs/%d/%s/log", runner, jobID, taskName) + } + log := &model.Log{FullURL: externalURL} // If the user hasn't requested the log body, stop here if graphql.GetFieldContext(ctx) != nil { @@ -32,10 +47,10 @@ func FetchLogs(ctx context.Context, url string) (*model.Log, error) { } } - // TODO: It might be possible/desirable to set up an API with the runners - // we can use to fetch logs in bulk, perhaps gzipped, and set up a loader - // for it. - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + // TODO: It might be possible/desirable to set up an API with the + // runners we can use to fetch logs in bulk, perhaps gzipped, and set + // up a loader for it. + req, err := http.NewRequestWithContext(ctx, "GET", internalURL, nil) if err != nil { return nil, err } diff --git a/api/graph/schema.graphqls b/api/graph/schema.graphqls index a5b968d..f7d62e8 100644 --- a/api/graph/schema.graphqls +++ b/api/graph/schema.graphqls @@ -144,8 +144,8 @@ type Log { "The most recently written 128 KiB of the build log." last128KiB: String! """ - The URL at which the full build log can be downloaded with a GET request - (text/plain). + The URL at which the full build log can be downloaded with an authenticated + GET request (text/plain). """ fullURL: String! } diff --git a/api/graph/schema.resolvers.go b/api/graph/schema.resolvers.go index 30cc17d..b28b0c6 100644 --- a/api/graph/schema.resolvers.go +++ b/api/graph/schema.resolvers.go @@ -124,8 +124,7 @@ func (r *jobResolver) Log(ctx context.Context, obj *model.Job) (*model.Log, erro if obj.Runner == nil { return nil, nil } - url := fmt.Sprintf("http://%s/logs/%d/log", *obj.Runner, obj.ID) - return FetchLogs(ctx, url) + return FetchLogs(ctx, *obj.Runner, obj.ID, "") } // Secrets is the resolver for the secrets field. @@ -926,8 +925,7 @@ func (r *taskResolver) Log(ctx context.Context, obj *model.Task) (*model.Log, er if obj.Runner == nil { return nil, nil } - url := fmt.Sprintf("http://%s/logs/%d/%s/log", *obj.Runner, obj.JobID, obj.Name) - return FetchLogs(ctx, url) + return FetchLogs(ctx, *obj.Runner, obj.JobID, obj.Name) } // Job is the resolver for the job field. diff --git a/api/server.go b/api/server.go index f93ebdb..c1c53cb 100644 --- a/api/server.go +++ b/api/server.go @@ -3,12 +3,17 @@ package main import ( "context" "fmt" + "io" + "log" + "net/http" + "strconv" "git.sr.ht/~sircmpwn/core-go/config" "git.sr.ht/~sircmpwn/core-go/server" "git.sr.ht/~sircmpwn/core-go/webhooks" work "git.sr.ht/~sircmpwn/dowork" "github.com/99designs/gqlgen/graphql" + "github.com/go-chi/chi" "git.sr.ht/~sircmpwn/builds.sr.ht/api/account" "git.sr.ht/~sircmpwn/builds.sr.ht/api/graph" @@ -42,7 +47,7 @@ func main() { accountQueue := work.NewQueue("account") webhookQueue := webhooks.NewQueue(schema) - server.NewServer("builds.sr.ht", appConfig). + srv := server.NewServer("builds.sr.ht", appConfig). WithDefaultMiddleware(). WithMiddleware( loaders.Middleware, @@ -53,6 +58,72 @@ func main() { WithQueues( accountQueue, webhookQueue.Queue, - ). - Run() + ) + + srv.Router().Head("/query/log/{job_id}/log", proxyLog) + srv.Router().Head("/query/log/{job_id}/{task_name}/log", proxyLog) + srv.Router().Get("/query/log/{job_id}/log", proxyLog) + srv.Router().Get("/query/log/{job_id}/{task_name}/log", proxyLog) + srv.Run() +} + +func proxyLog(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + jobId, err := strconv.Atoi(chi.URLParam(r, "job_id")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid job ID\r\n")) + return + } + job, err := loaders.ForContext(ctx).JobsByID.Load(jobId) + if err != nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("Unknown build job\r\n")) + return + } + if job.Runner == nil { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("This build job has not been started yet\r\n")) + return + } + + var url string + taskName := chi.URLParam(r, "task_name") + if taskName == "" { + url = fmt.Sprintf("http://%s/logs/%d/log", *job.Runner, job.ID) + } else { + url = fmt.Sprintf("http://%s/logs/%d/%s/log", + *job.Runner, job.ID, taskName) + } + req, err := http.NewRequestWithContext(ctx, r.Method, url, nil) + if err != nil { + log.Printf("Error fetching logs: %s", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error\r\n")) + return + } + + rrange := r.Header.Get("Range") + if rrange != "" { + req.Header.Add("Range", rrange) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("Failed to retrieve build log\r\n")) + return + } + defer resp.Body.Close() + for key, val := range resp.Header { + for _, val := range val { + w.Header().Add(key, val) + } + } + w.WriteHeader(resp.StatusCode) + + _, err = io.Copy(w, resp.Body) + if err != nil { + log.Printf("Error forwarding log: %s", err.Error()) + } } diff --git a/buildsrht/blueprints/jobs.py b/buildsrht/blueprints/jobs.py index 12d3a62..d37c1a5 100644 --- a/buildsrht/blueprints/jobs.py +++ b/buildsrht/blueprints/jobs.py @@ -10,7 +10,8 @@ from flask import Response, url_for from markupsafe import Markup, escape from prometheus_client import Counter from srht.cache import get_cache, set_cache -from srht.config import cfg +from srht.config import cfg, get_origin +from srht.crypto import encrypt_request_authorization from srht.database import db from srht.flask import paginate_query, session from srht.oauth import current_user, loginrequired, UserType @@ -448,14 +449,17 @@ def job_by_id(username, job_id): if not log: metrics.buildsrht_logcache_miss.inc() try: - r = requests_session.head(log_url) + r = requests_session.head(log_url, + headers=encrypt_request_authorization()) cl = int(r.headers["Content-Length"]) if cl > log_max: r = requests_session.get(log_url, headers={ "Range": f"bytes={cl-log_max}-{cl-1}", + **encrypt_request_authorization(), }, timeout=3) else: - r = requests_session.get(log_url, timeout=3) + r = requests_session.get(log_url, timeout=3, + headers=encrypt_request_authorization()) if r.status_code >= 200 and r.status_code <= 299: log = { "name": name, @@ -477,13 +481,13 @@ def job_by_id(username, job_id): set_cache(cachekey, timedelta(days=2), json.dumps(log)) logs.append(log) return log["more"] - log_url = "http://{}/logs/{}/log".format(job.runner, job.id) + origin = get_origin("builds.sr.ht") + log_url = f"{origin}/query/log/{job.id}/log" if get_log(log_url, None, job.status): for task in sorted(job.tasks, key=lambda t: t.id): if task.status == TaskStatus.pending: continue - log_url = "http://{}/logs/{}/{}/log".format( - job.runner, job.id, task.name) + log_url = f"{origin}/query/log/{job.id}/{task.name}/log" if not get_log(log_url, task.name, task.status): break min_artifact_date = datetime.utcnow() - timedelta(days=90) diff --git a/go.mod b/go.mod index 8b19756..095eaed 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead // indirect github.com/emersion/go-smtp v0.16.0 // indirect github.com/fernet/fernet-go v0.0.0-20211208181803-9f70042a33ee // indirect + github.com/go-chi/chi v4.1.2+incompatible github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/gocelery/gocelery v0.0.0-20201111034804-825d89059344 github.com/google/uuid v1.3.0 -- 2.38.5