diff options
Diffstat (limited to 'service')
-rw-r--r-- | service/client.go | 111 | ||||
-rw-r--r-- | service/service.go | 119 | ||||
-rw-r--r-- | service/transport.go | 143 |
3 files changed, 179 insertions, 194 deletions
diff --git a/service/client.go b/service/client.go new file mode 100644 index 0000000..3affd57 --- /dev/null +++ b/service/client.go @@ -0,0 +1,111 @@ +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "strings" + "time" + + "bloat/mastodon" + "bloat/model" + "bloat/renderer" +) + +type client struct { + *mastodon.Client + w http.ResponseWriter + r *http.Request + s *model.Session + csrf string + ctx context.Context + rctx *renderer.Context +} + +func (c *client) setSession(sess *model.Session) error { + var sb strings.Builder + bw := base64.NewEncoder(base64.URLEncoding, &sb) + err := json.NewEncoder(bw).Encode(sess) + bw.Close() + if err != nil { + return err + } + http.SetCookie(c.w, &http.Cookie{ + Name: "session", + Value: sb.String(), + Expires: time.Now().Add(365 * 24 * time.Hour), + }) + return nil +} + +func (c *client) getSession() (sess *model.Session, err error) { + cookie, _ := c.r.Cookie("session") + if cookie == nil { + return nil, errInvalidSession + } + br := base64.NewDecoder(base64.URLEncoding, strings.NewReader(cookie.Value)) + err = json.NewDecoder(br).Decode(&sess) + return +} + +func (c *client) unsetSession() { + http.SetCookie(c.w, &http.Cookie{ + Name: "session", + Value: "", + Expires: time.Now(), + }) +} + +func (c *client) writeJson(data interface{}) error { + return json.NewEncoder(c.w).Encode(map[string]interface{}{ + "data": data, + }) +} + +func (c *client) redirect(url string) { + c.w.Header().Add("Location", url) + c.w.WriteHeader(http.StatusFound) +} + +func (c *client) authenticate(t int) (err error) { + csrf := c.r.FormValue("csrf_token") + ref := c.r.URL.RequestURI() + defer func() { + if c.s == nil { + c.s = &model.Session{ + Settings: *model.NewSettings(), + } + } + c.rctx = &renderer.Context{ + HideAttachments: c.s.Settings.HideAttachments, + MaskNSFW: c.s.Settings.MaskNSFW, + ThreadInNewTab: c.s.Settings.ThreadInNewTab, + FluorideMode: c.s.Settings.FluorideMode, + DarkMode: c.s.Settings.DarkMode, + CSRFToken: c.s.CSRFToken, + UserID: c.s.UserID, + AntiDopamineMode: c.s.Settings.AntiDopamineMode, + UserCSS: c.s.Settings.CSS, + Referrer: ref, + } + }() + if t < SESSION { + return + } + sess, err := c.getSession() + if err != nil { + return err + } + c.s = sess + c.Client = mastodon.NewClient(&mastodon.Config{ + Server: "https://" + c.s.Instance, + ClientID: c.s.ClientID, + ClientSecret: c.s.ClientSecret, + AccessToken: c.s.AccessToken, + }) + if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) { + return errInvalidCSRFToken + } + return +} diff --git a/service/service.go b/service/service.go index cda42f8..432f938 100644 --- a/service/service.go +++ b/service/service.go @@ -27,14 +27,11 @@ type service struct { instance string postFormats []model.PostFormat renderer renderer.Renderer - sessionRepo model.SessionRepo - appRepo model.AppRepo } func NewService(cname string, cscope string, cwebsite string, css string, instance string, postFormats []model.PostFormat, - renderer renderer.Renderer, sessionRepo model.SessionRepo, - appRepo model.AppRepo) *service { + renderer renderer.Renderer) *service { return &service{ cname: cname, cscope: cscope, @@ -43,57 +40,9 @@ func NewService(cname string, cscope string, cwebsite string, instance: instance, postFormats: postFormats, renderer: renderer, - sessionRepo: sessionRepo, - appRepo: appRepo, } } -func (s *service) authenticate(c *client, sid string, csrf string, ref string, t int) (err error) { - var sett *model.Settings - defer func() { - if sett == nil { - sett = model.NewSettings() - } - c.rctx = &renderer.Context{ - HideAttachments: sett.HideAttachments, - MaskNSFW: sett.MaskNSFW, - ThreadInNewTab: sett.ThreadInNewTab, - FluorideMode: sett.FluorideMode, - DarkMode: sett.DarkMode, - CSRFToken: c.s.CSRFToken, - UserID: c.s.UserID, - AntiDopamineMode: sett.AntiDopamineMode, - UserCSS: sett.CSS, - Referrer: ref, - } - }() - if t < SESSION { - return - } - if len(sid) < 1 { - return errInvalidSession - } - c.s, err = s.sessionRepo.Get(sid) - if err != nil { - return errInvalidSession - } - sett = &c.s.Settings - app, err := s.appRepo.Get(c.s.InstanceDomain) - if err != nil { - return err - } - c.Client = mastodon.NewClient(&mastodon.Config{ - Server: app.InstanceURL, - ClientID: app.ClientID, - ClientSecret: app.ClientSecret, - AccessToken: c.s.AccessToken, - }) - if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) { - return errInvalidCSRFToken - } - return -} - func (s *service) cdata(c *client, title string, count int, rinterval int, target string) (data *renderer.CommonData) { data = &renderer.CommonData{ @@ -820,7 +769,7 @@ func (s *service) SingleInstance() (instance string, ok bool) { return } -func (s *service) NewSession(c *client, instance string) (rurl string, sid string, err error) { +func (s *service) NewSession(c *client, instance string) (rurl string, sess *model.Session, err error) { var instanceURL string if strings.HasPrefix(instance, "https://") { instanceURL = instance @@ -829,7 +778,7 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sid strin instanceURL = "https://" + instance } - sid, err = util.NewSessionID() + sid, err := util.NewSessionID() if err != nil { return } @@ -838,42 +787,23 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sid strin return } - sess := model.Session{ - ID: sid, - InstanceDomain: instance, - CSRFToken: csrf, - Settings: *model.NewSettings(), - } - err = s.sessionRepo.Add(sess) + app, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{ + Server: instanceURL, + ClientName: s.cname, + Scopes: s.cscope, + Website: s.cwebsite, + RedirectURIs: s.cwebsite + "/oauth_callback", + }) if err != nil { return } - - app, err := s.appRepo.Get(instance) - if err != nil { - if err != model.ErrAppNotFound { - return - } - mastoApp, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{ - Server: instanceURL, - ClientName: s.cname, - Scopes: s.cscope, - Website: s.cwebsite, - RedirectURIs: s.cwebsite + "/oauth_callback", - }) - if err != nil { - return "", "", err - } - app = model.App{ - InstanceDomain: instance, - InstanceURL: instanceURL, - ClientID: mastoApp.ClientID, - ClientSecret: mastoApp.ClientSecret, - } - err = s.appRepo.Add(app) - if err != nil { - return "", "", err - } + sess = &model.Session{ + ID: sid, + Instance: instance, + ClientID: app.ClientID, + ClientSecret: app.ClientSecret, + CSRFToken: csrf, + Settings: *model.NewSettings(), } u, err := url.Parse("/oauth/authorize") @@ -907,12 +837,7 @@ func (s *service) Signin(c *client, code string) (err error) { } c.s.AccessToken = c.GetAccessToken(c.ctx) c.s.UserID = u.ID - return s.sessionRepo.Add(c.s) -} - -func (s *service) Signout(c *client) (err error) { - s.sessionRepo.Remove(c.s.ID) - return + return c.setSession(c.s) } func (s *service) Post(c *client, content string, replyToID string, @@ -1044,12 +969,8 @@ func (s *service) SaveSettings(c *client, settings *model.Settings) (err error) if len(settings.CSS) > 1<<20 { return errInvalidArgument } - sess, err := s.sessionRepo.Get(c.s.ID) - if err != nil { - return - } - sess.Settings = *settings - return s.sessionRepo.Add(sess) + c.s.Settings = *settings + return c.setSession(c.s) } func (s *service) MuteConversation(c *client, id string) (err error) { diff --git a/service/transport.go b/service/transport.go index 4518b1a..471a7d4 100644 --- a/service/transport.go +++ b/service/transport.go @@ -1,25 +1,18 @@ package service import ( - "context" "encoding/json" "log" "net/http" "strconv" "time" - "bloat/mastodon" "bloat/model" - "bloat/renderer" "github.com/gorilla/mux" ) const ( - sessionExp = 365 * 24 * time.Hour -) - -const ( HTML int = iota JSON ) @@ -30,35 +23,6 @@ const ( CSRF ) -type client struct { - *mastodon.Client - w http.ResponseWriter - r *http.Request - s model.Session - csrf string - ctx context.Context - rctx *renderer.Context -} - -func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sid, - Expires: time.Now().Add(exp), - }) -} - -func writeJson(c *client, data interface{}) error { - return json.NewEncoder(c.w).Encode(map[string]interface{}{ - "data": data, - }) -} - -func redirect(c *client, url string) { - c.w.Header().Add("Location", url) - c.w.WriteHeader(http.StatusFound) -} - func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r := mux.NewRouter() @@ -75,16 +39,6 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } } - authenticate := func(c *client, t int) error { - var sid string - if cookie, _ := c.r.Cookie("session_id"); cookie != nil { - sid = cookie.Value - } - csrf := c.r.FormValue("csrf_token") - ref := c.r.URL.RequestURI() - return s.authenticate(c, sid, csrf, ref, t) - } - handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { var err error @@ -108,7 +62,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } c.w.Header().Add("Content-Type", ct) - err = authenticate(c, at) + err = c.authenticate(at) if err != nil { writeError(c, err, rt, req.Method == http.MethodGet) return @@ -123,16 +77,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } rootPage := handle(func(c *client) error { - err := authenticate(c, SESSION) + err := c.authenticate(SESSION) if err != nil { if err == errInvalidSession { - redirect(c, "/signin") + c.redirect("/signin") return nil } return err } if !c.s.IsLoggedIn() { - redirect(c, "/signin") + c.redirect("/signin") return nil } return s.RootPage(c) @@ -147,12 +101,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if !ok { return s.SigninPage(c) } - url, sid, err := s.NewSession(c, instance) + url, sess, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c.w, sid, sessionExp) - redirect(c, url) + c.setSession(sess) + c.redirect(url) return nil }, NOAUTH, HTML) @@ -167,7 +121,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { }, SESSION, HTML) defaultTimelinePage := handle(func(c *client) error { - redirect(c, "/timeline/home") + c.redirect("/timeline/home") return nil }, SESSION, HTML) @@ -243,12 +197,12 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { signin := handle(func(c *client) error { instance := c.r.FormValue("instance") - url, sid, err := s.NewSession(c, instance) + url, sess, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c.w, sid, sessionExp) - redirect(c, url) + c.setSession(sess) + c.redirect(url) return nil }, NOAUTH, HTML) @@ -259,7 +213,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, "/") + c.redirect("/") return nil }, SESSION, HTML) @@ -287,7 +241,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } else { location = c.r.FormValue("referrer") } - redirect(c, location) + c.redirect(location) return nil }, CSRF, HTML) @@ -301,7 +255,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -315,7 +269,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -329,7 +283,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -343,7 +297,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -355,7 +309,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")+"#status-"+statusID) + c.redirect(c.r.FormValue("referrer") + "#status-" + statusID) return nil }, CSRF, HTML) @@ -371,7 +325,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -381,7 +335,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -391,7 +345,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -401,7 +355,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -417,7 +371,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -427,7 +381,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -437,7 +391,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -447,7 +401,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -457,7 +411,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -467,7 +421,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -504,7 +458,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, "/") + c.redirect("/") return nil }, CSRF, HTML) @@ -514,7 +468,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -524,7 +478,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -534,7 +488,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -545,7 +499,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -559,7 +513,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -573,7 +527,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -584,7 +538,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -594,7 +548,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -608,7 +562,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -618,7 +572,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -629,7 +583,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -648,7 +602,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -660,14 +614,13 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) signout := handle(func(c *client) error { - s.Signout(c) - setSessionCookie(c.w, "", 0) - redirect(c, "/") + c.unsetSession() + c.redirect("/") return nil }, CSRF, HTML) @@ -677,7 +630,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fUnlike := handle(func(c *client) error { @@ -686,7 +639,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fRetweet := handle(func(c *client) error { @@ -695,7 +648,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fUnretweet := handle(func(c *client) error { @@ -704,7 +657,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) r.HandleFunc("/", rootPage).Methods(http.MethodGet) |