diff options
Diffstat (limited to 'modules/web')
-rw-r--r-- | modules/web/handler.go | 193 | ||||
-rw-r--r-- | modules/web/middleware/binding.go | 162 | ||||
-rw-r--r-- | modules/web/middleware/cookie.go | 85 | ||||
-rw-r--r-- | modules/web/middleware/data.go | 63 | ||||
-rw-r--r-- | modules/web/middleware/flash.go | 65 | ||||
-rw-r--r-- | modules/web/middleware/locale.go | 59 | ||||
-rw-r--r-- | modules/web/middleware/request.go | 14 | ||||
-rw-r--r-- | modules/web/route.go | 211 | ||||
-rw-r--r-- | modules/web/route_test.go | 179 | ||||
-rw-r--r-- | modules/web/routemock.go | 61 | ||||
-rw-r--r-- | modules/web/routemock_test.go | 71 | ||||
-rw-r--r-- | modules/web/routing/context.go | 49 | ||||
-rw-r--r-- | modules/web/routing/funcinfo.go | 172 | ||||
-rw-r--r-- | modules/web/routing/funcinfo_test.go | 80 | ||||
-rw-r--r-- | modules/web/routing/logger.go | 109 | ||||
-rw-r--r-- | modules/web/routing/logger_manager.go | 124 | ||||
-rw-r--r-- | modules/web/routing/requestrecord.go | 28 | ||||
-rw-r--r-- | modules/web/types/response.go | 10 |
18 files changed, 1735 insertions, 0 deletions
diff --git a/modules/web/handler.go b/modules/web/handler.go new file mode 100644 index 00000000..728cc5a1 --- /dev/null +++ b/modules/web/handler.go @@ -0,0 +1,193 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + goctx "context" + "fmt" + "net/http" + "reflect" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/web/routing" + "code.gitea.io/gitea/modules/web/types" +) + +var responseStatusProviders = map[reflect.Type]func(req *http.Request) types.ResponseStatusProvider{} + +func RegisterResponseStatusProvider[T any](fn func(req *http.Request) types.ResponseStatusProvider) { + responseStatusProviders[reflect.TypeOf((*T)(nil)).Elem()] = fn +} + +// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written +type responseWriter struct { + respWriter http.ResponseWriter + status int +} + +var _ types.ResponseStatusProvider = (*responseWriter)(nil) + +func (r *responseWriter) WrittenStatus() int { + return r.status +} + +func (r *responseWriter) Header() http.Header { + return r.respWriter.Header() +} + +func (r *responseWriter) Write(bytes []byte) (int, error) { + if r.status == 0 { + r.status = http.StatusOK + } + return r.respWriter.Write(bytes) +} + +func (r *responseWriter) WriteHeader(statusCode int) { + r.status = statusCode + r.respWriter.WriteHeader(statusCode) +} + +var ( + httpReqType = reflect.TypeOf((*http.Request)(nil)) + respWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() + cancelFuncType = reflect.TypeOf((*goctx.CancelFunc)(nil)).Elem() +) + +// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup +func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { + hasStatusProvider := false + for _, argIn := range argsIn { + if _, hasStatusProvider = argIn.Interface().(types.ResponseStatusProvider); hasStatusProvider { + break + } + } + if !hasStatusProvider { + panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type())) + } + if fn.Type().NumOut() != 0 && fn.Type().NumIn() != 1 { + panic(fmt.Sprintf("handler should have no return value or only one argument, but got %s", fn.Type())) + } + if fn.Type().NumOut() == 1 && fn.Type().Out(0) != cancelFuncType { + panic(fmt.Sprintf("handler should return a cancel function, but got %s", fn.Type())) + } +} + +func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value { + defer func() { + if err := recover(); err != nil { + log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err) + panic(err) + } + }() + isPreCheck := req == nil + + argsIn := make([]reflect.Value, fn.Type().NumIn()) + for i := 0; i < fn.Type().NumIn(); i++ { + argTyp := fn.Type().In(i) + switch argTyp { + case respWriterType: + argsIn[i] = reflect.ValueOf(resp) + case httpReqType: + argsIn[i] = reflect.ValueOf(req) + default: + if argFn, ok := responseStatusProviders[argTyp]; ok { + if isPreCheck { + argsIn[i] = reflect.ValueOf(&responseWriter{}) + } else { + argsIn[i] = reflect.ValueOf(argFn(req)) + } + } else { + panic(fmt.Sprintf("unsupported argument type: %s", argTyp)) + } + } + } + return argsIn +} + +func handleResponse(fn reflect.Value, ret []reflect.Value) goctx.CancelFunc { + if len(ret) == 1 { + if cancelFunc, ok := ret[0].Interface().(goctx.CancelFunc); ok { + return cancelFunc + } + panic(fmt.Sprintf("unsupported return type: %s", ret[0].Type())) + } else if len(ret) > 1 { + panic(fmt.Sprintf("unsupported return values: %s", fn.Type())) + } + return nil +} + +func hasResponseBeenWritten(argsIn []reflect.Value) bool { + for _, argIn := range argsIn { + if statusProvider, ok := argIn.Interface().(types.ResponseStatusProvider); ok { + if statusProvider.WrittenStatus() != 0 { + return true + } + } + } + return false +} + +// toHandlerProvider converts a handler to a handler provider +// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware +func toHandlerProvider(handler any) func(next http.Handler) http.Handler { + funcInfo := routing.GetFuncInfo(handler) + fn := reflect.ValueOf(handler) + if fn.Type().Kind() != reflect.Func { + panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type())) + } + + if hp, ok := handler.(func(next http.Handler) http.Handler); ok { + return func(next http.Handler) http.Handler { + h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + routing.UpdateFuncInfo(req.Context(), funcInfo) + h.ServeHTTP(resp, req) + }) + } + } + + if hp, ok := handler.(func(next http.Handler) http.HandlerFunc); ok { + return func(next http.Handler) http.Handler { + h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + routing.UpdateFuncInfo(req.Context(), funcInfo) + h.ServeHTTP(resp, req) + }) + } + } + + provider := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) { + // wrap the response writer to check whether the response has been written + resp := respOrig + if _, ok := resp.(types.ResponseStatusProvider); !ok { + resp = &responseWriter{respWriter: resp} + } + + // prepare the arguments for the handler and do pre-check + argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo) + if req == nil { + preCheckHandler(fn, argsIn) + return // it's doing pre-check, just return + } + + routing.UpdateFuncInfo(req.Context(), funcInfo) + ret := fn.Call(argsIn) + + // handle the return value, and defer the cancel function if there is one + cancelFunc := handleResponse(fn, ret) + if cancelFunc != nil { + defer cancelFunc() + } + + // if the response has not been written, call the next handler + if next != nil && !hasResponseBeenWritten(argsIn) { + next.ServeHTTP(resp, req) + } + }) + } + + provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported + return provider +} diff --git a/modules/web/middleware/binding.go b/modules/web/middleware/binding.go new file mode 100644 index 00000000..8fa71a81 --- /dev/null +++ b/modules/web/middleware/binding.go @@ -0,0 +1,162 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "reflect" + "strings" + + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/translation" + "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/validation" + + "gitea.com/go-chi/binding" +) + +// Form form binding interface +type Form interface { + binding.Validator +} + +func init() { + binding.SetNameMapper(util.ToSnakeCase) +} + +// AssignForm assign form values back to the template data. +func AssignForm(form any, data map[string]any) { + typ := reflect.TypeOf(form) + val := reflect.ValueOf(form) + + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + fieldName := field.Tag.Get("form") + // Allow ignored fields in the struct + if fieldName == "-" { + continue + } else if len(fieldName) == 0 { + fieldName = util.ToSnakeCase(field.Name) + } + + data[fieldName] = val.Field(i).Interface() + } +} + +func getRuleBody(field reflect.StructField, prefix string) string { + for _, rule := range strings.Split(field.Tag.Get("binding"), ";") { + if strings.HasPrefix(rule, prefix) { + return rule[len(prefix) : len(rule)-1] + } + } + return "" +} + +// GetSize get size int form tag +func GetSize(field reflect.StructField) string { + return getRuleBody(field, "Size(") +} + +// GetMinSize get minimal size in form tag +func GetMinSize(field reflect.StructField) string { + return getRuleBody(field, "MinSize(") +} + +// GetMaxSize get max size in form tag +func GetMaxSize(field reflect.StructField) string { + return getRuleBody(field, "MaxSize(") +} + +// GetInclude get include in form tag +func GetInclude(field reflect.StructField) string { + return getRuleBody(field, "Include(") +} + +// Validate populates the data with validation error (if any). +func Validate(errs binding.Errors, data map[string]any, f any, l translation.Locale) binding.Errors { + if errs.Len() == 0 { + return errs + } + + data["HasError"] = true + // If the field with name errs[0].FieldNames[0] is not found in form + // somehow, some code later on will panic on Data["ErrorMsg"].(string). + // So initialize it to some default. + data["ErrorMsg"] = l.Tr("form.unknown_error") + AssignForm(f, data) + + typ := reflect.TypeOf(f) + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if field, ok := typ.FieldByName(errs[0].FieldNames[0]); ok { + fieldName := field.Tag.Get("form") + if fieldName != "-" { + data["Err_"+field.Name] = true + + trName := field.Tag.Get("locale") + if len(trName) == 0 { + trName = l.TrString("form." + field.Name) + } else { + trName = l.TrString(trName) + } + + switch errs[0].Classification { + case binding.ERR_REQUIRED: + data["ErrorMsg"] = trName + l.TrString("form.require_error") + case binding.ERR_ALPHA_DASH: + data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error") + case binding.ERR_ALPHA_DASH_DOT: + data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error") + case validation.ErrGitRefName: + data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error") + case binding.ERR_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field)) + case binding.ERR_MIN_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field)) + case binding.ERR_MAX_SIZE: + data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field)) + case binding.ERR_EMAIL: + data["ErrorMsg"] = trName + l.TrString("form.email_error") + case binding.ERR_URL: + data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message) + case binding.ERR_INCLUDE: + data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field)) + case validation.ErrGlobPattern: + data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message) + case validation.ErrRegexPattern: + data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message) + case validation.ErrUsername: + if setting.Service.AllowDotsInUsernames { + data["ErrorMsg"] = trName + l.TrString("form.username_error") + } else { + data["ErrorMsg"] = trName + l.TrString("form.username_error_no_dots") + } + case validation.ErrInvalidGroupTeamMap: + data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message) + default: + msg := errs[0].Classification + if msg != "" && errs[0].Message != "" { + msg += ": " + } + + msg += errs[0].Message + if msg == "" { + msg = l.TrString("form.unknown_error") + } + data["ErrorMsg"] = trName + ": " + msg + } + return errs + } + } + return errs +} diff --git a/modules/web/middleware/cookie.go b/modules/web/middleware/cookie.go new file mode 100644 index 00000000..f2d25f5b --- /dev/null +++ b/modules/web/middleware/cookie.go @@ -0,0 +1,85 @@ +// Copyright 2020 The Macaron Authors +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "net/http" + "net/url" + "strings" + + "code.gitea.io/gitea/modules/session" + "code.gitea.io/gitea/modules/setting" +) + +// SetRedirectToCookie convenience function to set the RedirectTo cookie consistently +func SetRedirectToCookie(resp http.ResponseWriter, value string) { + SetSiteCookie(resp, "redirect_to", value, 0) +} + +// DeleteRedirectToCookie convenience function to delete most cookies consistently +func DeleteRedirectToCookie(resp http.ResponseWriter) { + SetSiteCookie(resp, "redirect_to", "", -1) +} + +// GetSiteCookie returns given cookie value from request header. +func GetSiteCookie(req *http.Request, name string) string { + cookie, err := req.Cookie(name) + if err != nil { + return "" + } + val, _ := url.QueryUnescape(cookie.Value) + return val +} + +// SetSiteCookie returns given cookie value from request header. +func SetSiteCookie(resp http.ResponseWriter, name, value string, maxAge int) { + // Previous versions would use a cookie path with a trailing /. + // These are more specific than cookies without a trailing /, so + // we need to delete these if they exist. + deleteLegacySiteCookie(resp, name) + cookie := &http.Cookie{ + Name: name, + Value: url.QueryEscape(value), + MaxAge: maxAge, + Path: setting.SessionConfig.CookiePath, + Domain: setting.SessionConfig.Domain, + Secure: setting.SessionConfig.Secure, + HttpOnly: true, + SameSite: setting.SessionConfig.SameSite, + } + resp.Header().Add("Set-Cookie", cookie.String()) +} + +// deleteLegacySiteCookie deletes the cookie with the given name at the cookie +// path with a trailing /, which would unintentionally override the cookie. +func deleteLegacySiteCookie(resp http.ResponseWriter, name string) { + if setting.SessionConfig.CookiePath == "" || strings.HasSuffix(setting.SessionConfig.CookiePath, "/") { + // If the cookie path ends with /, no legacy cookies will take + // precedence, so do nothing. The exception is that cookies with no + // path could override other cookies, but it's complicated and we don't + // currently handle that. + return + } + + cookie := &http.Cookie{ + Name: name, + Value: "", + MaxAge: -1, + Path: setting.SessionConfig.CookiePath + "/", + Domain: setting.SessionConfig.Domain, + Secure: setting.SessionConfig.Secure, + HttpOnly: true, + SameSite: setting.SessionConfig.SameSite, + } + resp.Header().Add("Set-Cookie", cookie.String()) +} + +func init() { + session.BeforeRegenerateSession = append(session.BeforeRegenerateSession, func(resp http.ResponseWriter, _ *http.Request) { + // Ensure that a cookie with a trailing slash does not take precedence over + // the cookie written by the middleware. + deleteLegacySiteCookie(resp, setting.SessionConfig.CookieName) + }) +} diff --git a/modules/web/middleware/data.go b/modules/web/middleware/data.go new file mode 100644 index 00000000..08d83f94 --- /dev/null +++ b/modules/web/middleware/data.go @@ -0,0 +1,63 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "context" + "time" + + "code.gitea.io/gitea/modules/setting" +) + +// ContextDataStore represents a data store +type ContextDataStore interface { + GetData() ContextData +} + +type ContextData map[string]any + +func (ds ContextData) GetData() ContextData { + return ds +} + +func (ds ContextData) MergeFrom(other ContextData) ContextData { + for k, v := range other { + ds[k] = v + } + return ds +} + +const ContextDataKeySignedUser = "SignedUser" + +type contextDataKeyType struct{} + +var contextDataKey contextDataKeyType + +func WithContextData(c context.Context) context.Context { + return context.WithValue(c, contextDataKey, make(ContextData, 10)) +} + +func GetContextData(c context.Context) ContextData { + if ds, ok := c.Value(contextDataKey).(ContextData); ok { + return ds + } + return nil +} + +func CommonTemplateContextData() ContextData { + return ContextData{ + "IsLandingPageOrganizations": setting.LandingPageURL == setting.LandingPageOrganizations, + + "ShowRegistrationButton": setting.Service.ShowRegistrationButton, + "ShowMilestonesDashboardPage": setting.Service.ShowMilestonesDashboardPage, + "ShowFooterVersion": setting.Other.ShowFooterVersion, + "DisableDownloadSourceArchives": setting.Repository.DisableDownloadSourceArchives, + + "EnableSwagger": setting.API.EnableSwagger, + "EnableOpenIDSignIn": setting.Service.EnableOpenIDSignIn, + "PageStartTime": time.Now(), + + "RunModeIsProd": setting.IsProd, + } +} diff --git a/modules/web/middleware/flash.go b/modules/web/middleware/flash.go new file mode 100644 index 00000000..88da2049 --- /dev/null +++ b/modules/web/middleware/flash.go @@ -0,0 +1,65 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "fmt" + "html/template" + "net/url" +) + +// Flash represents a one time data transfer between two requests. +type Flash struct { + DataStore ContextDataStore + url.Values + ErrorMsg, WarningMsg, InfoMsg, SuccessMsg string +} + +func (f *Flash) set(name, msg string, current ...bool) { + if f.Values == nil { + f.Values = make(map[string][]string) + } + showInCurrentPage := len(current) > 0 && current[0] + if showInCurrentPage { + // assign it to the context data, then the template can use ".Flash.XxxMsg" to render the message + f.DataStore.GetData()["Flash"] = f + } else { + // the message map will be saved into the cookie and be shown in next response (a new page response which decodes the cookie) + f.Set(name, msg) + } +} + +func flashMsgStringOrHTML(msg any) string { + switch v := msg.(type) { + case string: + return v + case template.HTML: + return string(v) + } + panic(fmt.Sprintf("unknown type: %T", msg)) +} + +// Error sets error message +func (f *Flash) Error(msg any, current ...bool) { + f.ErrorMsg = flashMsgStringOrHTML(msg) + f.set("error", f.ErrorMsg, current...) +} + +// Warning sets warning message +func (f *Flash) Warning(msg any, current ...bool) { + f.WarningMsg = flashMsgStringOrHTML(msg) + f.set("warning", f.WarningMsg, current...) +} + +// Info sets info message +func (f *Flash) Info(msg any, current ...bool) { + f.InfoMsg = flashMsgStringOrHTML(msg) + f.set("info", f.InfoMsg, current...) +} + +// Success sets success message +func (f *Flash) Success(msg any, current ...bool) { + f.SuccessMsg = flashMsgStringOrHTML(msg) + f.set("success", f.SuccessMsg, current...) +} diff --git a/modules/web/middleware/locale.go b/modules/web/middleware/locale.go new file mode 100644 index 00000000..34a16f04 --- /dev/null +++ b/modules/web/middleware/locale.go @@ -0,0 +1,59 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "net/http" + + "code.gitea.io/gitea/modules/translation" + "code.gitea.io/gitea/modules/translation/i18n" + + "golang.org/x/text/language" +) + +// Locale handle locale +func Locale(resp http.ResponseWriter, req *http.Request) translation.Locale { + // 1. Check URL arguments. + lang := req.URL.Query().Get("lang") + changeLang := lang != "" + + // 2. Get language information from cookies. + if len(lang) == 0 { + ck, _ := req.Cookie("lang") + if ck != nil { + lang = ck.Value + } + } + + // Check again in case someone changes the supported language list. + if lang != "" && !i18n.DefaultLocales.HasLang(lang) { + lang = "" + changeLang = false + } + + // 3. Get language information from 'Accept-Language'. + // The first element in the list is chosen to be the default language automatically. + if len(lang) == 0 { + tags, _, _ := language.ParseAcceptLanguage(req.Header.Get("Accept-Language")) + tag := translation.Match(tags...) + lang = tag.String() + } + + if changeLang { + SetLocaleCookie(resp, lang, 1<<31-1) + } + + return translation.NewLocale(lang) +} + +// SetLocaleCookie convenience function to set the locale cookie consistently +func SetLocaleCookie(resp http.ResponseWriter, lang string, maxAge int) { + SetSiteCookie(resp, "lang", lang, maxAge) +} + +// DeleteLocaleCookie convenience function to delete the locale cookie consistently +// Setting the lang cookie will trigger the middleware to reset the language to previous state. +func DeleteLocaleCookie(resp http.ResponseWriter) { + SetSiteCookie(resp, "lang", "", -1) +} diff --git a/modules/web/middleware/request.go b/modules/web/middleware/request.go new file mode 100644 index 00000000..0bb155df --- /dev/null +++ b/modules/web/middleware/request.go @@ -0,0 +1,14 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package middleware + +import ( + "net/http" + "strings" +) + +// IsAPIPath returns true if the specified URL is an API path +func IsAPIPath(req *http.Request) bool { + return strings.HasPrefix(req.URL.Path, "/api/") +} diff --git a/modules/web/route.go b/modules/web/route.go new file mode 100644 index 00000000..805fcb44 --- /dev/null +++ b/modules/web/route.go @@ -0,0 +1,211 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "net/http" + "strings" + + "code.gitea.io/gitea/modules/web/middleware" + + "gitea.com/go-chi/binding" + "github.com/go-chi/chi/v5" +) + +// Bind binding an obj to a handler's context data +func Bind[T any](_ T) http.HandlerFunc { + return func(resp http.ResponseWriter, req *http.Request) { + theObj := new(T) // create a new form obj for every request but not use obj directly + data := middleware.GetContextData(req.Context()) + binding.Bind(req, theObj) + SetForm(data, theObj) + middleware.AssignForm(theObj, data) + } +} + +// SetForm set the form object +func SetForm(dataStore middleware.ContextDataStore, obj any) { + dataStore.GetData()["__form"] = obj +} + +// GetForm returns the validate form information +func GetForm(dataStore middleware.ContextDataStore) any { + return dataStore.GetData()["__form"] +} + +// Route defines a route based on chi's router +type Route struct { + R chi.Router + curGroupPrefix string + curMiddlewares []any +} + +// NewRoute creates a new route +func NewRoute() *Route { + r := chi.NewRouter() + return &Route{R: r} +} + +// Use supports two middlewares +func (r *Route) Use(middlewares ...any) { + for _, m := range middlewares { + if m != nil { + r.R.Use(toHandlerProvider(m)) + } + } +} + +// Group mounts a sub-Router along a `pattern` string. +func (r *Route) Group(pattern string, fn func(), middlewares ...any) { + previousGroupPrefix := r.curGroupPrefix + previousMiddlewares := r.curMiddlewares + r.curGroupPrefix += pattern + r.curMiddlewares = append(r.curMiddlewares, middlewares...) + + fn() + + r.curGroupPrefix = previousGroupPrefix + r.curMiddlewares = previousMiddlewares +} + +func (r *Route) getPattern(pattern string) string { + newPattern := r.curGroupPrefix + pattern + if !strings.HasPrefix(newPattern, "/") { + newPattern = "/" + newPattern + } + if newPattern == "/" { + return newPattern + } + return strings.TrimSuffix(newPattern, "/") +} + +func (r *Route) wrapMiddlewareAndHandler(h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { + handlerProviders := make([]func(http.Handler) http.Handler, 0, len(r.curMiddlewares)+len(h)+1) + for _, m := range r.curMiddlewares { + if m != nil { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } + } + for _, m := range h { + if h != nil { + handlerProviders = append(handlerProviders, toHandlerProvider(m)) + } + } + middlewares := handlerProviders[:len(handlerProviders)-1] + handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP + mockPoint := RouteMockPoint(MockAfterMiddlewares) + if mockPoint != nil { + middlewares = append(middlewares, mockPoint) + } + return middlewares, handlerFunc +} + +// Methods adds the same handlers for multiple http "methods" (separated by ","). +// If any method is invalid, the lower level router will panic. +func (r *Route) Methods(methods, pattern string, h ...any) { + middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) + fullPattern := r.getPattern(pattern) + if strings.Contains(methods, ",") { + methods := strings.Split(methods, ",") + for _, method := range methods { + r.R.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc) + } + } else { + r.R.With(middlewares...).Method(methods, fullPattern, handlerFunc) + } +} + +// Mount attaches another Route along ./pattern/* +func (r *Route) Mount(pattern string, subR *Route) { + subR.Use(r.curMiddlewares...) + r.R.Mount(r.getPattern(pattern), subR.R) +} + +// Any delegate requests for all methods +func (r *Route) Any(pattern string, h ...any) { + middlewares, handlerFunc := r.wrapMiddlewareAndHandler(h) + r.R.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc) +} + +// Delete delegate delete method +func (r *Route) Delete(pattern string, h ...any) { + r.Methods("DELETE", pattern, h...) +} + +// Get delegate get method +func (r *Route) Get(pattern string, h ...any) { + r.Methods("GET", pattern, h...) +} + +// Head delegate head method +func (r *Route) Head(pattern string, h ...any) { + r.Methods("HEAD", pattern, h...) +} + +// Post delegate post method +func (r *Route) Post(pattern string, h ...any) { + r.Methods("POST", pattern, h...) +} + +// Put delegate put method +func (r *Route) Put(pattern string, h ...any) { + r.Methods("PUT", pattern, h...) +} + +// Patch delegate patch method +func (r *Route) Patch(pattern string, h ...any) { + r.Methods("PATCH", pattern, h...) +} + +// ServeHTTP implements http.Handler +func (r *Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.R.ServeHTTP(w, req) +} + +// NotFound defines a handler to respond whenever a route could not be found. +func (r *Route) NotFound(h http.HandlerFunc) { + r.R.NotFound(h) +} + +// Combo delegates requests to Combo +func (r *Route) Combo(pattern string, h ...any) *Combo { + return &Combo{r, pattern, h} +} + +// Combo represents a tiny group routes with same pattern +type Combo struct { + r *Route + pattern string + h []any +} + +// Get delegates Get method +func (c *Combo) Get(h ...any) *Combo { + c.r.Get(c.pattern, append(c.h, h...)...) + return c +} + +// Post delegates Post method +func (c *Combo) Post(h ...any) *Combo { + c.r.Post(c.pattern, append(c.h, h...)...) + return c +} + +// Delete delegates Delete method +func (c *Combo) Delete(h ...any) *Combo { + c.r.Delete(c.pattern, append(c.h, h...)...) + return c +} + +// Put delegates Put method +func (c *Combo) Put(h ...any) *Combo { + c.r.Put(c.pattern, append(c.h, h...)...) + return c +} + +// Patch delegates Patch method +func (c *Combo) Patch(h ...any) *Combo { + c.r.Patch(c.pattern, append(c.h, h...)...) + return c +} diff --git a/modules/web/route_test.go b/modules/web/route_test.go new file mode 100644 index 00000000..d8015d6e --- /dev/null +++ b/modules/web/route_test.go @@ -0,0 +1,179 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + chi "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRoute1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + r := NewRoute() + r.Get("/{username}/{reponame}/{type:issues|pulls}", func(resp http.ResponseWriter, req *http.Request) { + username := chi.URLParam(req, "username") + assert.EqualValues(t, "gitea", username) + reponame := chi.URLParam(req, "reponame") + assert.EqualValues(t, "gitea", reponame) + tp := chi.URLParam(req, "type") + assert.EqualValues(t, "issues", tp) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) +} + +func TestRoute2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + hit := -1 + + r := NewRoute() + r.Group("/{username}/{reponame}", func() { + r.Group("", func() { + r.Get("/{type:issues|pulls}", func(resp http.ResponseWriter, req *http.Request) { + username := chi.URLParam(req, "username") + assert.EqualValues(t, "gitea", username) + reponame := chi.URLParam(req, "reponame") + assert.EqualValues(t, "gitea", reponame) + tp := chi.URLParam(req, "type") + assert.EqualValues(t, "issues", tp) + hit = 0 + }) + + r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { + username := chi.URLParam(req, "username") + assert.EqualValues(t, "gitea", username) + reponame := chi.URLParam(req, "reponame") + assert.EqualValues(t, "gitea", reponame) + tp := chi.URLParam(req, "type") + assert.EqualValues(t, "issues", tp) + index := chi.URLParam(req, "index") + assert.EqualValues(t, "1", index) + hit = 1 + }) + }, func(resp http.ResponseWriter, req *http.Request) { + if stop, err := strconv.Atoi(req.FormValue("stop")); err == nil { + hit = stop + resp.WriteHeader(http.StatusOK) + } + }) + + r.Group("/issues/{index}", func() { + r.Get("/view", func(resp http.ResponseWriter, req *http.Request) { + username := chi.URLParam(req, "username") + assert.EqualValues(t, "gitea", username) + reponame := chi.URLParam(req, "reponame") + assert.EqualValues(t, "gitea", reponame) + index := chi.URLParam(req, "index") + assert.EqualValues(t, "1", index) + hit = 2 + }) + }) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 0, hit) + + req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 1, hit) + + req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1?stop=100", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 100, hit) + + req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 2, hit) +} + +func TestRoute3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + hit := -1 + + m := NewRoute() + r := NewRoute() + r.Mount("/api/v1", m) + + m.Group("/repos", func() { + m.Group("/{username}/{reponame}", func() { + m.Group("/branch_protections", func() { + m.Get("", func(resp http.ResponseWriter, req *http.Request) { + hit = 0 + }) + m.Post("", func(resp http.ResponseWriter, req *http.Request) { + hit = 1 + }) + m.Group("/{name}", func() { + m.Get("", func(resp http.ResponseWriter, req *http.Request) { + hit = 2 + }) + m.Patch("", func(resp http.ResponseWriter, req *http.Request) { + hit = 3 + }) + m.Delete("", func(resp http.ResponseWriter, req *http.Request) { + hit = 4 + }) + }) + }) + }) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 0, hit) + + req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) + assert.EqualValues(t, 1, hit) + + req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 2, hit) + + req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 3, hit) + + req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.EqualValues(t, http.StatusOK, recorder.Code) + assert.EqualValues(t, 4, hit) +} diff --git a/modules/web/routemock.go b/modules/web/routemock.go new file mode 100644 index 00000000..cb41f63b --- /dev/null +++ b/modules/web/routemock.go @@ -0,0 +1,61 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "net/http" + + "code.gitea.io/gitea/modules/setting" +) + +// MockAfterMiddlewares is a general mock point, it's between middlewares and the handler +const MockAfterMiddlewares = "MockAfterMiddlewares" + +var routeMockPoints = map[string]func(next http.Handler) http.Handler{} + +// RouteMockPoint registers a mock point as a middleware for testing, example: +// +// r.Use(web.RouteMockPoint("my-mock-point-1")) +// r.Get("/foo", middleware2, web.RouteMockPoint("my-mock-point-2"), middleware2, handler) +// +// Then use web.RouteMock to mock the route execution. +// It only takes effect in testing mode (setting.IsInTesting == true). +func RouteMockPoint(pointName string) func(next http.Handler) http.Handler { + if !setting.IsInTesting { + return nil + } + routeMockPoints[pointName] = nil + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h := routeMockPoints[pointName]; h != nil { + h(next).ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r) + } + }) + } +} + +// RouteMock uses the registered mock point to mock the route execution, example: +// +// defer web.RouteMockReset() +// web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) { +// ctx.WriteResponse(...) +// } +// +// Then the mock function will be executed as a middleware at the mock point. +// It only takes effect in testing mode (setting.IsInTesting == true). +func RouteMock(pointName string, h any) { + if _, ok := routeMockPoints[pointName]; !ok { + panic("route mock point not found: " + pointName) + } + routeMockPoints[pointName] = toHandlerProvider(h) +} + +// RouteMockReset resets all mock points (no mock anymore) +func RouteMockReset() { + for k := range routeMockPoints { + routeMockPoints[k] = nil // keep the keys because RouteMock will check the keys to make sure no misspelling + } +} diff --git a/modules/web/routemock_test.go b/modules/web/routemock_test.go new file mode 100644 index 00000000..cd99b993 --- /dev/null +++ b/modules/web/routemock_test.go @@ -0,0 +1,71 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package web + +import ( + "net/http" + "net/http/httptest" + "testing" + + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRouteMock(t *testing.T) { + setting.IsInTesting = true + + r := NewRoute() + middleware1 := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Middleware1", "m1") + } + middleware2 := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Middleware2", "m2") + } + handler := func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-Handler", "h") + } + r.Get("/foo", middleware1, RouteMockPoint("mock-point"), middleware2, handler) + + // normal request + recorder := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://localhost:8000/foo", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 3) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) + assert.EqualValues(t, "h", recorder.Header().Get("X-Test-Handler")) + RouteMockReset() + + // mock at "mock-point" + RouteMock("mock-point", func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-MockPoint", "a") + resp.WriteHeader(http.StatusOK) + }) + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 2) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "a", recorder.Header().Get("X-Test-MockPoint")) + RouteMockReset() + + // mock at MockAfterMiddlewares + RouteMock(MockAfterMiddlewares, func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Set("X-Test-MockPoint", "b") + resp.WriteHeader(http.StatusOK) + }) + recorder = httptest.NewRecorder() + req, err = http.NewRequest("GET", "http://localhost:8000/foo", nil) + require.NoError(t, err) + r.ServeHTTP(recorder, req) + assert.Len(t, recorder.Header(), 3) + assert.EqualValues(t, "m1", recorder.Header().Get("X-Test-Middleware1")) + assert.EqualValues(t, "m2", recorder.Header().Get("X-Test-Middleware2")) + assert.EqualValues(t, "b", recorder.Header().Get("X-Test-MockPoint")) + RouteMockReset() +} diff --git a/modules/web/routing/context.go b/modules/web/routing/context.go new file mode 100644 index 00000000..c5e85a41 --- /dev/null +++ b/modules/web/routing/context.go @@ -0,0 +1,49 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "context" + "net/http" +) + +type contextKeyType struct{} + +var contextKey contextKeyType + +// UpdateFuncInfo updates a context's func info +func UpdateFuncInfo(ctx context.Context, funcInfo *FuncInfo) { + record, ok := ctx.Value(contextKey).(*requestRecord) + if !ok { + return + } + + record.lock.Lock() + record.funcInfo = funcInfo + record.lock.Unlock() +} + +// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it +func MarkLongPolling(resp http.ResponseWriter, req *http.Request) { + record, ok := req.Context().Value(contextKey).(*requestRecord) + if !ok { + return + } + + record.lock.Lock() + record.isLongPolling = true + record.lock.Unlock() +} + +// UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that. +func UpdatePanicError(ctx context.Context, err any) { + record, ok := ctx.Value(contextKey).(*requestRecord) + if !ok { + return + } + + record.lock.Lock() + record.panicError = err + record.lock.Unlock() +} diff --git a/modules/web/routing/funcinfo.go b/modules/web/routing/funcinfo.go new file mode 100644 index 00000000..f4e9731a --- /dev/null +++ b/modules/web/routing/funcinfo.go @@ -0,0 +1,172 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "sync" +) + +var ( + funcInfoMap = map[uintptr]*FuncInfo{} + funcInfoNameMap = map[string]*FuncInfo{} + funcInfoMapMu sync.RWMutex +) + +// FuncInfo contains information about the function to be logged by the router log +type FuncInfo struct { + file string + shortFile string + line int + name string + shortName string +} + +// String returns a string form of the FuncInfo for logging +func (info *FuncInfo) String() string { + if info == nil { + return "unknown-handler" + } + return fmt.Sprintf("%s:%d(%s)", info.shortFile, info.line, info.shortName) +} + +// GetFuncInfo returns the FuncInfo for a provided function and friendlyname +func GetFuncInfo(fn any, friendlyName ...string) *FuncInfo { + // ptr represents the memory position of the function passed in as v. + // This will be used as program counter in FuncForPC below + ptr := reflect.ValueOf(fn).Pointer() + + // if we have been provided with a friendlyName look for the named funcs + if len(friendlyName) == 1 { + name := friendlyName[0] + funcInfoMapMu.RLock() + info, ok := funcInfoNameMap[name] + funcInfoMapMu.RUnlock() + if ok { + return info + } + } + + // Otherwise attempt to get pre-cached information for this function pointer + funcInfoMapMu.RLock() + info, ok := funcInfoMap[ptr] + funcInfoMapMu.RUnlock() + + if ok { + if len(friendlyName) == 1 { + name := friendlyName[0] + info = copyFuncInfo(info) + info.shortName = name + + funcInfoNameMap[name] = info + funcInfoMapMu.Lock() + funcInfoNameMap[name] = info + funcInfoMapMu.Unlock() + } + return info + } + + // This is likely the first time we have seen this function + // + // Get the runtime.func for this function (if we can) + f := runtime.FuncForPC(ptr) + if f != nil { + info = convertToFuncInfo(f) + + // cache this info globally + funcInfoMapMu.Lock() + funcInfoMap[ptr] = info + + // if we have been provided with a friendlyName override the short name we've generated + if len(friendlyName) == 1 { + name := friendlyName[0] + info = copyFuncInfo(info) + info.shortName = name + funcInfoNameMap[name] = info + } + funcInfoMapMu.Unlock() + } + return info +} + +// convertToFuncInfo take a runtime.Func and convert it to a logFuncInfo, fill in shorten filename, etc +func convertToFuncInfo(f *runtime.Func) *FuncInfo { + file, line := f.FileLine(f.Entry()) + + info := &FuncInfo{ + file: strings.ReplaceAll(file, "\\", "/"), + line: line, + name: f.Name(), + } + + // only keep last 2 names in path, fall back to funcName if not + info.shortFile = shortenFilename(info.file, info.name) + + // remove package prefix. eg: "xxx.com/pkg1/pkg2.foo" => "pkg2.foo" + pos := strings.LastIndexByte(info.name, '/') + if pos >= 0 { + info.shortName = info.name[pos+1:] + } else { + info.shortName = info.name + } + + // remove ".func[0-9]*" suffix for anonymous func + info.shortName = trimAnonymousFunctionSuffix(info.shortName) + + return info +} + +func copyFuncInfo(l *FuncInfo) *FuncInfo { + return &FuncInfo{ + file: l.file, + shortFile: l.shortFile, + line: l.line, + name: l.name, + shortName: l.shortName, + } +} + +// shortenFilename generates a short source code filename from a full package path, eg: "code.gitea.io/routers/common/logger_context.go" => "common/logger_context.go" +func shortenFilename(filename, fallback string) string { + if filename == "" { + return fallback + } + if lastIndex := strings.LastIndexByte(filename, '/'); lastIndex >= 0 { + if secondLastIndex := strings.LastIndexByte(filename[:lastIndex], '/'); secondLastIndex >= 0 { + return filename[secondLastIndex+1:] + } + } + return filename +} + +// trimAnonymousFunctionSuffix trims ".func[0-9]*" from the end of anonymous function names, we only want to see the main function names in logs +func trimAnonymousFunctionSuffix(name string) string { + // if the name is an anonymous name, it should be like "{main-function}.func1", so the length can not be less than 7 + if len(name) < 7 { + return name + } + + funcSuffixIndex := strings.LastIndex(name, ".func") + if funcSuffixIndex < 0 { + return name + } + + hasFuncSuffix := true + + // len(".func") = 5 + for i := funcSuffixIndex + 5; i < len(name); i++ { + if name[i] < '0' || name[i] > '9' { + hasFuncSuffix = false + break + } + } + + if hasFuncSuffix { + return name[:funcSuffixIndex] + } + return name +} diff --git a/modules/web/routing/funcinfo_test.go b/modules/web/routing/funcinfo_test.go new file mode 100644 index 00000000..2ab59603 --- /dev/null +++ b/modules/web/routing/funcinfo_test.go @@ -0,0 +1,80 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "fmt" + "testing" +) + +func Test_shortenFilename(t *testing.T) { + tests := []struct { + filename string + fallback string + expected string + }{ + { + "code.gitea.io/routers/common/logger_context.go", + "NO_FALLBACK", + "common/logger_context.go", + }, + { + "common/logger_context.go", + "NO_FALLBACK", + "common/logger_context.go", + }, + { + "logger_context.go", + "NO_FALLBACK", + "logger_context.go", + }, + { + "", + "USE_FALLBACK", + "USE_FALLBACK", + }, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("shortenFilename('%s')", tt.filename), func(t *testing.T) { + if gotShort := shortenFilename(tt.filename, tt.fallback); gotShort != tt.expected { + t.Errorf("shortenFilename('%s'), expect '%s', but get '%s'", tt.filename, tt.expected, gotShort) + } + }) + } +} + +func Test_trimAnonymousFunctionSuffix(t *testing.T) { + tests := []struct { + name string + want string + }{ + { + "notAnonymous", + "notAnonymous", + }, + { + "anonymous.func1", + "anonymous", + }, + { + "notAnonymous.funca", + "notAnonymous.funca", + }, + { + "anonymous.func100", + "anonymous", + }, + { + "anonymous.func100.func6", + "anonymous.func100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := trimAnonymousFunctionSuffix(tt.name); got != tt.want { + t.Errorf("trimAnonymousFunctionSuffix() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/modules/web/routing/logger.go b/modules/web/routing/logger.go new file mode 100644 index 00000000..5f3a7592 --- /dev/null +++ b/modules/web/routing/logger.go @@ -0,0 +1,109 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "net/http" + "strings" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/web/types" +) + +// NewLoggerHandler is a handler that will log routing to the router log taking account of +// routing information +func NewLoggerHandler() func(next http.Handler) http.Handler { + manager := requestRecordsManager{ + requestRecords: map[uint64]*requestRecord{}, + } + manager.startSlowQueryDetector(3 * time.Second) + + logger := log.GetLogger("router") + manager.print = logPrinter(logger) + return manager.handler +} + +var ( + startMessage = log.NewColoredValue("started ", log.DEBUG.ColorAttributes()...) + slowMessage = log.NewColoredValue("slow ", log.WARN.ColorAttributes()...) + pollingMessage = log.NewColoredValue("polling ", log.INFO.ColorAttributes()...) + failedMessage = log.NewColoredValue("failed ", log.WARN.ColorAttributes()...) + completedMessage = log.NewColoredValue("completed", log.INFO.ColorAttributes()...) + unknownHandlerMessage = log.NewColoredValue("completed", log.ERROR.ColorAttributes()...) +) + +func logPrinter(logger log.Logger) func(trigger Event, record *requestRecord) { + return func(trigger Event, record *requestRecord) { + if trigger == StartEvent { + if !logger.LevelEnabled(log.TRACE) { + // for performance, if the "started" message shouldn't be logged, we just return as early as possible + // developers can set the router log level to TRACE to get the "started" request messages. + return + } + // when a request starts, we have no information about the handler function information, we only have the request path + req := record.request + logger.Trace("router: %s %v %s for %s", startMessage, log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr) + return + } + + req := record.request + + // Get data from the record + record.lock.Lock() + handlerFuncInfo := record.funcInfo.String() + isLongPolling := record.isLongPolling + isUnknownHandler := record.funcInfo == nil + panicErr := record.panicError + record.lock.Unlock() + + if trigger == StillExecutingEvent { + message := slowMessage + logf := logger.Warn + if isLongPolling { + logf = logger.Info + message = pollingMessage + } + logf("router: %s %v %s for %s, elapsed %v @ %s", + message, + log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr, + log.ColoredTime(time.Since(record.startTime)), + handlerFuncInfo, + ) + return + } + + if panicErr != nil { + logger.Warn("router: %s %v %s for %s, panic in %v @ %s, err=%v", + failedMessage, + log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr, + log.ColoredTime(time.Since(record.startTime)), + handlerFuncInfo, + panicErr, + ) + return + } + + var status int + if v, ok := record.responseWriter.(types.ResponseStatusProvider); ok { + status = v.WrittenStatus() + } + logf := logger.Info + if strings.HasPrefix(req.RequestURI, "/assets/") { + logf = logger.Trace + } + message := completedMessage + if isUnknownHandler { + logf = logger.Error + message = unknownHandlerMessage + } + + logf("router: %s %v %s for %s, %v %v in %v @ %s", + message, + log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr, + log.ColoredStatus(status), log.ColoredStatus(status, http.StatusText(status)), log.ColoredTime(time.Since(record.startTime)), + handlerFuncInfo, + ) + } +} diff --git a/modules/web/routing/logger_manager.go b/modules/web/routing/logger_manager.go new file mode 100644 index 00000000..aa25ec3a --- /dev/null +++ b/modules/web/routing/logger_manager.go @@ -0,0 +1,124 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "context" + "net/http" + "sync" + "time" + + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/process" +) + +// Event indicates when the printer is triggered +type Event int + +const ( + // StartEvent at the beginning of a request + StartEvent Event = iota + + // StillExecutingEvent the request is still executing + StillExecutingEvent + + // EndEvent the request has ended (either completed or failed) + EndEvent +) + +// Printer is used to output the log for a request +type Printer func(trigger Event, record *requestRecord) + +type requestRecordsManager struct { + print Printer + + lock sync.Mutex + + requestRecords map[uint64]*requestRecord + count uint64 +} + +func (manager *requestRecordsManager) startSlowQueryDetector(threshold time.Duration) { + go graceful.GetManager().RunWithShutdownContext(func(ctx context.Context) { + ctx, _, finished := process.GetManager().AddTypedContext(ctx, "Service: SlowQueryDetector", process.SystemProcessType, true) + defer finished() + // This go-routine checks all active requests every second. + // If a request has been running for a long time (eg: /user/events), we also print a log with "still-executing" message + // After the "still-executing" log is printed, the record will be removed from the map to prevent from duplicated logs in future + + // We do not care about accurate duration here. It just does the check periodically, 0.5s or 1.5s are all OK. + t := time.NewTicker(time.Second) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + now := time.Now() + + var slowRequests []*requestRecord + + // find all slow requests with lock + manager.lock.Lock() + for index, record := range manager.requestRecords { + if now.Sub(record.startTime) < threshold { + continue + } + + slowRequests = append(slowRequests, record) + delete(manager.requestRecords, index) + } + manager.lock.Unlock() + + // print logs for slow requests + for _, record := range slowRequests { + manager.print(StillExecutingEvent, record) + } + } + } + }) +} + +func (manager *requestRecordsManager) handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + record := &requestRecord{ + startTime: time.Now(), + request: req, + responseWriter: w, + } + + // generate a record index an insert into the map + manager.lock.Lock() + record.index = manager.count + manager.count++ + manager.requestRecords[record.index] = record + manager.lock.Unlock() + + defer func() { + // just in case there is a panic. now the panics are all recovered in middleware.go + localPanicErr := recover() + if localPanicErr != nil { + record.lock.Lock() + record.panicError = localPanicErr + record.lock.Unlock() + } + + // remove from the record map + manager.lock.Lock() + delete(manager.requestRecords, record.index) + manager.lock.Unlock() + + // log the end of the request + manager.print(EndEvent, record) + + if localPanicErr != nil { + // the panic wasn't recovered before us, so we should pass it up, and let the framework handle the panic error + panic(localPanicErr) + } + }() + + req = req.WithContext(context.WithValue(req.Context(), contextKey, record)) + manager.print(StartEvent, record) + next.ServeHTTP(w, req) + }) +} diff --git a/modules/web/routing/requestrecord.go b/modules/web/routing/requestrecord.go new file mode 100644 index 00000000..cc61fc4d --- /dev/null +++ b/modules/web/routing/requestrecord.go @@ -0,0 +1,28 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package routing + +import ( + "net/http" + "sync" + "time" +) + +type requestRecord struct { + // index of the record in the records map + index uint64 + + // immutable fields + startTime time.Time + request *http.Request + responseWriter http.ResponseWriter + + // mutex + lock sync.RWMutex + + // mutable fields + isLongPolling bool + funcInfo *FuncInfo + panicError any +} diff --git a/modules/web/types/response.go b/modules/web/types/response.go new file mode 100644 index 00000000..834f4912 --- /dev/null +++ b/modules/web/types/response.go @@ -0,0 +1,10 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package types + +// ResponseStatusProvider is an interface to get the written status in the response +// Many packages need this interface, so put it in the separate package to avoid import cycle +type ResponseStatusProvider interface { + WrittenStatus() int +} |