diff options
Diffstat (limited to 'service')
-rw-r--r-- | service/client.go | 118 | ||||
-rw-r--r-- | service/service.go | 339 | ||||
-rw-r--r-- | service/transport.go | 269 |
3 files changed, 453 insertions, 273 deletions
diff --git a/service/client.go b/service/client.go new file mode 100644 index 0000000..18ebb52 --- /dev/null +++ b/service/client.go @@ -0,0 +1,118 @@ +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "strings" + "time" + + "bloat/mastodon" + "bloat/model" + "bloat/renderer" +) + +type client struct { + *mastodon.Client + w http.ResponseWriter + r *http.Request + s *model.Session + csrf string + ctx context.Context + rctx *renderer.Context +} + +func (c *client) setSession(sess *model.Session) error { + var sb strings.Builder + bw := base64.NewEncoder(base64.URLEncoding, &sb) + err := json.NewEncoder(bw).Encode(sess) + bw.Close() + if err != nil { + return err + } + http.SetCookie(c.w, &http.Cookie{ + Name: "session", + Path: "/", + HttpOnly: true, + Value: sb.String(), + Expires: time.Now().Add(365 * 24 * time.Hour), + }) + return nil +} + +func (c *client) getSession() (sess *model.Session, err error) { + cookie, _ := c.r.Cookie("session") + if cookie == nil { + return nil, errInvalidSession + } + br := base64.NewDecoder(base64.URLEncoding, strings.NewReader(cookie.Value)) + err = json.NewDecoder(br).Decode(&sess) + return +} + +func (c *client) unsetSession() { + http.SetCookie(c.w, &http.Cookie{ + Name: "session", + Path: "/", + Value: "", + Expires: time.Now(), + }) +} + +func (c *client) writeJson(data interface{}) error { + return json.NewEncoder(c.w).Encode(map[string]interface{}{ + "data": data, + }) +} + +func (c *client) redirect(url string) { + c.w.Header().Add("Location", url) + c.w.WriteHeader(http.StatusFound) +} + +func (c *client) authenticate(t int, instance string) (err error) { + csrf := c.r.FormValue("csrf_token") + ref := c.r.URL.RequestURI() + defer func() { + if c.s == nil { + c.s = &model.Session{ + Settings: *model.NewSettings(), + } + } + c.rctx = &renderer.Context{ + HideAttachments: c.s.Settings.HideAttachments, + MaskNSFW: c.s.Settings.MaskNSFW, + ThreadInNewTab: c.s.Settings.ThreadInNewTab, + FluorideMode: c.s.Settings.FluorideMode, + DarkMode: c.s.Settings.DarkMode, + CSRFToken: c.s.CSRFToken, + UserID: c.s.UserID, + AntiDopamineMode: c.s.Settings.AntiDopamineMode, + UserCSS: c.s.Settings.CSS, + Referrer: ref, + } + }() + if t < SESSION { + return + } + sess, err := c.getSession() + if err != nil { + return err + } + c.s = sess + if len(instance) > 0 && c.s.Instance != instance { + return errors.New("invalid instance") + } + c.Client = mastodon.NewClient(&mastodon.Config{ + Server: "https://" + c.s.Instance, + ClientID: c.s.ClientID, + ClientSecret: c.s.ClientSecret, + AccessToken: c.s.AccessToken, + }) + if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) { + return errInvalidCSRFToken + } + return +} diff --git a/service/service.go b/service/service.go index cda42f8..90a28d5 100644 --- a/service/service.go +++ b/service/service.go @@ -1,6 +1,8 @@ package service import ( + "crypto/sha256" + "encoding/base64" "errors" "fmt" "mime/multipart" @@ -27,14 +29,11 @@ type service struct { instance string postFormats []model.PostFormat renderer renderer.Renderer - sessionRepo model.SessionRepo - appRepo model.AppRepo } func NewService(cname string, cscope string, cwebsite string, css string, instance string, postFormats []model.PostFormat, - renderer renderer.Renderer, sessionRepo model.SessionRepo, - appRepo model.AppRepo) *service { + renderer renderer.Renderer) *service { return &service{ cname: cname, cscope: cscope, @@ -43,61 +42,18 @@ func NewService(cname string, cscope string, cwebsite string, instance: instance, postFormats: postFormats, renderer: renderer, - sessionRepo: sessionRepo, - appRepo: appRepo, } } -func (s *service) authenticate(c *client, sid string, csrf string, ref string, t int) (err error) { - var sett *model.Settings - defer func() { - if sett == nil { - sett = model.NewSettings() - } - c.rctx = &renderer.Context{ - HideAttachments: sett.HideAttachments, - MaskNSFW: sett.MaskNSFW, - ThreadInNewTab: sett.ThreadInNewTab, - FluorideMode: sett.FluorideMode, - DarkMode: sett.DarkMode, - CSRFToken: c.s.CSRFToken, - UserID: c.s.UserID, - AntiDopamineMode: sett.AntiDopamineMode, - UserCSS: sett.CSS, - Referrer: ref, - } - }() - if t < SESSION { - return - } - if len(sid) < 1 { - return errInvalidSession - } - c.s, err = s.sessionRepo.Get(sid) - if err != nil { - return errInvalidSession - } - sett = &c.s.Settings - app, err := s.appRepo.Get(c.s.InstanceDomain) - if err != nil { - return err - } - c.Client = mastodon.NewClient(&mastodon.Config{ - Server: app.InstanceURL, - ClientID: app.ClientID, - ClientSecret: app.ClientSecret, - AccessToken: c.s.AccessToken, - }) - if t >= CSRF && (len(csrf) < 1 || csrf != c.s.CSRFToken) { - return errInvalidCSRFToken - } - return -} - func (s *service) cdata(c *client, title string, count int, rinterval int, target string) (data *renderer.CommonData) { + if title == "" { + title = s.cname + } else { + title += " - " + s.cname + } data = &renderer.CommonData{ - Title: title + " - " + s.cname, + Title: title, CustomCSS: s.css, Count: count, RefreshInterval: rinterval, @@ -105,6 +61,7 @@ func (s *service) cdata(c *client, title string, count int, rinterval int, } if c != nil && c.s.IsLoggedIn() { data.CSRFToken = c.s.CSRFToken + data.Title += " - " + c.s.Instance } return } @@ -130,7 +87,7 @@ func (s *service) ErrorPage(c *client, err error, retry bool) error { } func (s *service) SigninPage(c *client) (err error) { - cdata := s.cdata(nil, "signin", 0, 0, "") + cdata := s.cdata(nil, "Signin", 0, 0, "") data := &renderer.SigninData{ CommonData: cdata, } @@ -138,8 +95,9 @@ func (s *service) SigninPage(c *client) (err error) { } func (s *service) RootPage(c *client) (err error) { + cdata := s.cdata(c, "", 0, 0, "") data := &renderer.RootData{ - Title: s.cname, + CommonData: cdata, } return s.renderer.Render(c.rctx, c.w, renderer.RootPage, data) } @@ -154,7 +112,7 @@ func (s *service) NavPage(c *client) (err error) { DefaultFormat: c.s.Settings.DefaultFormat, Formats: s.postFormats, } - cdata := s.cdata(c, "nav", 0, 0, "main") + cdata := s.cdata(c, "Nav", 0, 0, "main") data := &renderer.NavData{ User: u, CommonData: cdata, @@ -163,7 +121,7 @@ func (s *service) NavPage(c *client) (err error) { return s.renderer.Render(c.rctx, c.w, renderer.NavPage, data) } -func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, +func (s *service) TimelinePage(c *client, tType, q, listId, maxID, minID string) (err error) { var nextLink, prevLink, title string @@ -173,6 +131,7 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, MinID: minID, Limit: 20, } + var refreshLink = "/timeline/" + tType switch tType { default: @@ -196,11 +155,12 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, } title = "Local Timeline" case "remote": - if len(instance) > 0 { - statuses, err = c.GetTimelinePublic(c.ctx, false, instance, &pg) + if len(q) > 0 { + statuses, err = c.GetTimelinePublic(c.ctx, false, q, &pg) if err != nil { return err } + refreshLink += "?q=" + url.QueryEscape(q) } title = "Remote Timeline" case "twkn": @@ -219,6 +179,17 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, return err } title = "List Timeline - " + list.Title + refreshLink += "?list=" + listId + case "hashtag": + q = strings.TrimPrefix(q, "#") + if len(q) > 0 { + statuses, err = c.GetTimelineHashtag(c.ctx, q, false, &pg) + if err != nil { + return err + } + refreshLink += "?q=" + url.QueryEscape(q) + } + title = "Hashtag Timeline" } for i := range statuses { @@ -230,8 +201,8 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, if (len(maxID) > 0 || len(minID) > 0) && len(statuses) > 0 { v := make(url.Values) v.Set("min_id", statuses[0].ID) - if len(instance) > 0 { - v.Set("instance", instance) + if len(q) > 0 { + v.Set("q", q) } if len(listId) > 0 { v.Set("list", listId) @@ -242,8 +213,8 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, if len(minID) > 0 || (len(pg.MaxID) > 0 && len(statuses) == 20) { v := make(url.Values) v.Set("max_id", pg.MaxID) - if len(instance) > 0 { - v.Set("instance", instance) + if len(q) > 0 { + v.Set("q", q) } if len(listId) > 0 { v.Set("list", listId) @@ -251,15 +222,16 @@ func (s *service) TimelinePage(c *client, tType, instance, listId, maxID, nextLink = "/timeline/" + tType + "?" + v.Encode() } - cdata := s.cdata(c, tType+" timeline ", 0, 0, "") + cdata := s.cdata(c, title, 0, 0, "") data := &renderer.TimelineData{ - Title: title, - Type: tType, - Instance: instance, - Statuses: statuses, - NextLink: nextLink, - PrevLink: prevLink, - CommonData: cdata, + Title: title, + Type: tType, + Q: q, + Statuses: statuses, + NextLink: nextLink, + PrevLink: prevLink, + RefreshLink: refreshLink, + CommonData: cdata, } return s.renderer.Render(c.rctx, c.w, renderer.TimelinePage, data) } @@ -319,7 +291,7 @@ func (s *service) ListPage(c *client, id string, q string) (err error) { } var searchAccounts []*mastodon.Account if len(q) > 0 { - result, err := c.Search(c.ctx, q, "accounts", 20, true, 0, id, true) + result, err := c.Search(c.ctx, q, "accounts", 20, true, 0, id, false) if err != nil { return err } @@ -404,7 +376,7 @@ func (s *service) ThreadPage(c *client, id string, reply bool) (err error) { addToReplyMap(replies, statuses[i].InReplyToID, statuses[i].ID, i+1) } - cdata := s.cdata(c, "post by "+status.Account.DisplayName, 0, 0, "") + cdata := s.cdata(c, "Post by "+status.Account.DisplayName, 0, 0, "") data := &renderer.ThreadData{ Statuses: statuses, PostContext: pctx, @@ -460,7 +432,7 @@ func (s *service) QuickReplyPage(c *client, id string) (err error) { }, } - cdata := s.cdata(c, "post by "+status.Account.DisplayName, 0, 0, "") + cdata := s.cdata(c, "Post by "+status.Account.DisplayName, 0, 0, "") data := &renderer.QuickReplyData{ Ancestor: ancestor, Status: status, @@ -475,7 +447,7 @@ func (s *service) LikedByPage(c *client, id string) (err error) { if err != nil { return } - cdata := s.cdata(c, "likes", 0, 0, "") + cdata := s.cdata(c, "Likes", 0, 0, "") data := &renderer.LikedByData{ CommonData: cdata, Users: likers, @@ -488,7 +460,7 @@ func (s *service) RetweetedByPage(c *client, id string) (err error) { if err != nil { return } - cdata := s.cdata(c, "retweets", 0, 0, "") + cdata := s.cdata(c, "Retweets", 0, 0, "") data := &renderer.RetweetedByData{ CommonData: cdata, Users: retweeters, @@ -513,7 +485,7 @@ func (s *service) NotificationPage(c *client, maxID string, // Explicitly include the supported types. // For now, only Pleroma supports this option, Mastadon // will simply ignore the unknown params. - includes = []string{"follow", "follow_request", "mention", "reblog", "favourite"} + includes = []string{"follow", "follow_request", "mention", "reblog", "favourite", "status"} } if c.s.Settings.AntiDopamineMode { excludes = append(excludes, "follow", "favourite", "reblog") @@ -537,7 +509,7 @@ func (s *service) NotificationPage(c *client, maxID string, nextLink = "/notifications?max_id=" + pg.MaxID } - cdata := s.cdata(c, "notifications", unreadCount, + cdata := s.cdata(c, "Notifications", unreadCount, c.s.Settings.NotificationInterval, "main") data := &renderer.NotificationData{ Notifications: notifications, @@ -560,12 +532,19 @@ func (s *service) UserPage(c *client, id string, pageType string, MinID: minID, Limit: 20, } + isCurrent := c.s.UserID == id - user, err := c.GetAccount(c.ctx, id) + // Some fields like AccountSource are only available in the + // CurrentUser API + var user *mastodon.Account + if isCurrent { + user, err = c.GetAccountCurrentUser(c.ctx) + } else { + user, err = c.GetAccount(c.ctx, id) + } if err != nil { return } - isCurrent := c.s.UserID == user.ID switch pageType { case "": @@ -677,7 +656,6 @@ func (s *service) UserPage(c *client, id string, pageType string, cdata := s.cdata(c, user.DisplayName+" @"+user.Acct, 0, 0, "") data := &renderer.UserData{ User: user, - IsCurrent: isCurrent, Type: pageType, Users: users, Statuses: statuses, @@ -691,7 +669,7 @@ func (s *service) UserSearchPage(c *client, id string, q string, offset int) (err error) { var nextLink string - var title = "search" + var title = "Search" user, err := c.GetAccount(c.ctx, id) if err != nil { @@ -711,7 +689,7 @@ func (s *service) UserSearchPage(c *client, if len(results.Statuses) == 20 { offset += 20 nextLink = fmt.Sprintf("/usersearch/%s?q=%s&offset=%d", id, - q, offset) + url.QueryEscape(q), offset) } if len(q) > 0 { @@ -729,8 +707,21 @@ func (s *service) UserSearchPage(c *client, return s.renderer.Render(c.rctx, c.w, renderer.UserSearchPage, data) } +func (s *service) MutePage(c *client, id string) (err error) { + user, err := c.GetAccount(c.ctx, id) + if err != nil { + return + } + cdata := s.cdata(c, "Mute "+user.DisplayName+" @"+user.Acct, 0, 0, "") + data := &renderer.UserData{ + User: user, + CommonData: cdata, + } + return s.renderer.Render(c.rctx, c.w, renderer.MutePage, data) +} + func (s *service) AboutPage(c *client) (err error) { - cdata := s.cdata(c, "about", 0, 0, "") + cdata := s.cdata(c, "About", 0, 0, "") data := &renderer.AboutData{ CommonData: cdata, } @@ -742,7 +733,7 @@ func (s *service) EmojiPage(c *client) (err error) { if err != nil { return } - cdata := s.cdata(c, "emojis", 0, 0, "") + cdata := s.cdata(c, "Emojis", 0, 0, "") data := &renderer.EmojiData{ Emojis: emojis, CommonData: cdata, @@ -750,17 +741,21 @@ func (s *service) EmojiPage(c *client) (err error) { return s.renderer.Render(c.rctx, c.w, renderer.EmojiPage, data) } -func (s *service) SearchPage(c *client, - q string, qType string, offset int) (err error) { +func (s *service) SearchPage(c *client, q string, qType string, offset int) ( + rurl string, err error) { var nextLink string - var title = "search" + var title = "Search" var results *mastodon.Results if len(q) > 0 { + if qType == "hashtags" { + rurl = "/timeline/hashtag?q=" + url.QueryEscape(q) + return + } results, err = c.Search(c.ctx, q, qType, 20, true, offset, "", false) if err != nil { - return err + return "", err } } else { results = &mastodon.Results{} @@ -770,7 +765,7 @@ func (s *service) SearchPage(c *client, (qType == "statuses" && len(results.Statuses) == 20) { offset += 20 nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", - q, qType, offset) + url.QueryEscape(q), qType, offset) } if len(q) > 0 { @@ -786,11 +781,11 @@ func (s *service) SearchPage(c *client, Statuses: results.Statuses, NextLink: nextLink, } - return s.renderer.Render(c.rctx, c.w, renderer.SearchPage, data) + return "", s.renderer.Render(c.rctx, c.w, renderer.SearchPage, data) } func (s *service) SettingsPage(c *client) (err error) { - cdata := s.cdata(c, "settings", 0, 0, "") + cdata := s.cdata(c, "Settings", 0, 0, "") data := &renderer.SettingsData{ CommonData: cdata, Settings: &c.s.Settings, @@ -804,7 +799,7 @@ func (svc *service) FiltersPage(c *client) (err error) { if err != nil { return } - cdata := svc.cdata(c, "filters", 0, 0, "") + cdata := svc.cdata(c, "Filters", 0, 0, "") data := &renderer.FiltersData{ CommonData: cdata, Filters: filters, @@ -812,6 +807,55 @@ func (svc *service) FiltersPage(c *client) (err error) { return svc.renderer.Render(c.rctx, c.w, renderer.FiltersPage, data) } +func (svc *service) ProfilePage(c *client) (err error) { + u, err := c.GetAccountCurrentUser(c.ctx) + if err != nil { + return + } + // Some instances allow more than 4 fields, but make sure that there are + // at least 4 fields in the slice because the template depends on it. + if u.Source.Fields == nil { + u.Source.Fields = new([]mastodon.Field) + } + for len(*u.Source.Fields) < 4 { + *u.Source.Fields = append(*u.Source.Fields, mastodon.Field{}) + } + cdata := svc.cdata(c, "Edit profile", 0, 0, "") + data := &renderer.ProfileData{ + CommonData: cdata, + User: u, + } + return svc.renderer.Render(c.rctx, c.w, renderer.ProfilePage, data) +} + +func (s *service) ProfileUpdate(c *client, name, bio string, avatar, banner *multipart.FileHeader, + fields []mastodon.Field, locked bool) (err error) { + // Need to pass empty data to clear fields + if len(fields) == 0 { + fields = append(fields, mastodon.Field{}) + } + p := &mastodon.Profile{ + DisplayName: &name, + Note: &bio, + Avatar: avatar, + Header: banner, + Fields: &fields, + Locked: &locked, + } + _, err = c.AccountUpdate(c.ctx, p) + return err +} + +func (s *service) ProfileDelAvatar(c *client) (err error) { + _, err = c.AccountDeleteAvatar(c.ctx) + return +} + +func (s *service) ProfileDelBanner(c *client) (err error) { + _, err = c.AccountDeleteHeader(c.ctx) + return err +} + func (s *service) SingleInstance() (instance string, ok bool) { if len(s.instance) > 0 { instance = s.instance @@ -820,7 +864,7 @@ func (s *service) SingleInstance() (instance string, ok bool) { return } -func (s *service) NewSession(c *client, instance string) (rurl string, sid string, err error) { +func (s *service) NewSession(c *client, instance string) (rurl string, sess *model.Session, err error) { var instanceURL string if strings.HasPrefix(instance, "https://") { instanceURL = instance @@ -829,66 +873,29 @@ func (s *service) NewSession(c *client, instance string) (rurl string, sid strin instanceURL = "https://" + instance } - sid, err = util.NewSessionID() - if err != nil { - return - } csrf, err := util.NewCSRFToken() if err != nil { return } - sess := model.Session{ - ID: sid, - InstanceDomain: instance, - CSRFToken: csrf, - Settings: *model.NewSettings(), - } - err = s.sessionRepo.Add(sess) + app, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{ + Server: instanceURL, + ClientName: s.cname, + Scopes: s.cscope, + Website: s.cwebsite, + RedirectURIs: s.cwebsite + "/oauth_callback", + }) if err != nil { return } - - app, err := s.appRepo.Get(instance) - if err != nil { - if err != model.ErrAppNotFound { - return - } - mastoApp, err := mastodon.RegisterApp(c.ctx, &mastodon.AppConfig{ - Server: instanceURL, - ClientName: s.cname, - Scopes: s.cscope, - Website: s.cwebsite, - RedirectURIs: s.cwebsite + "/oauth_callback", - }) - if err != nil { - return "", "", err - } - app = model.App{ - InstanceDomain: instance, - InstanceURL: instanceURL, - ClientID: mastoApp.ClientID, - ClientSecret: mastoApp.ClientSecret, - } - err = s.appRepo.Add(app) - if err != nil { - return "", "", err - } - } - - u, err := url.Parse("/oauth/authorize") - if err != nil { - return + rurl = app.AuthURI + sess = &model.Session{ + Instance: instance, + ClientID: app.ClientID, + ClientSecret: app.ClientSecret, + CSRFToken: csrf, + Settings: *model.NewSettings(), } - - q := make(url.Values) - q.Set("scope", "read write follow") - q.Set("client_id", app.ClientID) - q.Set("response_type", "code") - q.Set("redirect_uri", s.cwebsite+"/oauth_callback") - u.RawQuery = q.Encode() - - rurl = instanceURL + u.String() return } @@ -907,16 +914,15 @@ func (s *service) Signin(c *client, code string) (err error) { } c.s.AccessToken = c.GetAccessToken(c.ctx) c.s.UserID = u.ID - return s.sessionRepo.Add(c.s) + return c.setSession(c.s) } func (s *service) Signout(c *client) (err error) { - s.sessionRepo.Remove(c.s.ID) - return + return c.RevokeToken(c.ctx) } func (s *service) Post(c *client, content string, replyToID string, - format string, visibility string, isNSFW bool, + format string, visibility string, isNSFW bool, isQuote bool, files []*multipart.FileHeader) (id string, err error) { var mediaIDs []string @@ -928,9 +934,16 @@ func (s *service) Post(c *client, content string, replyToID string, mediaIDs = append(mediaIDs, a.ID) } + var quoteID string + if isQuote { + quoteID = replyToID + replyToID = "" + } + tweet := &mastodon.Toot{ Status: content, InReplyToID: replyToID, + QuoteID: quoteID, MediaIDs: mediaIDs, ContentType: format, Visibility: visibility, @@ -1005,8 +1018,8 @@ func (s *service) Reject(c *client, id string) (err error) { return c.FollowRequestReject(c.ctx, id) } -func (s *service) Mute(c *client, id string, notifications *bool) (err error) { - _, err = c.AccountMute(c.ctx, id, notifications) +func (s *service) Mute(c *client, id string, notifications bool, duration int) (err error) { + _, err = c.AccountMute(c.ctx, id, notifications, duration) return } @@ -1041,15 +1054,21 @@ func (s *service) SaveSettings(c *client, settings *model.Settings) (err error) default: return errInvalidArgument } - if len(settings.CSS) > 1<<20 { - return errInvalidArgument - } - sess, err := s.sessionRepo.Get(c.s.ID) - if err != nil { - return + if len(settings.CSS) > 0 { + if len(settings.CSS) > 1<<20 { + return errInvalidArgument + } + // For some reason, browsers convert CRLF to LF before calculating + // the hash of the inline resources. + settings.CSS = strings.Replace(settings.CSS, "\x0d\x0a", "\x0a", -1) + + h := sha256.Sum256([]byte(settings.CSS)) + settings.CSSHash = base64.StdEncoding.EncodeToString(h[:]) + } else { + settings.CSSHash = "" } - sess.Settings = *settings - return s.sessionRepo.Add(sess) + c.s.Settings = *settings + return c.setSession(c.s) } func (s *service) MuteConversation(c *client, id string) (err error) { @@ -1082,7 +1101,7 @@ func (s *service) UnBookmark(c *client, id string) (err error) { func (svc *service) Filter(c *client, phrase string, wholeWord bool) (err error) { fctx := []string{"home", "notifications", "public", "thread"} - return c.AddFilter(c.ctx, phrase, fctx, true, wholeWord, nil) + return c.AddFilter(c.ctx, phrase, fctx, false, wholeWord, nil) } func (svc *service) UnFilter(c *client, id string) (err error) { diff --git a/service/transport.go b/service/transport.go index 4518b1a..e0372bd 100644 --- a/service/transport.go +++ b/service/transport.go @@ -1,25 +1,21 @@ package service import ( - "context" "encoding/json" + "fmt" "log" + "mime/multipart" "net/http" "strconv" "time" "bloat/mastodon" "bloat/model" - "bloat/renderer" "github.com/gorilla/mux" ) const ( - sessionExp = 365 * 24 * time.Hour -) - -const ( HTML int = iota JSON ) @@ -30,36 +26,16 @@ const ( CSRF ) -type client struct { - *mastodon.Client - w http.ResponseWriter - r *http.Request - s model.Session - csrf string - ctx context.Context - rctx *renderer.Context -} +const csp = "default-src 'none';" + + " img-src *;" + + " media-src *;" + + " font-src *;" + + " child-src *;" + + " connect-src 'self';" + + " script-src 'self';" + + " style-src 'self'" -func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { - http.SetCookie(w, &http.Cookie{ - Name: "session_id", - Value: sid, - Expires: time.Now().Add(exp), - }) -} - -func writeJson(c *client, data interface{}) error { - return json.NewEncoder(c.w).Encode(map[string]interface{}{ - "data": data, - }) -} - -func redirect(c *client, url string) { - c.w.Header().Add("Location", url) - c.w.WriteHeader(http.StatusFound) -} - -func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { +func NewHandler(s *service, verbose bool, staticDir string) http.Handler { r := mux.NewRouter() writeError := func(c *client, err error, t int, retry bool) { @@ -75,16 +51,6 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } } - authenticate := func(c *client, t int) error { - var sid string - if cookie, _ := c.r.Cookie("session_id"); cookie != nil { - sid = cookie.Value - } - csrf := c.r.FormValue("csrf_token") - ref := c.r.URL.RequestURI() - return s.authenticate(c, sid, csrf, ref, t) - } - handle := func(f func(c *client) error, at int, rt int) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { var err error @@ -94,26 +60,35 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r: req, } - defer func(begin time.Time) { - logger.Printf("path=%s, err=%v, took=%v\n", - req.URL.Path, err, time.Since(begin)) - }(time.Now()) + if verbose { + defer func(begin time.Time) { + log.Printf("path=%s, err=%v, took=%v\n", + req.URL.Path, err, time.Since(begin)) + }(time.Now()) + } - var ct string + h := c.w.Header() switch rt { case HTML: - ct = "text/html; charset=utf-8" + h.Set("Content-Type", "text/html; charset=utf-8") + h.Set("Content-Security-Policy", csp) case JSON: - ct = "application/json" + h.Set("Content-Type", "application/json") } - c.w.Header().Add("Content-Type", ct) - err = authenticate(c, at) + err = c.authenticate(at, s.instance) if err != nil { writeError(c, err, rt, req.Method == http.MethodGet) return } + // Override the CSP header to allow custom CSS + if rt == HTML && len(c.s.Settings.CSS) > 0 && + len(c.s.Settings.CSSHash) > 0 { + v := fmt.Sprintf("%s 'sha256-%s'", csp, c.s.Settings.CSSHash) + h.Set("Content-Security-Policy", v) + } + err = f(c) if err != nil { writeError(c, err, rt, req.Method == http.MethodGet) @@ -123,16 +98,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } rootPage := handle(func(c *client) error { - err := authenticate(c, SESSION) + err := c.authenticate(SESSION, "") if err != nil { if err == errInvalidSession { - redirect(c, "/signin") + c.redirect("/signin") return nil } return err } if !c.s.IsLoggedIn() { - redirect(c, "/signin") + c.redirect("/signin") return nil } return s.RootPage(c) @@ -147,27 +122,27 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if !ok { return s.SigninPage(c) } - url, sid, err := s.NewSession(c, instance) + url, sess, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c.w, sid, sessionExp) - redirect(c, url) + c.setSession(sess) + c.redirect(url) return nil }, NOAUTH, HTML) timelinePage := handle(func(c *client) error { tType, _ := mux.Vars(c.r)["type"] q := c.r.URL.Query() - instance := q.Get("instance") + query := q.Get("q") list := q.Get("list") maxID := q.Get("max_id") minID := q.Get("min_id") - return s.TimelinePage(c, tType, instance, list, maxID, minID) + return s.TimelinePage(c, tType, query, list, maxID, minID) }, SESSION, HTML) defaultTimelinePage := handle(func(c *client) error { - redirect(c, "/timeline/home") + c.redirect("/timeline/home") return nil }, SESSION, HTML) @@ -217,6 +192,11 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { return s.UserSearchPage(c, id, sq, offset) }, SESSION, HTML) + mutePage := handle(func(c *client) error { + id, _ := mux.Vars(c.r)["id"] + return s.MutePage(c, id) + }, SESSION, HTML) + aboutPage := handle(func(c *client) error { return s.AboutPage(c) }, SESSION, HTML) @@ -230,7 +210,14 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { sq := q.Get("q") qType := q.Get("type") offset, _ := strconv.Atoi(q.Get("offset")) - return s.SearchPage(c, sq, qType, offset) + rurl, err := s.SearchPage(c, sq, qType, offset) + if err != nil { + return err + } + if len(rurl) > 0 { + c.redirect(rurl) + } + return nil }, SESSION, HTML) settingsPage := handle(func(c *client) error { @@ -241,14 +228,65 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { return s.FiltersPage(c) }, SESSION, HTML) + profilePage := handle(func(c *client) error { + return s.ProfilePage(c) + }, SESSION, HTML) + + profileUpdate := handle(func(c *client) error { + name := c.r.FormValue("name") + bio := c.r.FormValue("bio") + var avatar, banner *multipart.FileHeader + if f := c.r.MultipartForm.File["avatar"]; len(f) > 0 { + avatar = f[0] + } + if f := c.r.MultipartForm.File["banner"]; len(f) > 0 { + banner = f[0] + } + var fields []mastodon.Field + for i := 0; i < 16; i++ { + n := c.r.FormValue(fmt.Sprintf("field-name-%d", i)) + v := c.r.FormValue(fmt.Sprintf("field-value-%d", i)) + if len(n) == 0 { + continue + } + f := mastodon.Field{Name: n, Value: v} + fields = append(fields, f) + } + locked := c.r.FormValue("locked") == "true" + err := s.ProfileUpdate(c, name, bio, avatar, banner, fields, locked) + if err != nil { + return err + } + c.redirect("/") + return nil + }, CSRF, HTML) + + profileDelAvatar := handle(func(c *client) error { + err := s.ProfileDelAvatar(c) + if err != nil { + return err + } + c.redirect(c.r.FormValue("referrer")) + return nil + }, CSRF, HTML) + + profileDelBanner := handle(func(c *client) error { + err := s.ProfileDelBanner(c) + if err != nil { + return err + } + c.redirect(c.r.FormValue("referrer")) + return nil + }, CSRF, HTML) + signin := handle(func(c *client) error { instance := c.r.FormValue("instance") - url, sid, err := s.NewSession(c, instance) + url, sess, err := s.NewSession(c, instance) if err != nil { return err } - setSessionCookie(c.w, sid, sessionExp) - redirect(c, url) + c.setSession(sess) + c.redirect(url) return nil }, NOAUTH, HTML) @@ -259,7 +297,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, "/") + c.redirect("/") return nil }, SESSION, HTML) @@ -269,10 +307,11 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { format := c.r.FormValue("format") visibility := c.r.FormValue("visibility") isNSFW := c.r.FormValue("is_nsfw") == "true" + isQuote := c.r.FormValue("is_quote") == "true" quickReply := c.r.FormValue("quickreply") == "true" files := c.r.MultipartForm.File["attachments"] - id, err := s.Post(c, content, replyToID, format, visibility, isNSFW, files) + id, err := s.Post(c, content, replyToID, format, visibility, isNSFW, isQuote, files) if err != nil { return err } @@ -287,7 +326,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { } else { location = c.r.FormValue("referrer") } - redirect(c, location) + c.redirect(location) return nil }, CSRF, HTML) @@ -301,7 +340,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -315,7 +354,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -329,7 +368,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -343,7 +382,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -355,7 +394,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")+"#status-"+statusID) + c.redirect(c.r.FormValue("referrer") + "#status-" + statusID) return nil }, CSRF, HTML) @@ -371,7 +410,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -381,7 +420,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -391,7 +430,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -401,23 +440,19 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) mute := handle(func(c *client) error { id, _ := mux.Vars(c.r)["id"] - q := c.r.URL.Query() - var notifications *bool - if r, ok := q["notifications"]; ok && len(r) > 0 { - notifications = new(bool) - *notifications = r[0] == "true" - } - err := s.Mute(c, id, notifications) + notifications, _ := strconv.ParseBool(c.r.FormValue("notifications")) + duration, _ := strconv.Atoi(c.r.FormValue("duration")) + err := s.Mute(c, id, notifications, duration) if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect("/user/" + id) return nil }, CSRF, HTML) @@ -427,7 +462,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -437,7 +472,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -447,7 +482,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -457,7 +492,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -467,7 +502,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -504,7 +539,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, "/") + c.redirect("/") return nil }, CSRF, HTML) @@ -514,7 +549,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -524,7 +559,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -534,7 +569,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -545,7 +580,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -559,7 +594,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -573,7 +608,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if len(rid) > 0 { id = rid } - redirect(c, c.r.FormValue("referrer")+"#status-"+id) + c.redirect(c.r.FormValue("referrer") + "#status-" + id) return nil }, CSRF, HTML) @@ -584,7 +619,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -594,7 +629,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -608,7 +643,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -618,7 +653,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -629,7 +664,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -648,7 +683,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) @@ -660,14 +695,17 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - redirect(c, c.r.FormValue("referrer")) + c.redirect(c.r.FormValue("referrer")) return nil }, CSRF, HTML) signout := handle(func(c *client) error { - s.Signout(c) - setSessionCookie(c.w, "", 0) - redirect(c, "/") + err := s.Signout(c) + if err != nil { + return err + } + c.unsetSession() + c.redirect("/") return nil }, CSRF, HTML) @@ -677,7 +715,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fUnlike := handle(func(c *client) error { @@ -686,7 +724,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fRetweet := handle(func(c *client) error { @@ -695,7 +733,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) fUnretweet := handle(func(c *client) error { @@ -704,7 +742,7 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { if err != nil { return err } - return writeJson(c, count) + return c.writeJson(count) }, CSRF, JSON) r.HandleFunc("/", rootPage).Methods(http.MethodGet) @@ -720,11 +758,16 @@ func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet) r.HandleFunc("/user/{id}/{type}", userPage).Methods(http.MethodGet) r.HandleFunc("/usersearch/{id}", userSearchPage).Methods(http.MethodGet) + r.HandleFunc("/mute/{id}", mutePage).Methods(http.MethodGet) r.HandleFunc("/about", aboutPage).Methods(http.MethodGet) r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet) r.HandleFunc("/search", searchPage).Methods(http.MethodGet) r.HandleFunc("/settings", settingsPage).Methods(http.MethodGet) r.HandleFunc("/filters", filtersPage).Methods(http.MethodGet) + r.HandleFunc("/profile", profilePage).Methods(http.MethodGet) + r.HandleFunc("/profile", profileUpdate).Methods(http.MethodPost) + r.HandleFunc("/profile/delavatar", profileDelAvatar).Methods(http.MethodPost) + r.HandleFunc("/profile/delbanner", profileDelBanner).Methods(http.MethodPost) r.HandleFunc("/signin", signin).Methods(http.MethodPost) r.HandleFunc("/oauth_callback", oauthCallback).Methods(http.MethodGet) r.HandleFunc("/post", post).Methods(http.MethodPost) |