From 6c5de7656257ffd8f2dcc441a313ab9aed97e834 Mon Sep 17 00:00:00 2001 From: r Date: Sat, 20 Mar 2021 05:12:48 +0000 Subject: Refactor --- service/transport.go | 301 ++++++++++++++++++++++----------------------------- 1 file changed, 131 insertions(+), 170 deletions(-) (limited to 'service/transport.go') diff --git a/service/transport.go b/service/transport.go index 1180f6c..317ce5b 100644 --- a/service/transport.go +++ b/service/transport.go @@ -1,8 +1,8 @@ package service import ( + "context" "encoding/json" - "errors" "log" "net/http" "strconv" @@ -10,44 +10,34 @@ import ( "bloat/mastodon" "bloat/model" + "bloat/renderer" "github.com/gorilla/mux" ) -var ( - errInvalidSession = errors.New("invalid session") - errInvalidCSRFToken = errors.New("invalid csrf token") -) - const ( sessionExp = 365 * 24 * time.Hour ) -type respType int - const ( - HTML respType = iota + HTML int = iota JSON ) -type authType int - const ( - NOAUTH authType = iota + NOAUTH int = iota SESSION CSRF ) type client struct { *mastodon.Client - http.ResponseWriter - Req *http.Request - CSRFToken string - Session model.Session -} - -func (c *client) url() string { - return c.Req.URL.RequestURI() + w http.ResponseWriter + r *http.Request + s model.Session + csrf string + ctx context.Context + rctx *renderer.Context } func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { @@ -59,66 +49,50 @@ func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { } func writeJson(c *client, data interface{}) error { - return json.NewEncoder(c).Encode(map[string]interface{}{ + return json.NewEncoder(c.w).Encode(map[string]interface{}{ "data": data, }) } func redirect(c *client, url string) { - c.Header().Add("Location", url) - c.WriteHeader(http.StatusFound) + c.w.Header().Add("Location", url) + c.w.WriteHeader(http.StatusFound) } func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r := mux.NewRouter() - writeError := func(c *client, err error, t respType) { + writeError := func(c *client, err error, t int) { switch t { case HTML: - c.WriteHeader(http.StatusInternalServerError) + c.w.WriteHeader(http.StatusInternalServerError) s.ErrorPage(c, err) case JSON: - c.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(c).Encode(map[string]string{ + c.w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(c.w).Encode(map[string]string{ "error": err.Error(), }) } } - authenticate := func(c *client, t authType) error { - if t >= SESSION { - cookie, err := c.Req.Cookie("session_id") - if err != nil || len(cookie.Value) < 1 { - return errInvalidSession - } - c.Session, err = s.sessionRepo.Get(cookie.Value) - if err != nil { - return errInvalidSession - } - app, err := s.appRepo.Get(c.Session.InstanceDomain) - if err != nil { - return err - } - c.Client = mastodon.NewClient(&mastodon.Config{ - Server: app.InstanceURL, - ClientID: app.ClientID, - ClientSecret: app.ClientSecret, - AccessToken: c.Session.AccessToken, - }) + authenticate := func(c *client, t int) error { + var sid string + if cookie, _ := c.r.Cookie("session_id"); cookie != nil { + sid = cookie.Value } - if t >= CSRF { - c.CSRFToken = c.Req.FormValue("csrf_token") - if len(c.CSRFToken) < 1 || c.CSRFToken != c.Session.CSRFToken { - return errInvalidCSRFToken - } - } - return nil + csrf := c.r.FormValue("csrf_token") + ref := c.r.URL.RequestURI() + return s.authenticate(c, sid, csrf, ref, t) } - handle := func(f func(c *client) error, at authType, rt respType) http.HandlerFunc { + handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { var err error - c := &client{Req: req, ResponseWriter: w} + c := &client{ + ctx: req.Context(), + w: w, + r: req, + } defer func(begin time.Time) { logger.Printf("path=%s, err=%v, took=%v\n", @@ -132,7 +106,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { case JSON: ct = "application/json" } - c.Header().Add("Content-Type", ct) + c.w.Header().Add("Content-Type", ct) err = authenticate(c, at) if err != nil { @@ -149,12 +123,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } rootPage := handle(func(c *client) error { - sid, _ := c.Req.Cookie("session_id") - if sid == nil || len(sid.Value) < 0 { - redirect(c, "/signin") - return nil - } - session, err := s.sessionRepo.Get(sid.Value) + err := authenticate(c, SESSION) if err != nil { if err == errInvalidSession { redirect(c, "/signin") @@ -162,7 +131,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } return err } - if len(session.AccessToken) < 1 { + if !c.s.IsLoggedIn() { redirect(c, "/signin") return nil } @@ -178,18 +147,18 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if !ok { return s.SigninPage(c) } - url, sid, err := s.NewSession(instance) + url, sid, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c, sid, sessionExp) + setSessionCookie(c.w, sid, sessionExp) redirect(c, url) return nil }, NOAUTH, HTML) timelinePage := handle(func(c *client) error { - tType, _ := mux.Vars(c.Req)["type"] - q := c.Req.URL.Query() + tType, _ := mux.Vars(c.r)["type"] + q := c.r.URL.Query() instance := q.Get("instance") maxID := q.Get("max_id") minID := q.Get("min_id") @@ -202,41 +171,41 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, SESSION, HTML) threadPage := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - q := c.Req.URL.Query() + id, _ := mux.Vars(c.r)["id"] + q := c.r.URL.Query() reply := q.Get("reply") return s.ThreadPage(c, id, len(reply) > 1) }, SESSION, HTML) likedByPage := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] return s.LikedByPage(c, id) }, SESSION, HTML) retweetedByPage := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] return s.RetweetedByPage(c, id) }, SESSION, HTML) notificationsPage := handle(func(c *client) error { - q := c.Req.URL.Query() + q := c.r.URL.Query() maxID := q.Get("max_id") minID := q.Get("min_id") return s.NotificationPage(c, maxID, minID) }, SESSION, HTML) userPage := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - pageType, _ := mux.Vars(c.Req)["type"] - q := c.Req.URL.Query() + id, _ := mux.Vars(c.r)["id"] + pageType, _ := mux.Vars(c.r)["type"] + q := c.r.URL.Query() maxID := q.Get("max_id") minID := q.Get("min_id") return s.UserPage(c, id, pageType, maxID, minID) }, SESSION, HTML) userSearchPage := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - q := c.Req.URL.Query() + id, _ := mux.Vars(c.r)["id"] + q := c.r.URL.Query() sq := q.Get("q") offset, _ := strconv.Atoi(q.Get("offset")) return s.UserSearchPage(c, id, sq, offset) @@ -251,7 +220,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, SESSION, HTML) searchPage := handle(func(c *client) error { - q := c.Req.URL.Query() + q := c.r.URL.Query() sq := q.Get("q") qType := q.Get("type") offset, _ := strconv.Atoi(q.Get("offset")) @@ -267,49 +236,41 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, SESSION, HTML) signin := handle(func(c *client) error { - instance := c.Req.FormValue("instance") - url, sid, err := s.NewSession(instance) + instance := c.r.FormValue("instance") + url, sid, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c, sid, sessionExp) + setSessionCookie(c.w, sid, sessionExp) redirect(c, url) return nil }, NOAUTH, HTML) oauthCallback := handle(func(c *client) error { - q := c.Req.URL.Query() + q := c.r.URL.Query() token := q.Get("code") - token, userID, err := s.Signin(c, token) - if err != nil { - return err - } - - c.Session.AccessToken = token - c.Session.UserID = userID - err = s.sessionRepo.Add(c.Session) + err := s.Signin(c, token) if err != nil { return err } - redirect(c, "/") return nil }, SESSION, HTML) post := handle(func(c *client) error { - content := c.Req.FormValue("content") - replyToID := c.Req.FormValue("reply_to_id") - format := c.Req.FormValue("format") - visibility := c.Req.FormValue("visibility") - isNSFW := c.Req.FormValue("is_nsfw") == "on" - files := c.Req.MultipartForm.File["attachments"] + content := c.r.FormValue("content") + replyToID := c.r.FormValue("reply_to_id") + format := c.r.FormValue("format") + visibility := c.r.FormValue("visibility") + isNSFW := c.r.FormValue("is_nsfw") == "true" + files := c.r.MultipartForm.File["attachments"] id, err := s.Post(c, content, replyToID, format, visibility, isNSFW, files) if err != nil { return err } - location := c.Req.FormValue("referrer") + location := c.r.FormValue("referrer") if len(replyToID) > 0 { location = "/thread/" + replyToID + "#status-" + id } @@ -318,8 +279,8 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, CSRF, HTML) like := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") _, err := s.Like(c, id) if err != nil { return err @@ -327,13 +288,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) unlike := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") _, err := s.UnLike(c, id) if err != nil { return err @@ -341,13 +302,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) retweet := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") _, err := s.Retweet(c, id) if err != nil { return err @@ -355,13 +316,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) unretweet := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") _, err := s.UnRetweet(c, id) if err != nil { return err @@ -369,25 +330,25 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) vote := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - statusID := c.Req.FormValue("status_id") - choices, _ := c.Req.PostForm["choices"] + id, _ := mux.Vars(c.r)["id"] + statusID := c.r.FormValue("status_id") + choices, _ := c.r.PostForm["choices"] err := s.Vote(c, id, choices) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")+"#status-"+statusID) + redirect(c, c.r.FormValue("referrer")+"#status-"+statusID) return nil }, CSRF, HTML) follow := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - q := c.Req.URL.Query() + id, _ := mux.Vars(c.r)["id"] + q := c.r.URL.Query() var reblogs *bool if r, ok := q["reblogs"]; ok && len(r) > 0 { reblogs = new(bool) @@ -397,111 +358,111 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unfollow := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnFollow(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) accept := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Accept(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) reject := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Reject(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) mute := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Mute(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unMute := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnMute(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) block := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Block(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unBlock := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnBlock(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) subscribe := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Subscribe(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unSubscribe := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnSubscribe(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) settings := handle(func(c *client) error { - visibility := c.Req.FormValue("visibility") - format := c.Req.FormValue("format") - copyScope := c.Req.FormValue("copy_scope") == "true" - threadInNewTab := c.Req.FormValue("thread_in_new_tab") == "true" - hideAttachments := c.Req.FormValue("hide_attachments") == "true" - maskNSFW := c.Req.FormValue("mask_nsfw") == "true" - ni, _ := strconv.Atoi(c.Req.FormValue("notification_interval")) - fluorideMode := c.Req.FormValue("fluoride_mode") == "true" - darkMode := c.Req.FormValue("dark_mode") == "true" - antiDopamineMode := c.Req.FormValue("anti_dopamine_mode") == "true" + visibility := c.r.FormValue("visibility") + format := c.r.FormValue("format") + copyScope := c.r.FormValue("copy_scope") == "true" + threadInNewTab := c.r.FormValue("thread_in_new_tab") == "true" + hideAttachments := c.r.FormValue("hide_attachments") == "true" + maskNSFW := c.r.FormValue("mask_nsfw") == "true" + ni, _ := strconv.Atoi(c.r.FormValue("notification_interval")) + fluorideMode := c.r.FormValue("fluoride_mode") == "true" + darkMode := c.r.FormValue("dark_mode") == "true" + antiDopamineMode := c.r.FormValue("anti_dopamine_mode") == "true" settings := &model.Settings{ DefaultVisibility: visibility, @@ -525,49 +486,49 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, CSRF, HTML) muteConversation := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.MuteConversation(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unMuteConversation := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnMuteConversation(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) delete := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.Delete(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) readNotifications := handle(func(c *client) error { - q := c.Req.URL.Query() + q := c.r.URL.Query() maxID := q.Get("max_id") err := s.ReadNotifications(c, maxID) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) bookmark := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") err := s.Bookmark(c, id) if err != nil { return err @@ -575,13 +536,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) unBookmark := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] - rid := c.Req.FormValue("retweeted_by_id") + id, _ := mux.Vars(c.r)["id"] + rid := c.r.FormValue("retweeted_by_id") err := s.UnBookmark(c, id) if err != nil { return err @@ -589,40 +550,40 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.Req.FormValue("referrer")+"#status-"+id) + redirect(c, c.r.FormValue("referrer")+"#status-"+id) return nil }, CSRF, HTML) filter := handle(func(c *client) error { - phrase := c.Req.FormValue("phrase") - wholeWord := c.Req.FormValue("whole_word") == "true" + phrase := c.r.FormValue("phrase") + wholeWord := c.r.FormValue("whole_word") == "true" err := s.Filter(c, phrase, wholeWord) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) unFilter := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] err := s.UnFilter(c, id) if err != nil { return err } - redirect(c, c.Req.FormValue("referrer")) + redirect(c, c.r.FormValue("referrer")) return nil }, CSRF, HTML) signout := handle(func(c *client) error { s.Signout(c) - setSessionCookie(c, "", 0) + setSessionCookie(c.w, "", 0) redirect(c, "/") return nil }, CSRF, HTML) fLike := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] count, err := s.Like(c, id) if err != nil { return err @@ -631,7 +592,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, CSRF, JSON) fUnlike := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] count, err := s.UnLike(c, id) if err != nil { return err @@ -640,7 +601,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, CSRF, JSON) fRetweet := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] count, err := s.Retweet(c, id) if err != nil { return err @@ -649,7 +610,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, CSRF, JSON) fUnretweet := handle(func(c *client) error { - id, _ := mux.Vars(c.Req)["id"] + id, _ := mux.Vars(c.r)["id"] count, err := s.UnRetweet(c, id) if err != nil { return err -- cgit v1.2.3