From fa27d9c6eb525b2e55c3faab5dd8a3e0e9658536 Mon Sep 17 00:00:00 2001 From: r Date: Sun, 22 Nov 2020 17:29:58 +0000 Subject: Refactor things - Remove separate auth/logging and merge them into transport.go - Add helper function for http handlers --- service/transport.go | 1026 +++++++++++++++++++++----------------------------- 1 file changed, 419 insertions(+), 607 deletions(-) (limited to 'service/transport.go') diff --git a/service/transport.go b/service/transport.go index 3c9392a..80ad7f1 100644 --- a/service/transport.go +++ b/service/transport.go @@ -2,597 +2,478 @@ package service import ( "encoding/json" - "io" - "mime/multipart" + "errors" + "log" "net/http" "strconv" "time" + "bloat/mastodon" "bloat/model" "github.com/gorilla/mux" ) +var ( + errInvalidSession = errors.New("invalid session") + errInvalidCSRFToken = errors.New("invalid csrf token") +) + const ( sessionExp = 365 * 24 * time.Hour ) -func newClient(w io.Writer, req *http.Request, csrfToken string) *model.Client { - var sessionID string - if req != nil { - c, err := req.Cookie("session_id") - if err == nil { - sessionID = c.Value - } - } - return &model.Client{ - Writer: w, - Ctx: model.ClientCtx{ - SessionID: sessionID, - CSRFToken: csrfToken, - }, - } +type respType int + +const ( + HTML respType = iota + JSON +) + +type authType int + +const ( + NOAUTH authType = iota + SESSION + CSRF +) + +type client struct { + *mastodon.Client + http.ResponseWriter + Req *http.Request + CSRFToken string + Session model.Session } -func setSessionCookie(w http.ResponseWriter, sessionID string, exp time.Duration) { +func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { http.SetCookie(w, &http.Cookie{ Name: "session_id", - Value: sessionID, + Value: sid, Expires: time.Now().Add(exp), }) } -func getMultipartFormValue(mf *multipart.Form, key string) (val string) { - vals, ok := mf.Value[key] - if !ok { - return "" - } - if len(vals) < 1 { - return "" - } - return vals[0] -} - -func serveJson(w io.Writer, data interface{}) (err error) { - var d = make(map[string]interface{}) - d["data"] = data - return json.NewEncoder(w).Encode(d) +func writeJson(c *client, data interface{}) error { + return json.NewEncoder(c).Encode(map[string]interface{}{ + "data": data, + }) } -func serveJsonError(w http.ResponseWriter, err error) { - var d = make(map[string]interface{}) - d["error"] = err.Error() - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(d) - return +func redirect(c *client, url string) { + c.Header().Add("Location", url) + c.WriteHeader(http.StatusFound) } -func NewHandler(s Service, staticDir string) http.Handler { +func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r := mux.NewRouter() - rootPage := func(w http.ResponseWriter, req *http.Request) { - sessionID, _ := req.Cookie("session_id") - if sessionID != nil && len(sessionID.Value) > 0 { - c := newClient(w, req, "") - err := s.ServeRootPage(c) - if err != nil { - if err == errInvalidAccessToken { - w.Header().Add("Location", "/signin") - w.WriteHeader(http.StatusFound) - return - } - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } else { - w.Header().Add("Location", "/signin") - w.WriteHeader(http.StatusFound) - } - } - - navPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - err := s.ServeNavPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + writeError := func(c *client, err error, t respType) { + switch t { + case HTML: + c.WriteHeader(http.StatusInternalServerError) + s.ErrorPage(c, err) + case JSON: + c.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(c).Encode(map[string]string{ + "error": err.Error(), + }) } } - signinPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, nil, "") - instance, ok := s.SingleInstance() - if ok { - url, sessionID, err := s.NewSession(instance) + authenticate := func(c *client, t authType) error { + if t >= SESSION { + cookie, err := c.Req.Cookie("session_id") + if err != nil || len(cookie.Value) < 1 { + return errInvalidSession + } + c.Session, err = s.sessionRepo.Get(cookie.Value) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return errInvalidSession } - - setSessionCookie(w, sessionID, sessionExp) - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - } else { - err := s.ServeSigninPage(c) + app, err := s.appRepo.Get(c.Session.InstanceDomain) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err + } + c.Client = mastodon.NewClient(&mastodon.Config{ + Server: app.InstanceURL, + ClientID: app.ClientID, + ClientSecret: app.ClientSecret, + AccessToken: c.Session.AccessToken, + }) + } + if t >= CSRF { + c.CSRFToken = c.Req.FormValue("csrf_token") + if len(c.CSRFToken) < 1 || c.CSRFToken != c.Session.CSRFToken { + return errInvalidCSRFToken } } + return nil } - timelinePage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - tType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeTimelinePage(c, tType, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - defaultTimelinePage := func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Location", "/timeline/home") - w.WriteHeader(http.StatusFound) - } - - threadPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - reply := req.URL.Query().Get("reply") - - err := s.ServeThreadPage(c, id, len(reply) > 1) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - likedByPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - - err := s.ServeLikedByPage(c, id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - retweetedByPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - - err := s.ServeRetweetedByPage(c, id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - notificationsPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeNotificationPage(c, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } + handle := func(f func(c *client) error, at authType, rt respType) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + var err error + c := &client{Req: req, ResponseWriter: w} - userPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - pageType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") + defer func(begin time.Time) { + logger.Printf("path=%s, err=%v, took=%v\n", + req.URL.Path, err, time.Since(begin)) + }(time.Now()) - err := s.ServeUserPage(c, id, pageType, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - userSearchPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - q := req.URL.Query().Get("q") - offsetStr := req.URL.Query().Get("offset") + var ct string + switch rt { + case HTML: + ct = "text/html; charset=utf-8" + case JSON: + ct = "application/json" + } + c.Header().Add("Content-Type", ct) - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) + err = authenticate(c, at) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) + writeError(c, err, rt) return } - } - err = s.ServeUserSearchPage(c, id, q, offset) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - aboutPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeAboutPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - emojisPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeEmojiPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - searchPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - q := req.URL.Query().Get("q") - qType := req.URL.Query().Get("type") - offsetStr := req.URL.Query().Get("offset") - - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) + err = f(c) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) + writeError(c, err, rt) return } } - - err = s.ServeSearchPage(c, q, qType, offset) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } } - settingsPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeSettingsPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + rootPage := handle(func(c *client) error { + sid, _ := c.Req.Cookie("session_id") + if sid == nil || len(sid.Value) < 0 { + redirect(c, "/signin") + return nil } - } - - signin := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, nil, "") - instance := req.FormValue("instance") - - url, sessionID, err := s.NewSession(instance) + session, err := s.sessionRepo.Get(sid.Value) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + if err == errInvalidSession { + redirect(c, "/signin") + return nil + } + return err } - - setSessionCookie(w, sessionID, sessionExp) - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - } - - oauthCallback := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - token := req.URL.Query().Get("code") - - _, _, err := s.Signin(c, "", token) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + if len(session.AccessToken) < 1 { + redirect(c, "/signin") + return nil } + return s.RootPage(c) + }, NOAUTH, HTML) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - post := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - err := req.ParseMultipartForm(4 << 20) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } + navPage := handle(func(c *client) error { + return s.NavPage(c) + }, SESSION, HTML) - c = newClient(w, req, - getMultipartFormValue(req.MultipartForm, "csrf_token")) - content := getMultipartFormValue(req.MultipartForm, "content") - replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") - format := getMultipartFormValue(req.MultipartForm, "format") - visibility := getMultipartFormValue(req.MultipartForm, "visibility") - isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") - files := req.MultipartForm.File["attachments"] + signinPage := handle(func(c *client) error { + instance, ok := s.SingleInstance() + if !ok { + return s.SigninPage(c) + } + url, sid, err := s.NewSession(instance) + if err != nil { + return err + } + setSessionCookie(c, sid, sessionExp) + redirect(c, url) + return nil + }, NOAUTH, HTML) + + timelinePage := handle(func(c *client) error { + tType, _ := mux.Vars(c.Req)["type"] + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.TimelinePage(c, tType, maxID, minID) + return nil + }, SESSION, HTML) + + defaultTimelinePage := handle(func(c *client) error { + redirect(c, "/timeline/home") + return nil + }, SESSION, HTML) + + threadPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() + reply := q.Get("reply") + return s.ThreadPage(c, id, len(reply) > 1) + }, SESSION, HTML) + + likedByPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + return s.LikedByPage(c, id) + }, SESSION, HTML) + + retweetedByPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + return s.RetweetedByPage(c, id) + }, SESSION, HTML) + + notificationsPage := handle(func(c *client) error { + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.NotificationPage(c, maxID, minID) + }, SESSION, HTML) + + userPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + pageType, _ := mux.Vars(c.Req)["type"] + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.UserPage(c, id, pageType, maxID, minID) + }, SESSION, HTML) + + userSearchPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() + sq := q.Get("q") + offset, _ := strconv.Atoi(q.Get("offset")) + return s.UserSearchPage(c, id, sq, offset) + }, SESSION, HTML) + + aboutPage := handle(func(c *client) error { + return s.AboutPage(c) + }, SESSION, HTML) + + emojisPage := handle(func(c *client) error { + return s.EmojiPage(c) + }, SESSION, HTML) + + searchPage := handle(func(c *client) error { + q := c.Req.URL.Query() + sq := q.Get("q") + qType := q.Get("type") + offset, _ := strconv.Atoi(q.Get("offset")) + return s.SearchPage(c, sq, qType, offset) + }, SESSION, HTML) + + settingsPage := handle(func(c *client) error { + return s.SettingsPage(c) + }, SESSION, HTML) + + signin := handle(func(c *client) error { + instance := c.Req.FormValue("instance") + url, sid, err := s.NewSession(instance) + if err != nil { + return err + } + setSessionCookie(c, sid, sessionExp) + redirect(c, url) + return nil + }, NOAUTH, HTML) + + oauthCallback := handle(func(c *client) error { + q := c.Req.URL.Query() + token := q.Get("code") + token, userID, err := s.Signin(c, token) + if err != nil { + return err + } + + c.Session.AccessToken = token + c.Session.UserID = userID + err = s.sessionRepo.Add(c.Session) + if err != nil { + return err + } + + redirect(c, "/") + return nil + }, SESSION, HTML) + + post := handle(func(c *client) error { + content := c.Req.FormValue("content") + replyToID := c.Req.FormValue("reply_to_id") + format := c.Req.FormValue("format") + visibility := c.Req.FormValue("visibility") + isNSFW := c.Req.FormValue("is_nsfw") == "on" + files := c.Req.MultipartForm.File["attachments"] id, err := s.Post(c, content, replyToID, format, visibility, isNSFW, files) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - location := req.Header.Get("Referer") + location := c.Req.Header.Get("Referer") if len(replyToID) > 0 { location = "/thread/" + replyToID + "#status-" + id } - w.Header().Add("Location", location) - w.WriteHeader(http.StatusFound) - } - - like := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, location) + return nil + }, CSRF, HTML) + like := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.Like(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unlike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unlike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.UnLike(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - retweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + retweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.Retweet(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unretweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unretweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.UnRetweet(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - vote := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - statusID := req.FormValue("status_id") - choices, _ := req.PostForm["choices"] - + vote := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + statusID := c.Req.FormValue("status_id") + choices, _ := c.Req.PostForm["choices"] err := s.Vote(c, id, choices) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")+"#status-"+statusID) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+statusID) - w.WriteHeader(http.StatusFound) - } - - follow := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + follow := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() var reblogs *bool - r, ok := req.URL.Query()["reblogs"] - if ok && len(r) > 0 { + if r, ok := q["reblogs"]; ok && len(r) > 0 { reblogs = new(bool) *reblogs = r[0] == "true" } - err := s.Follow(c, id, reblogs) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unfollow := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unfollow := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnFollow(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - mute := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + mute := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Mute(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unMute := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unMute := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnMute(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - block := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + block := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Block(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unBlock := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unBlock := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnBlock(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - subscribe := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + subscribe := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Subscribe(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unSubscribe := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unSubscribe := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnSubscribe(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - settings := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - visibility := req.FormValue("visibility") - format := req.FormValue("format") - copyScope := req.FormValue("copy_scope") == "true" - threadInNewTab := req.FormValue("thread_in_new_tab") == "true" - hideAttachments := req.FormValue("hide_attachments") == "true" - maskNSFW := req.FormValue("mask_nsfw") == "true" - ni, _ := strconv.Atoi(req.FormValue("notification_interval")) - fluorideMode := req.FormValue("fluoride_mode") == "true" - darkMode := req.FormValue("dark_mode") == "true" - antiDopamineMode := req.FormValue("anti_dopamine_mode") == "true" + settings := handle(func(c *client) error { + visibility := c.Req.FormValue("visibility") + format := c.Req.FormValue("format") + copyScope := c.Req.FormValue("copy_scope") == "true" + threadInNewTab := c.Req.FormValue("thread_in_new_tab") == "true" + hideAttachments := c.Req.FormValue("hide_attachments") == "true" + maskNSFW := c.Req.FormValue("mask_nsfw") == "true" + ni, _ := strconv.Atoi(c.Req.FormValue("notification_interval")) + fluorideMode := c.Req.FormValue("fluoride_mode") == "true" + darkMode := c.Req.FormValue("dark_mode") == "true" + antiDopamineMode := c.Req.FormValue("anti_dopamine_mode") == "true" settings := &model.Settings{ DefaultVisibility: visibility, @@ -609,192 +490,123 @@ func NewHandler(s Service, staticDir string) http.Handler { err := s.SaveSettings(c, settings) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, "/") + return nil + }, CSRF, HTML) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - muteConversation := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + muteConversation := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.MuteConversation(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unMuteConversation := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unMuteConversation := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnMuteConversation(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - delete := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + delete := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Delete(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - readNotifications := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - maxID := req.URL.Query().Get("max_id") - + readNotifications := handle(func(c *client) error { + q := c.Req.URL.Query() + maxID := q.Get("max_id") err := s.ReadNotifications(c, maxID) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - bookmark := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - + bookmark := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") err := s.Bookmark(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unBookmark := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unBookmark := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") err := s.UnBookmark(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - signout := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + signout := handle(func(c *client) error { s.Signout(c) + setSessionCookie(c, "", 0) + redirect(c, "/") + return nil + }, CSRF, HTML) - setSessionCookie(w, "", 0) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - fLike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + fLike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.Like(c, id) if err != nil { - serveJsonError(w, err) - return + return err } + return writeJson(c, count) + }, CSRF, JSON) - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return - } - } - - fUnlike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + fUnlike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.UnLike(c, id) if err != nil { - serveJsonError(w, err) - return - } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return + return err } - } - - fRetweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] + return writeJson(c, count) + }, CSRF, JSON) + fRetweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.Retweet(c, id) if err != nil { - serveJsonError(w, err) - return - } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return + return err } - } - - fUnretweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] + return writeJson(c, count) + }, CSRF, JSON) + fUnretweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.UnRetweet(c, id) if err != nil { - serveJsonError(w, err) - return + return err } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return - } - } + return writeJson(c, count) + }, CSRF, JSON) r.HandleFunc("/", rootPage).Methods(http.MethodGet) r.HandleFunc("/nav", navPage).Methods(http.MethodGet) -- cgit v1.2.3