From fa27d9c6eb525b2e55c3faab5dd8a3e0e9658536 Mon Sep 17 00:00:00 2001 From: r Date: Sun, 22 Nov 2020 17:29:58 +0000 Subject: Refactor things - Remove separate auth/logging and merge them into transport.go - Add helper function for http handlers --- config/config.go | 2 +- kv/kv.go | 91 ----- main.go | 9 +- model/client.go | 19 - model/post.go | 1 - repo/appRepo.go | 6 +- repo/sessionRepo.go | 6 +- service/auth.go | 480 ----------------------- service/logging.go | 341 ----------------- service/service.go | 313 ++++++--------- service/transport.go | 1026 +++++++++++++++++++++----------------------------- util/kv.go | 91 +++++ 12 files changed, 634 insertions(+), 1751 deletions(-) delete mode 100644 kv/kv.go delete mode 100644 model/client.go delete mode 100644 service/auth.go delete mode 100644 service/logging.go create mode 100644 util/kv.go diff --git a/config/config.go b/config/config.go index d6140e5..8678f52 100644 --- a/config/config.go +++ b/config/config.go @@ -101,7 +101,7 @@ func Parse(r io.Reader) (c *config, err error) { case "log_file": c.LogFile = val default: - return nil, errors.New("invliad config key " + key) + return nil, errors.New("invalid config key " + key) } } diff --git a/kv/kv.go b/kv/kv.go deleted file mode 100644 index 0f51e07..0000000 --- a/kv/kv.go +++ /dev/null @@ -1,91 +0,0 @@ -package kv - -import ( - "errors" - "io/ioutil" - "os" - "path/filepath" - "strings" - "sync" -) - -var ( - errInvalidKey = errors.New("invalid key") - errNoSuchKey = errors.New("no such key") -) - -type Database struct { - cache map[string][]byte - basedir string - m sync.RWMutex -} - -func NewDatabse(basedir string) (db *Database, err error) { - err = os.Mkdir(basedir, 0755) - if err != nil && !os.IsExist(err) { - return - } - - return &Database{ - cache: make(map[string][]byte), - basedir: basedir, - }, nil -} - -func (db *Database) Set(key string, val []byte) (err error) { - if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { - return errInvalidKey - } - - err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644) - if err != nil { - return - } - - db.m.Lock() - db.cache[key] = val - db.m.Unlock() - - return -} - -func (db *Database) Get(key string) (val []byte, err error) { - if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { - return nil, errInvalidKey - } - - db.m.RLock() - data, ok := db.cache[key] - db.m.RUnlock() - - if !ok { - data, err = ioutil.ReadFile(filepath.Join(db.basedir, key)) - if err != nil { - err = errNoSuchKey - return nil, err - } - - db.m.Lock() - db.cache[key] = data - db.m.Unlock() - } - - val = make([]byte, len(data)) - copy(val, data) - - return -} - -func (db *Database) Remove(key string) { - if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { - return - } - - os.Remove(filepath.Join(db.basedir, key)) - - db.m.Lock() - delete(db.cache, key) - db.m.Unlock() - - return -} diff --git a/main.go b/main.go index 80baa81..636c59c 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,6 @@ import ( "time" "bloat/config" - "bloat/kv" "bloat/renderer" "bloat/repo" "bloat/service" @@ -76,13 +75,13 @@ func main() { } sessionDBPath := filepath.Join(config.DatabasePath, "session") - sessionDB, err := kv.NewDatabse(sessionDBPath) + sessionDB, err := util.NewDatabse(sessionDBPath) if err != nil { errExit(err) } appDBPath := filepath.Join(config.DatabasePath, "app") - appDB, err := kv.NewDatabse(appDBPath) + appDB, err := util.NewDatabse(appDBPath) if err != nil { errExit(err) } @@ -114,9 +113,7 @@ func main() { s := service.NewService(config.ClientName, config.ClientScope, config.ClientWebsite, customCSS, config.PostFormats, renderer, sessionRepo, appRepo, config.SingleInstance) - s = service.NewAuthService(sessionRepo, appRepo, s) - s = service.NewLoggingService(logger, s) - handler := service.NewHandler(s, config.StaticDirectory) + handler := service.NewHandler(s, logger, config.StaticDirectory) logger.Println("listening on", config.ListenAddress) err = http.ListenAndServe(config.ListenAddress, handler) diff --git a/model/client.go b/model/client.go deleted file mode 100644 index 931ddaa..0000000 --- a/model/client.go +++ /dev/null @@ -1,19 +0,0 @@ -package model - -import ( - "io" - - "bloat/mastodon" -) - -type ClientCtx struct { - SessionID string - CSRFToken string -} - -type Client struct { - *mastodon.Client - Writer io.Writer - Ctx ClientCtx - Session Session -} diff --git a/model/post.go b/model/post.go index 831f74f..40118ed 100644 --- a/model/post.go +++ b/model/post.go @@ -10,7 +10,6 @@ type PostContext struct { DefaultFormat string ReplyContext *ReplyContext Formats []PostFormat - DarkMode bool } type ReplyContext struct { diff --git a/repo/appRepo.go b/repo/appRepo.go index 6338c4a..d97ac1f 100644 --- a/repo/appRepo.go +++ b/repo/appRepo.go @@ -3,15 +3,15 @@ package repo import ( "encoding/json" - "bloat/kv" + "bloat/util" "bloat/model" ) type appRepo struct { - db *kv.Database + db *util.Database } -func NewAppRepo(db *kv.Database) *appRepo { +func NewAppRepo(db *util.Database) *appRepo { return &appRepo{ db: db, } diff --git a/repo/sessionRepo.go b/repo/sessionRepo.go index 15e3d31..2097c3e 100644 --- a/repo/sessionRepo.go +++ b/repo/sessionRepo.go @@ -3,15 +3,15 @@ package repo import ( "encoding/json" - "bloat/kv" + "bloat/util" "bloat/model" ) type sessionRepo struct { - db *kv.Database + db *util.Database } -func NewSessionRepo(db *kv.Database) *sessionRepo { +func NewSessionRepo(db *util.Database) *sessionRepo { return &sessionRepo{ db: db, } diff --git a/service/auth.go b/service/auth.go deleted file mode 100644 index 7845675..0000000 --- a/service/auth.go +++ /dev/null @@ -1,480 +0,0 @@ -package service - -import ( - "errors" - "mime/multipart" - - "bloat/mastodon" - "bloat/model" -) - -var ( - errInvalidSession = errors.New("invalid session") - errInvalidAccessToken = errors.New("invalid access token") - errInvalidCSRFToken = errors.New("invalid csrf token") -) - -type as struct { - sessionRepo model.SessionRepo - appRepo model.AppRepo - Service -} - -func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service { - return &as{sessionRepo, appRepo, s} -} - -func (s *as) initClient(c *model.Client) (err error) { - if len(c.Ctx.SessionID) < 1 { - return errInvalidSession - } - session, err := s.sessionRepo.Get(c.Ctx.SessionID) - if err != nil { - return errInvalidSession - } - app, err := s.appRepo.Get(session.InstanceDomain) - if err != nil { - return - } - mc := mastodon.NewClient(&mastodon.Config{ - Server: app.InstanceURL, - ClientID: app.ClientID, - ClientSecret: app.ClientSecret, - AccessToken: session.AccessToken, - }) - c.Client = mc - c.Session = session - return nil -} - -func (s *as) authenticateClient(c *model.Client) (err error) { - err = s.initClient(c) - if err != nil { - return - } - if len(c.Session.AccessToken) < 1 { - return errInvalidAccessToken - } - return nil -} - -func checkCSRF(c *model.Client) (err error) { - if c.Ctx.CSRFToken != c.Session.CSRFToken { - return errInvalidCSRFToken - } - return nil -} - -func (s *as) ServeErrorPage(c *model.Client, err error) { - s.authenticateClient(c) - s.Service.ServeErrorPage(c, err) -} - -func (s *as) ServeSigninPage(c *model.Client) (err error) { - return s.Service.ServeSigninPage(c) -} - -func (s *as) ServeRootPage(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeRootPage(c) -} - -func (s *as) ServeNavPage(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeNavPage(c) -} - -func (s *as) ServeTimelinePage(c *model.Client, tType string, - maxID string, minID string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeTimelinePage(c, tType, maxID, minID) -} - -func (s *as) ServeThreadPage(c *model.Client, id string, reply bool) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeThreadPage(c, id, reply) -} - -func (s *as) ServeLikedByPage(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeLikedByPage(c, id) -} - -func (s *as) ServeRetweetedByPage(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeRetweetedByPage(c, id) -} - -func (s *as) ServeNotificationPage(c *model.Client, - maxID string, minID string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeNotificationPage(c, maxID, minID) -} - -func (s *as) ServeUserPage(c *model.Client, id string, - pageType string, maxID string, minID string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeUserPage(c, id, pageType, maxID, minID) -} - -func (s *as) ServeAboutPage(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeAboutPage(c) -} - -func (s *as) ServeEmojiPage(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeEmojiPage(c) -} - -func (s *as) ServeSearchPage(c *model.Client, q string, - qType string, offset int) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeSearchPage(c, q, qType, offset) -} - -func (s *as) ServeUserSearchPage(c *model.Client, - id string, q string, offset int) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeUserSearchPage(c, id, q, offset) -} - -func (s *as) ServeSettingsPage(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - return s.Service.ServeSettingsPage(c) -} - -func (s *as) NewSession(instance string) (redirectUrl string, - sessionID string, err error) { - return s.Service.NewSession(instance) -} - -func (s *as) Signin(c *model.Client, sessionID string, - code string) (token string, userID string, err error) { - err = s.authenticateClient(c) - if err != nil && err != errInvalidAccessToken { - return - } - - token, userID, err = s.Service.Signin(c, c.Session.ID, code) - if err != nil { - return - } - - c.Session.AccessToken = token - c.Session.UserID = userID - - err = s.sessionRepo.Add(c.Session) - if err != nil { - return - } - - return -} - -func (s *as) Signout(c *model.Client) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - s.Service.Signout(c) - return -} - -func (s *as) Post(c *model.Client, content string, - replyToID string, format string, visibility string, isNSFW bool, - files []*multipart.FileHeader) (id string, err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Post(c, content, replyToID, format, visibility, isNSFW, files) -} - -func (s *as) Like(c *model.Client, id string) (count int64, err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Like(c, id) -} - -func (s *as) UnLike(c *model.Client, id string) (count int64, err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnLike(c, id) -} - -func (s *as) Retweet(c *model.Client, id string) (count int64, err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Retweet(c, id) -} - -func (s *as) UnRetweet(c *model.Client, id string) (count int64, err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnRetweet(c, id) -} - -func (s *as) Vote(c *model.Client, id string, - choices []string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Vote(c, id, choices) -} - -func (s *as) Follow(c *model.Client, id string, reblogs *bool) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Follow(c, id, reblogs) -} - -func (s *as) UnFollow(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnFollow(c, id) -} - -func (s *as) Mute(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Mute(c, id) -} - -func (s *as) UnMute(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnMute(c, id) -} - -func (s *as) Block(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Block(c, id) -} - -func (s *as) UnBlock(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnBlock(c, id) -} - -func (s *as) Subscribe(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Subscribe(c, id) -} - -func (s *as) UnSubscribe(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnSubscribe(c, id) -} - -func (s *as) SaveSettings(c *model.Client, settings *model.Settings) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.SaveSettings(c, settings) -} - -func (s *as) MuteConversation(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.MuteConversation(c, id) -} - -func (s *as) UnMuteConversation(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnMuteConversation(c, id) -} - -func (s *as) Delete(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Delete(c, id) -} - -func (s *as) ReadNotifications(c *model.Client, maxID string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.ReadNotifications(c, maxID) -} - -func (s *as) Bookmark(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.Bookmark(c, id) -} - -func (s *as) UnBookmark(c *model.Client, id string) (err error) { - err = s.authenticateClient(c) - if err != nil { - return - } - err = checkCSRF(c) - if err != nil { - return - } - return s.Service.UnBookmark(c, id) -} diff --git a/service/logging.go b/service/logging.go deleted file mode 100644 index 3cb99bf..0000000 --- a/service/logging.go +++ /dev/null @@ -1,341 +0,0 @@ -package service - -import ( - "log" - "mime/multipart" - "time" - - "bloat/model" -) - -type ls struct { - logger *log.Logger - Service -} - -func NewLoggingService(logger *log.Logger, s Service) Service { - return &ls{logger, s} -} - -func (s *ls) ServeErrorPage(c *model.Client, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, err=%v, took=%v\n", - "ServeErrorPage", err, time.Since(begin)) - }(time.Now()) - s.Service.ServeErrorPage(c, err) -} - -func (s *ls) ServeSigninPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeSigninPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeSigninPage(c) -} - -func (s *ls) ServeRootPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeRootPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeRootPage(c) -} - -func (s *ls) ServeNavPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeNavPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeNavPage(c) -} - -func (s *ls) ServeTimelinePage(c *model.Client, tType string, - maxID string, minID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, type=%v, took=%v, err=%v\n", - "ServeTimelinePage", tType, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeTimelinePage(c, tType, maxID, minID) -} - -func (s *ls) ServeThreadPage(c *model.Client, id string, - reply bool) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "ServeThreadPage", id, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeThreadPage(c, id, reply) -} - -func (s *ls) ServeLikedByPage(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "ServeLikedByPage", id, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeLikedByPage(c, id) -} - -func (s *ls) ServeRetweetedByPage(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "ServeRetweetedByPage", id, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeRetweetedByPage(c, id) -} - -func (s *ls) ServeNotificationPage(c *model.Client, - maxID string, minID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeNotificationPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeNotificationPage(c, maxID, minID) -} - -func (s *ls) ServeUserPage(c *model.Client, id string, - pageType string, maxID string, minID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, type=%v, took=%v, err=%v\n", - "ServeUserPage", id, pageType, time.Since(begin), err) - }(time.Now()) - return s.Service.ServeUserPage(c, id, pageType, maxID, minID) -} - -func (s *ls) ServeAboutPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeAboutPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeAboutPage(c) -} - -func (s *ls) ServeEmojiPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeEmojiPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeEmojiPage(c) -} - -func (s *ls) ServeSearchPage(c *model.Client, q string, - qType string, offset int) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeSearchPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeSearchPage(c, q, qType, offset) -} - -func (s *ls) ServeUserSearchPage(c *model.Client, - id string, q string, offset int) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeUserSearchPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeUserSearchPage(c, id, q, offset) -} - -func (s *ls) ServeSettingsPage(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "ServeSettingsPage", time.Since(begin), err) - }(time.Now()) - return s.Service.ServeSettingsPage(c) -} - -func (s *ls) NewSession(instance string) (redirectUrl string, - sessionID string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", - "NewSession", instance, time.Since(begin), err) - }(time.Now()) - return s.Service.NewSession(instance) -} - -func (s *ls) Signin(c *model.Client, sessionID string, - code string) (token string, userID string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, session_id=%v, took=%v, err=%v\n", - "Signin", sessionID, time.Since(begin), err) - }(time.Now()) - return s.Service.Signin(c, sessionID, code) -} - -func (s *ls) Signout(c *model.Client) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "Signout", time.Since(begin), err) - }(time.Now()) - return s.Service.Signout(c) -} - -func (s *ls) Post(c *model.Client, content string, - replyToID string, format string, visibility string, isNSFW bool, - files []*multipart.FileHeader) (id string, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "Post", time.Since(begin), err) - }(time.Now()) - return s.Service.Post(c, content, replyToID, format, - visibility, isNSFW, files) -} - -func (s *ls) Like(c *model.Client, id string) (count int64, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Like", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Like(c, id) -} - -func (s *ls) UnLike(c *model.Client, id string) (count int64, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnLike", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnLike(c, id) -} - -func (s *ls) Retweet(c *model.Client, id string) (count int64, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Retweet", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Retweet(c, id) -} - -func (s *ls) UnRetweet(c *model.Client, id string) (count int64, err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnRetweet", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnRetweet(c, id) -} - -func (s *ls) Vote(c *model.Client, id string, choices []string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Vote", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Vote(c, id, choices) -} - -func (s *ls) Follow(c *model.Client, id string, reblogs *bool) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Follow", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Follow(c, id, reblogs) -} - -func (s *ls) UnFollow(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnFollow", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnFollow(c, id) -} - -func (s *ls) Mute(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Mute", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Mute(c, id) -} - -func (s *ls) UnMute(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnMute", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnMute(c, id) -} - -func (s *ls) Block(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Block", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Block(c, id) -} - -func (s *ls) UnBlock(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnBlock", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnBlock(c, id) -} - -func (s *ls) Subscribe(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Subscribe", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Subscribe(c, id) -} - -func (s *ls) UnSubscribe(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnSubscribe", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnSubscribe(c, id) -} - -func (s *ls) SaveSettings(c *model.Client, settings *model.Settings) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, took=%v, err=%v\n", - "SaveSettings", time.Since(begin), err) - }(time.Now()) - return s.Service.SaveSettings(c, settings) -} - -func (s *ls) MuteConversation(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "MuteConversation", id, time.Since(begin), err) - }(time.Now()) - return s.Service.MuteConversation(c, id) -} - -func (s *ls) UnMuteConversation(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnMuteConversation", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnMuteConversation(c, id) -} - -func (s *ls) Delete(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Delete", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Delete(c, id) -} - -func (s *ls) ReadNotifications(c *model.Client, maxID string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, max_id=%v, took=%v, err=%v\n", - "ReadNotifications", maxID, time.Since(begin), err) - }(time.Now()) - return s.Service.ReadNotifications(c, maxID) -} - -func (s *ls) Bookmark(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "Bookmark", id, time.Since(begin), err) - }(time.Now()) - return s.Service.Bookmark(c, id) -} - -func (s *ls) UnBookmark(c *model.Client, id string) (err error) { - defer func(begin time.Time) { - s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", - "UnBookmark", id, time.Since(begin), err) - }(time.Now()) - return s.Service.UnBookmark(c, id) -} diff --git a/service/service.go b/service/service.go index 1ba99da..db44e10 100644 --- a/service/service.go +++ b/service/service.go @@ -20,53 +20,6 @@ var ( errInvalidArgument = errors.New("invalid argument") ) -type Service interface { - ServeErrorPage(c *model.Client, err error) - ServeSigninPage(c *model.Client) (err error) - ServeRootPage(c *model.Client) (err error) - ServeNavPage(c *model.Client) (err error) - ServeTimelinePage(c *model.Client, tType string, maxID string, - minID string) (err error) - ServeThreadPage(c *model.Client, id string, reply bool) (err error) - ServeLikedByPage(c *model.Client, id string) (err error) - ServeRetweetedByPage(c *model.Client, id string) (err error) - ServeNotificationPage(c *model.Client, maxID string, minID string) (err error) - ServeUserPage(c *model.Client, id string, pageType string, maxID string, - minID string) (err error) - ServeAboutPage(c *model.Client) (err error) - ServeEmojiPage(c *model.Client) (err error) - ServeSearchPage(c *model.Client, q string, qType string, offset int) (err error) - ServeUserSearchPage(c *model.Client, id string, q string, offset int) (err error) - ServeSettingsPage(c *model.Client) (err error) - SingleInstance() (instance string, ok bool) - NewSession(instance string) (redirectUrl string, sessionID string, err error) - Signin(c *model.Client, sessionID string, code string) (token string, - userID string, err error) - Signout(c *model.Client) (err error) - Post(c *model.Client, content string, replyToID string, format string, visibility string, - isNSFW bool, files []*multipart.FileHeader) (id string, err error) - Like(c *model.Client, id string) (count int64, err error) - UnLike(c *model.Client, id string) (count int64, err error) - Retweet(c *model.Client, id string) (count int64, err error) - UnRetweet(c *model.Client, id string) (count int64, err error) - Vote(c *model.Client, id string, choices []string) (err error) - Follow(c *model.Client, id string, reblogs *bool) (err error) - UnFollow(c *model.Client, id string) (err error) - Mute(c *model.Client, id string) (err error) - UnMute(c *model.Client, id string) (err error) - Block(c *model.Client, id string) (err error) - UnBlock(c *model.Client, id string) (err error) - Subscribe(c *model.Client, id string) (err error) - UnSubscribe(c *model.Client, id string) (err error) - SaveSettings(c *model.Client, settings *model.Settings) (err error) - MuteConversation(c *model.Client, id string) (err error) - UnMuteConversation(c *model.Client, id string) (err error) - Delete(c *model.Client, id string) (err error) - ReadNotifications(c *model.Client, maxID string) (err error) - Bookmark(c *model.Client, id string) (err error) - UnBookmark(c *model.Client, id string) (err error) -} - type service struct { clientName string clientScope string @@ -88,7 +41,7 @@ func NewService(clientName string, sessionRepo model.SessionRepo, appRepo model.AppRepo, singleInstance string, -) Service { +) *service { return &service{ clientName: clientName, clientScope: clientScope, @@ -102,7 +55,7 @@ func NewService(clientName string, } } -func getRendererContext(c *model.Client) *renderer.Context { +func getRendererContext(c *client) *renderer.Context { var settings model.Settings var session model.Session if c != nil { @@ -128,26 +81,21 @@ func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, if key == nil { return } - keyStr, ok := key.(string) if !ok { return } - _, ok = m[keyStr] if !ok { m[keyStr] = []mastodon.ReplyInfo{} } - m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) } -func (svc *service) getCommonData(c *model.Client, - title string) (data *renderer.CommonData) { - +func (s *service) getCommonData(c *client, title string) (data *renderer.CommonData) { data = &renderer.CommonData{ - Title: title + " - " + svc.clientName, - CustomCSS: svc.customCSS, + Title: title + " - " + s.clientName, + CustomCSS: s.customCSS, } if c != nil && c.Session.IsLoggedIn() { data.CSRFToken = c.Session.CSRFToken @@ -155,66 +103,59 @@ func (svc *service) getCommonData(c *model.Client, return } -func (svc *service) ServeErrorPage(c *model.Client, err error) { +func (s *service) ErrorPage(c *client, err error) { var errStr string if err != nil { errStr = err.Error() } - - commonData := svc.getCommonData(nil, "error") + commonData := s.getCommonData(nil, "error") data := &renderer.ErrorData{ CommonData: commonData, Error: errStr, } - rCtx := getRendererContext(c) - svc.renderer.Render(rCtx, c.Writer, renderer.ErrorPage, data) + s.renderer.Render(rCtx, c, renderer.ErrorPage, data) } -func (svc *service) ServeSigninPage(c *model.Client) (err error) { - commonData := svc.getCommonData(nil, "signin") +func (s *service) SigninPage(c *client) (err error) { + commonData := s.getCommonData(nil, "signin") data := &renderer.SigninData{ CommonData: commonData, } - rCtx := getRendererContext(nil) - return svc.renderer.Render(rCtx, c.Writer, renderer.SigninPage, data) + return s.renderer.Render(rCtx, c, renderer.SigninPage, data) } -func (svc *service) ServeRootPage(c *model.Client) (err error) { +func (s *service) RootPage(c *client) (err error) { data := &renderer.RootData{ - Title: svc.clientName, + Title: s.clientName, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.RootPage, data) + return s.renderer.Render(rCtx, c, renderer.RootPage, data) } -func (svc *service) ServeNavPage(c *model.Client) (err error) { +func (s *service) NavPage(c *client) (err error) { u, err := c.GetAccountCurrentUser(ctx) if err != nil { return } - postContext := model.PostContext{ DefaultVisibility: c.Session.Settings.DefaultVisibility, DefaultFormat: c.Session.Settings.DefaultFormat, - Formats: svc.postFormats, + Formats: s.postFormats, } - - commonData := svc.getCommonData(c, "Nav") + commonData := s.getCommonData(c, "nav") commonData.Target = "main" data := &renderer.NavData{ User: u, CommonData: commonData, PostContext: postContext, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.NavPage, data) + return s.renderer.Render(rCtx, c, renderer.NavPage, data) } -func (svc *service) ServeTimelinePage(c *model.Client, tType string, +func (s *service) TimelinePage(c *client, tType string, maxID string, minID string) (err error) { var nextLink, prevLink, title string @@ -279,7 +220,7 @@ func (svc *service) ServeTimelinePage(c *model.Client, tType string, nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", tType, pg.MaxID) } - commonData := svc.getCommonData(c, tType+" timeline ") + commonData := s.getCommonData(c, tType+" timeline ") data := &renderer.TimelineData{ Title: title, Statuses: statuses, @@ -289,10 +230,10 @@ func (svc *service) ServeTimelinePage(c *model.Client, tType string, } rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.TimelinePage, data) + return s.renderer.Render(rCtx, c, renderer.TimelinePage, data) } -func (svc *service) ServeThreadPage(c *model.Client, id string, reply bool) (err error) { +func (s *service) ThreadPage(c *client, id string, reply bool) (err error) { var postContext model.PostContext status, err := c.GetStatus(ctx, id) @@ -323,14 +264,13 @@ func (svc *service) ServeThreadPage(c *model.Client, id string, reply bool) (err postContext = model.PostContext{ DefaultVisibility: visibility, DefaultFormat: c.Session.Settings.DefaultFormat, - Formats: svc.postFormats, + Formats: s.postFormats, ReplyContext: &model.ReplyContext{ InReplyToID: id, InReplyToName: status.Account.Acct, ReplyContent: content, ForceVisibility: isDirect, }, - DarkMode: c.Session.Settings.DarkMode, } } @@ -353,7 +293,7 @@ func (svc *service) ServeThreadPage(c *model.Client, id string, reply bool) (err addToReplyMap(replies, statuses[i].InReplyToID, statuses[i].ID, i+1) } - commonData := svc.getCommonData(c, "post by "+status.Account.DisplayName) + commonData := s.getCommonData(c, "post by "+status.Account.DisplayName) data := &renderer.ThreadData{ Statuses: statuses, PostContext: postContext, @@ -362,42 +302,38 @@ func (svc *service) ServeThreadPage(c *model.Client, id string, reply bool) (err } rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.ThreadPage, data) + return s.renderer.Render(rCtx, c, renderer.ThreadPage, data) } -func (svc *service) ServeLikedByPage(c *model.Client, id string) (err error) { +func (s *service) LikedByPage(c *client, id string) (err error) { likers, err := c.GetFavouritedBy(ctx, id, nil) if err != nil { return } - - commonData := svc.getCommonData(c, "likes") + commonData := s.getCommonData(c, "likes") data := &renderer.LikedByData{ CommonData: commonData, Users: likers, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.LikedByPage, data) + return s.renderer.Render(rCtx, c, renderer.LikedByPage, data) } -func (svc *service) ServeRetweetedByPage(c *model.Client, id string) (err error) { +func (s *service) RetweetedByPage(c *client, id string) (err error) { retweeters, err := c.GetRebloggedBy(ctx, id, nil) if err != nil { return } - - commonData := svc.getCommonData(c, "retweets") + commonData := s.getCommonData(c, "retweets") data := &renderer.RetweetedByData{ CommonData: commonData, Users: retweeters, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.RetweetedByPage, data) + return s.renderer.Render(rCtx, c, renderer.RetweetedByPage, data) } -func (svc *service) ServeNotificationPage(c *model.Client, maxID string, +func (s *service) NotificationPage(c *client, maxID string, minID string) (err error) { var nextLink string @@ -428,12 +364,11 @@ func (svc *service) ServeNotificationPage(c *model.Client, maxID string, if unreadCount > 0 { readID = notifications[0].ID } - if len(notifications) == 20 && len(pg.MaxID) > 0 { nextLink = "/notifications?max_id=" + pg.MaxID } - commonData := svc.getCommonData(c, "notifications") + commonData := s.getCommonData(c, "notifications") commonData.RefreshInterval = c.Session.Settings.NotificationInterval commonData.Target = "main" commonData.Count = unreadCount @@ -445,10 +380,10 @@ func (svc *service) ServeNotificationPage(c *model.Client, maxID string, CommonData: commonData, } rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.NotificationPage, data) + return s.renderer.Render(rCtx, c, renderer.NotificationPage, data) } -func (svc *service) ServeUserPage(c *model.Client, id string, pageType string, +func (s *service) UserPage(c *client, id string, pageType string, maxID string, minID string) (err error) { var nextLink string @@ -561,7 +496,7 @@ func (svc *service) ServeUserPage(c *model.Client, id string, pageType string, } } - commonData := svc.getCommonData(c, user.DisplayName) + commonData := s.getCommonData(c, user.DisplayName) data := &renderer.UserData{ User: user, IsCurrent: isCurrent, @@ -572,10 +507,10 @@ func (svc *service) ServeUserPage(c *model.Client, id string, pageType string, CommonData: commonData, } rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.UserPage, data) + return s.renderer.Render(rCtx, c, renderer.UserPage, data) } -func (svc *service) ServeUserSearchPage(c *model.Client, +func (s *service) UserSearchPage(c *client, id string, q string, offset int) (err error) { var nextLink string @@ -598,7 +533,8 @@ func (svc *service) ServeUserSearchPage(c *model.Client, if len(results.Statuses) == 20 { offset += 20 - nextLink = fmt.Sprintf("/usersearch/%s?q=%s&offset=%d", id, url.QueryEscape(q), offset) + nextLink = fmt.Sprintf("/usersearch/%s?q=%s&offset=%d", id, + url.QueryEscape(q), offset) } qq := template.HTMLEscapeString(q) @@ -606,7 +542,7 @@ func (svc *service) ServeUserSearchPage(c *model.Client, title += " \"" + qq + "\"" } - commonData := svc.getCommonData(c, title) + commonData := s.getCommonData(c, title) data := &renderer.UserSearchData{ CommonData: commonData, User: user, @@ -614,38 +550,34 @@ func (svc *service) ServeUserSearchPage(c *model.Client, Statuses: results.Statuses, NextLink: nextLink, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.UserSearchPage, data) + return s.renderer.Render(rCtx, c, renderer.UserSearchPage, data) } -func (svc *service) ServeAboutPage(c *model.Client) (err error) { - commonData := svc.getCommonData(c, "about") +func (s *service) AboutPage(c *client) (err error) { + commonData := s.getCommonData(c, "about") data := &renderer.AboutData{ CommonData: commonData, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.AboutPage, data) + return s.renderer.Render(rCtx, c, renderer.AboutPage, data) } -func (svc *service) ServeEmojiPage(c *model.Client) (err error) { +func (s *service) EmojiPage(c *client) (err error) { emojis, err := c.GetInstanceEmojis(ctx) if err != nil { return } - - commonData := svc.getCommonData(c, "emojis") + commonData := s.getCommonData(c, "emojis") data := &renderer.EmojiData{ Emojis: emojis, CommonData: commonData, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.EmojiPage, data) + return s.renderer.Render(rCtx, c, renderer.EmojiPage, data) } -func (svc *service) ServeSearchPage(c *model.Client, +func (s *service) SearchPage(c *client, q string, qType string, offset int) (err error) { var nextLink string @@ -664,7 +596,8 @@ func (svc *service) ServeSearchPage(c *model.Client, if (qType == "accounts" && len(results.Accounts) == 20) || (qType == "statuses" && len(results.Statuses) == 20) { offset += 20 - nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", url.QueryEscape(q), qType, offset) + nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", + url.QueryEscape(q), qType, offset) } qq := template.HTMLEscapeString(q) @@ -672,7 +605,7 @@ func (svc *service) ServeSearchPage(c *model.Client, title += " \"" + qq + "\"" } - commonData := svc.getCommonData(c, title) + commonData := s.getCommonData(c, title) data := &renderer.SearchData{ CommonData: commonData, Q: qq, @@ -681,34 +614,30 @@ func (svc *service) ServeSearchPage(c *model.Client, Statuses: results.Statuses, NextLink: nextLink, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.SearchPage, data) + return s.renderer.Render(rCtx, c, renderer.SearchPage, data) } -func (svc *service) ServeSettingsPage(c *model.Client) (err error) { - commonData := svc.getCommonData(c, "settings") +func (s *service) SettingsPage(c *client) (err error) { + commonData := s.getCommonData(c, "settings") data := &renderer.SettingsData{ CommonData: commonData, Settings: &c.Session.Settings, - PostFormats: svc.postFormats, + PostFormats: s.postFormats, } - rCtx := getRendererContext(c) - return svc.renderer.Render(rCtx, c.Writer, renderer.SettingsPage, data) + return s.renderer.Render(rCtx, c, renderer.SettingsPage, data) } -func (svc *service) SingleInstance() (instance string, ok bool) { - if len(svc.singleInstance) > 0 { - instance = svc.singleInstance +func (s *service) SingleInstance() (instance string, ok bool) { + if len(s.singleInstance) > 0 { + instance = s.singleInstance ok = true } return } -func (svc *service) NewSession(instance string) ( - redirectUrl string, sessionID string, err error) { - +func (s *service) NewSession(instance string) (rurl string, sid string, err error) { var instanceURL string if strings.HasPrefix(instance, "https://") { instanceURL = instance @@ -717,53 +646,48 @@ func (svc *service) NewSession(instance string) ( instanceURL = "https://" + instance } - sessionID, err = util.NewSessionID() + sid, err = util.NewSessionID() if err != nil { return } - csrfToken, err := util.NewCSRFToken() if err != nil { return } session := model.Session{ - ID: sessionID, + ID: sid, InstanceDomain: instance, CSRFToken: csrfToken, Settings: *model.NewSettings(), } - - err = svc.sessionRepo.Add(session) + err = s.sessionRepo.Add(session) if err != nil { return } - app, err := svc.appRepo.Get(instance) + app, err := s.appRepo.Get(instance) if err != nil { if err != model.ErrAppNotFound { return } - mastoApp, err := mastodon.RegisterApp(ctx, &mastodon.AppConfig{ Server: instanceURL, - ClientName: svc.clientName, - Scopes: svc.clientScope, - Website: svc.clientWebsite, - RedirectURIs: svc.clientWebsite + "/oauth_callback", + ClientName: s.clientName, + Scopes: s.clientScope, + Website: s.clientWebsite, + RedirectURIs: s.clientWebsite + "/oauth_callback", }) if err != nil { return "", "", err } - app = model.App{ InstanceDomain: instance, InstanceURL: instanceURL, ClientID: mastoApp.ClientID, ClientSecret: mastoApp.ClientSecret, } - - err = svc.appRepo.Add(app) + err = s.appRepo.Add(app) if err != nil { return "", "", err } @@ -778,23 +702,21 @@ func (svc *service) NewSession(instance string) ( q.Set("scope", "read write follow") q.Set("client_id", app.ClientID) q.Set("response_type", "code") - q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") + q.Set("redirect_uri", s.clientWebsite+"/oauth_callback") u.RawQuery = q.Encode() - redirectUrl = instanceURL + u.String() - + rurl = instanceURL + u.String() return } -func (svc *service) Signin(c *model.Client, sessionID string, - code string) (token string, userID string, err error) { +func (s *service) Signin(c *client, code string) (token string, + userID string, err error) { if len(code) < 1 { err = errInvalidArgument return } - - err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback") + err = c.AuthenticateToken(ctx, code, s.clientWebsite+"/oauth_callback") if err != nil { return } @@ -805,17 +727,16 @@ func (svc *service) Signin(c *model.Client, sessionID string, return } userID = u.ID - return } -func (svc *service) Signout(c *model.Client) (err error) { - svc.sessionRepo.Remove(c.Session.ID) +func (s *service) Signout(c *client) (err error) { + s.sessionRepo.Remove(c.Session.ID) return } -func (svc *service) Post(c *model.Client, content string, - replyToID string, format string, visibility string, isNSFW bool, +func (s *service) Post(c *client, content string, replyToID string, + format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { var mediaIDs []string @@ -835,141 +756,135 @@ func (svc *service) Post(c *model.Client, content string, Visibility: visibility, Sensitive: isNSFW, } - - s, err := c.PostStatus(ctx, tweet) + st, err := c.PostStatus(ctx, tweet) if err != nil { return } - - return s.ID, nil + return st.ID, nil } -func (svc *service) Like(c *model.Client, id string) (count int64, err error) { - s, err := c.Favourite(ctx, id) +func (s *service) Like(c *client, id string) (count int64, err error) { + st, err := c.Favourite(ctx, id) if err != nil { return } - count = s.FavouritesCount + count = st.FavouritesCount return } -func (svc *service) UnLike(c *model.Client, id string) (count int64, err error) { - s, err := c.Unfavourite(ctx, id) +func (s *service) UnLike(c *client, id string) (count int64, err error) { + st, err := c.Unfavourite(ctx, id) if err != nil { return } - count = s.FavouritesCount + count = st.FavouritesCount return } -func (svc *service) Retweet(c *model.Client, id string) (count int64, err error) { - s, err := c.Reblog(ctx, id) +func (s *service) Retweet(c *client, id string) (count int64, err error) { + st, err := c.Reblog(ctx, id) if err != nil { return } - if s.Reblog != nil { - count = s.Reblog.ReblogsCount + if st.Reblog != nil { + count = st.Reblog.ReblogsCount } return } -func (svc *service) UnRetweet(c *model.Client, id string) ( +func (s *service) UnRetweet(c *client, id string) ( count int64, err error) { - s, err := c.Unreblog(ctx, id) + st, err := c.Unreblog(ctx, id) if err != nil { return } - count = s.ReblogsCount + count = st.ReblogsCount return } -func (svc *service) Vote(c *model.Client, id string, choices []string) (err error) { +func (s *service) Vote(c *client, id string, choices []string) (err error) { _, err = c.Vote(ctx, id, choices) - if err != nil { - return - } return } -func (svc *service) Follow(c *model.Client, id string, reblogs *bool) (err error) { +func (s *service) Follow(c *client, id string, reblogs *bool) (err error) { _, err = c.AccountFollow(ctx, id, reblogs) return } -func (svc *service) UnFollow(c *model.Client, id string) (err error) { +func (s *service) UnFollow(c *client, id string) (err error) { _, err = c.AccountUnfollow(ctx, id) return } -func (svc *service) Mute(c *model.Client, id string) (err error) { +func (s *service) Mute(c *client, id string) (err error) { _, err = c.AccountMute(ctx, id) return } -func (svc *service) UnMute(c *model.Client, id string) (err error) { +func (s *service) UnMute(c *client, id string) (err error) { _, err = c.AccountUnmute(ctx, id) return } -func (svc *service) Block(c *model.Client, id string) (err error) { +func (s *service) Block(c *client, id string) (err error) { _, err = c.AccountBlock(ctx, id) return } -func (svc *service) UnBlock(c *model.Client, id string) (err error) { +func (s *service) UnBlock(c *client, id string) (err error) { _, err = c.AccountUnblock(ctx, id) return } -func (svc *service) Subscribe(c *model.Client, id string) (err error) { +func (s *service) Subscribe(c *client, id string) (err error) { _, err = c.Subscribe(ctx, id) return } -func (svc *service) UnSubscribe(c *model.Client, id string) (err error) { +func (s *service) UnSubscribe(c *client, id string) (err error) { _, err = c.UnSubscribe(ctx, id) return } -func (svc *service) SaveSettings(c *model.Client, s *model.Settings) (err error) { - switch s.NotificationInterval { +func (s *service) SaveSettings(c *client, settings *model.Settings) (err error) { + switch settings.NotificationInterval { case 0, 30, 60, 120, 300, 600: default: return errInvalidArgument } - session, err := svc.sessionRepo.Get(c.Session.ID) + session, err := s.sessionRepo.Get(c.Session.ID) if err != nil { return } - - session.Settings = *s - return svc.sessionRepo.Add(session) + session.Settings = *settings + return s.sessionRepo.Add(session) } -func (svc *service) MuteConversation(c *model.Client, id string) (err error) { +func (s *service) MuteConversation(c *client, id string) (err error) { _, err = c.MuteConversation(ctx, id) return } -func (svc *service) UnMuteConversation(c *model.Client, id string) (err error) { +func (s *service) UnMuteConversation(c *client, id string) (err error) { _, err = c.UnmuteConversation(ctx, id) return } -func (svc *service) Delete(c *model.Client, id string) (err error) { +func (s *service) Delete(c *client, id string) (err error) { return c.DeleteStatus(ctx, id) } -func (svc *service) ReadNotifications(c *model.Client, maxID string) (err error) { +func (s *service) ReadNotifications(c *client, maxID string) (err error) { return c.ReadNotifications(ctx, maxID) } -func (svc *service) Bookmark(c *model.Client, id string) (err error) { +func (s *service) Bookmark(c *client, id string) (err error) { _, err = c.Bookmark(ctx, id) return } -func (svc *service) UnBookmark(c *model.Client, id string) (err error) { +func (s *service) UnBookmark(c *client, id string) (err error) { _, err = c.Unbookmark(ctx, id) return } diff --git a/service/transport.go b/service/transport.go index 3c9392a..80ad7f1 100644 --- a/service/transport.go +++ b/service/transport.go @@ -2,597 +2,478 @@ package service import ( "encoding/json" - "io" - "mime/multipart" + "errors" + "log" "net/http" "strconv" "time" + "bloat/mastodon" "bloat/model" "github.com/gorilla/mux" ) +var ( + errInvalidSession = errors.New("invalid session") + errInvalidCSRFToken = errors.New("invalid csrf token") +) + const ( sessionExp = 365 * 24 * time.Hour ) -func newClient(w io.Writer, req *http.Request, csrfToken string) *model.Client { - var sessionID string - if req != nil { - c, err := req.Cookie("session_id") - if err == nil { - sessionID = c.Value - } - } - return &model.Client{ - Writer: w, - Ctx: model.ClientCtx{ - SessionID: sessionID, - CSRFToken: csrfToken, - }, - } +type respType int + +const ( + HTML respType = iota + JSON +) + +type authType int + +const ( + NOAUTH authType = iota + SESSION + CSRF +) + +type client struct { + *mastodon.Client + http.ResponseWriter + Req *http.Request + CSRFToken string + Session model.Session } -func setSessionCookie(w http.ResponseWriter, sessionID string, exp time.Duration) { +func setSessionCookie(w http.ResponseWriter, sid string, exp time.Duration) { http.SetCookie(w, &http.Cookie{ Name: "session_id", - Value: sessionID, + Value: sid, Expires: time.Now().Add(exp), }) } -func getMultipartFormValue(mf *multipart.Form, key string) (val string) { - vals, ok := mf.Value[key] - if !ok { - return "" - } - if len(vals) < 1 { - return "" - } - return vals[0] -} - -func serveJson(w io.Writer, data interface{}) (err error) { - var d = make(map[string]interface{}) - d["data"] = data - return json.NewEncoder(w).Encode(d) +func writeJson(c *client, data interface{}) error { + return json.NewEncoder(c).Encode(map[string]interface{}{ + "data": data, + }) } -func serveJsonError(w http.ResponseWriter, err error) { - var d = make(map[string]interface{}) - d["error"] = err.Error() - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(d) - return +func redirect(c *client, url string) { + c.Header().Add("Location", url) + c.WriteHeader(http.StatusFound) } -func NewHandler(s Service, staticDir string) http.Handler { +func NewHandler(s *service, logger *log.Logger, staticDir string) http.Handler { r := mux.NewRouter() - rootPage := func(w http.ResponseWriter, req *http.Request) { - sessionID, _ := req.Cookie("session_id") - if sessionID != nil && len(sessionID.Value) > 0 { - c := newClient(w, req, "") - err := s.ServeRootPage(c) - if err != nil { - if err == errInvalidAccessToken { - w.Header().Add("Location", "/signin") - w.WriteHeader(http.StatusFound) - return - } - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } else { - w.Header().Add("Location", "/signin") - w.WriteHeader(http.StatusFound) - } - } - - navPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - err := s.ServeNavPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + writeError := func(c *client, err error, t respType) { + switch t { + case HTML: + c.WriteHeader(http.StatusInternalServerError) + s.ErrorPage(c, err) + case JSON: + c.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(c).Encode(map[string]string{ + "error": err.Error(), + }) } } - signinPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, nil, "") - instance, ok := s.SingleInstance() - if ok { - url, sessionID, err := s.NewSession(instance) + authenticate := func(c *client, t authType) error { + if t >= SESSION { + cookie, err := c.Req.Cookie("session_id") + if err != nil || len(cookie.Value) < 1 { + return errInvalidSession + } + c.Session, err = s.sessionRepo.Get(cookie.Value) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return errInvalidSession } - - setSessionCookie(w, sessionID, sessionExp) - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - } else { - err := s.ServeSigninPage(c) + app, err := s.appRepo.Get(c.Session.InstanceDomain) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err + } + c.Client = mastodon.NewClient(&mastodon.Config{ + Server: app.InstanceURL, + ClientID: app.ClientID, + ClientSecret: app.ClientSecret, + AccessToken: c.Session.AccessToken, + }) + } + if t >= CSRF { + c.CSRFToken = c.Req.FormValue("csrf_token") + if len(c.CSRFToken) < 1 || c.CSRFToken != c.Session.CSRFToken { + return errInvalidCSRFToken } } + return nil } - timelinePage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - tType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeTimelinePage(c, tType, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - defaultTimelinePage := func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Location", "/timeline/home") - w.WriteHeader(http.StatusFound) - } - - threadPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - reply := req.URL.Query().Get("reply") - - err := s.ServeThreadPage(c, id, len(reply) > 1) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - likedByPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - - err := s.ServeLikedByPage(c, id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - retweetedByPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - - err := s.ServeRetweetedByPage(c, id) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - notificationsPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") - - err := s.ServeNotificationPage(c, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } + handle := func(f func(c *client) error, at authType, rt respType) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + var err error + c := &client{Req: req, ResponseWriter: w} - userPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - pageType, _ := mux.Vars(req)["type"] - maxID := req.URL.Query().Get("max_id") - minID := req.URL.Query().Get("min_id") + defer func(begin time.Time) { + logger.Printf("path=%s, err=%v, took=%v\n", + req.URL.Path, err, time.Since(begin)) + }(time.Now()) - err := s.ServeUserPage(c, id, pageType, maxID, minID) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - userSearchPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - id, _ := mux.Vars(req)["id"] - q := req.URL.Query().Get("q") - offsetStr := req.URL.Query().Get("offset") + var ct string + switch rt { + case HTML: + ct = "text/html; charset=utf-8" + case JSON: + ct = "application/json" + } + c.Header().Add("Content-Type", ct) - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) + err = authenticate(c, at) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) + writeError(c, err, rt) return } - } - err = s.ServeUserSearchPage(c, id, q, offset) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - aboutPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeAboutPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - emojisPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeEmojiPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } - } - - searchPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - q := req.URL.Query().Get("q") - qType := req.URL.Query().Get("type") - offsetStr := req.URL.Query().Get("offset") - - var offset int - var err error - if len(offsetStr) > 1 { - offset, err = strconv.Atoi(offsetStr) + err = f(c) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) + writeError(c, err, rt) return } } - - err = s.ServeSearchPage(c, q, qType, offset) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } } - settingsPage := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - - err := s.ServeSettingsPage(c) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + rootPage := handle(func(c *client) error { + sid, _ := c.Req.Cookie("session_id") + if sid == nil || len(sid.Value) < 0 { + redirect(c, "/signin") + return nil } - } - - signin := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, nil, "") - instance := req.FormValue("instance") - - url, sessionID, err := s.NewSession(instance) + session, err := s.sessionRepo.Get(sid.Value) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + if err == errInvalidSession { + redirect(c, "/signin") + return nil + } + return err } - - setSessionCookie(w, sessionID, sessionExp) - w.Header().Add("Location", url) - w.WriteHeader(http.StatusFound) - } - - oauthCallback := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - token := req.URL.Query().Get("code") - - _, _, err := s.Signin(c, "", token) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + if len(session.AccessToken) < 1 { + redirect(c, "/signin") + return nil } + return s.RootPage(c) + }, NOAUTH, HTML) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - post := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, "") - err := req.ParseMultipartForm(4 << 20) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return - } + navPage := handle(func(c *client) error { + return s.NavPage(c) + }, SESSION, HTML) - c = newClient(w, req, - getMultipartFormValue(req.MultipartForm, "csrf_token")) - content := getMultipartFormValue(req.MultipartForm, "content") - replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") - format := getMultipartFormValue(req.MultipartForm, "format") - visibility := getMultipartFormValue(req.MultipartForm, "visibility") - isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") - files := req.MultipartForm.File["attachments"] + signinPage := handle(func(c *client) error { + instance, ok := s.SingleInstance() + if !ok { + return s.SigninPage(c) + } + url, sid, err := s.NewSession(instance) + if err != nil { + return err + } + setSessionCookie(c, sid, sessionExp) + redirect(c, url) + return nil + }, NOAUTH, HTML) + + timelinePage := handle(func(c *client) error { + tType, _ := mux.Vars(c.Req)["type"] + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.TimelinePage(c, tType, maxID, minID) + return nil + }, SESSION, HTML) + + defaultTimelinePage := handle(func(c *client) error { + redirect(c, "/timeline/home") + return nil + }, SESSION, HTML) + + threadPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() + reply := q.Get("reply") + return s.ThreadPage(c, id, len(reply) > 1) + }, SESSION, HTML) + + likedByPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + return s.LikedByPage(c, id) + }, SESSION, HTML) + + retweetedByPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + return s.RetweetedByPage(c, id) + }, SESSION, HTML) + + notificationsPage := handle(func(c *client) error { + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.NotificationPage(c, maxID, minID) + }, SESSION, HTML) + + userPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + pageType, _ := mux.Vars(c.Req)["type"] + q := c.Req.URL.Query() + maxID := q.Get("max_id") + minID := q.Get("min_id") + return s.UserPage(c, id, pageType, maxID, minID) + }, SESSION, HTML) + + userSearchPage := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() + sq := q.Get("q") + offset, _ := strconv.Atoi(q.Get("offset")) + return s.UserSearchPage(c, id, sq, offset) + }, SESSION, HTML) + + aboutPage := handle(func(c *client) error { + return s.AboutPage(c) + }, SESSION, HTML) + + emojisPage := handle(func(c *client) error { + return s.EmojiPage(c) + }, SESSION, HTML) + + searchPage := handle(func(c *client) error { + q := c.Req.URL.Query() + sq := q.Get("q") + qType := q.Get("type") + offset, _ := strconv.Atoi(q.Get("offset")) + return s.SearchPage(c, sq, qType, offset) + }, SESSION, HTML) + + settingsPage := handle(func(c *client) error { + return s.SettingsPage(c) + }, SESSION, HTML) + + signin := handle(func(c *client) error { + instance := c.Req.FormValue("instance") + url, sid, err := s.NewSession(instance) + if err != nil { + return err + } + setSessionCookie(c, sid, sessionExp) + redirect(c, url) + return nil + }, NOAUTH, HTML) + + oauthCallback := handle(func(c *client) error { + q := c.Req.URL.Query() + token := q.Get("code") + token, userID, err := s.Signin(c, token) + if err != nil { + return err + } + + c.Session.AccessToken = token + c.Session.UserID = userID + err = s.sessionRepo.Add(c.Session) + if err != nil { + return err + } + + redirect(c, "/") + return nil + }, SESSION, HTML) + + post := handle(func(c *client) error { + content := c.Req.FormValue("content") + replyToID := c.Req.FormValue("reply_to_id") + format := c.Req.FormValue("format") + visibility := c.Req.FormValue("visibility") + isNSFW := c.Req.FormValue("is_nsfw") == "on" + files := c.Req.MultipartForm.File["attachments"] id, err := s.Post(c, content, replyToID, format, visibility, isNSFW, files) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - location := req.Header.Get("Referer") + location := c.Req.Header.Get("Referer") if len(replyToID) > 0 { location = "/thread/" + replyToID + "#status-" + id } - w.Header().Add("Location", location) - w.WriteHeader(http.StatusFound) - } - - like := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, location) + return nil + }, CSRF, HTML) + like := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.Like(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unlike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unlike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.UnLike(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - retweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + retweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.Retweet(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unretweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unretweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") _, err := s.UnRetweet(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - vote := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - statusID := req.FormValue("status_id") - choices, _ := req.PostForm["choices"] - + vote := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + statusID := c.Req.FormValue("status_id") + choices, _ := c.Req.PostForm["choices"] err := s.Vote(c, id, choices) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")+"#status-"+statusID) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+statusID) - w.WriteHeader(http.StatusFound) - } - - follow := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + follow := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + q := c.Req.URL.Query() var reblogs *bool - r, ok := req.URL.Query()["reblogs"] - if ok && len(r) > 0 { + if r, ok := q["reblogs"]; ok && len(r) > 0 { reblogs = new(bool) *reblogs = r[0] == "true" } - err := s.Follow(c, id, reblogs) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unfollow := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unfollow := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnFollow(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - mute := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + mute := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Mute(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unMute := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unMute := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnMute(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - block := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + block := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Block(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unBlock := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unBlock := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnBlock(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - subscribe := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + subscribe := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Subscribe(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unSubscribe := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unSubscribe := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnSubscribe(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - settings := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - visibility := req.FormValue("visibility") - format := req.FormValue("format") - copyScope := req.FormValue("copy_scope") == "true" - threadInNewTab := req.FormValue("thread_in_new_tab") == "true" - hideAttachments := req.FormValue("hide_attachments") == "true" - maskNSFW := req.FormValue("mask_nsfw") == "true" - ni, _ := strconv.Atoi(req.FormValue("notification_interval")) - fluorideMode := req.FormValue("fluoride_mode") == "true" - darkMode := req.FormValue("dark_mode") == "true" - antiDopamineMode := req.FormValue("anti_dopamine_mode") == "true" + settings := handle(func(c *client) error { + visibility := c.Req.FormValue("visibility") + format := c.Req.FormValue("format") + copyScope := c.Req.FormValue("copy_scope") == "true" + threadInNewTab := c.Req.FormValue("thread_in_new_tab") == "true" + hideAttachments := c.Req.FormValue("hide_attachments") == "true" + maskNSFW := c.Req.FormValue("mask_nsfw") == "true" + ni, _ := strconv.Atoi(c.Req.FormValue("notification_interval")) + fluorideMode := c.Req.FormValue("fluoride_mode") == "true" + darkMode := c.Req.FormValue("dark_mode") == "true" + antiDopamineMode := c.Req.FormValue("anti_dopamine_mode") == "true" settings := &model.Settings{ DefaultVisibility: visibility, @@ -609,192 +490,123 @@ func NewHandler(s Service, staticDir string) http.Handler { err := s.SaveSettings(c, settings) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, "/") + return nil + }, CSRF, HTML) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - muteConversation := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + muteConversation := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.MuteConversation(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - unMuteConversation := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + unMuteConversation := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.UnMuteConversation(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - delete := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + delete := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] err := s.Delete(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - readNotifications := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - maxID := req.URL.Query().Get("max_id") - + readNotifications := handle(func(c *client) error { + q := c.Req.URL.Query() + maxID := q.Get("max_id") err := s.ReadNotifications(c, maxID) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } + redirect(c, c.Req.Header.Get("Referer")) + return nil + }, CSRF, HTML) - w.Header().Add("Location", req.Header.Get("Referer")) - w.WriteHeader(http.StatusFound) - } - - bookmark := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") - + bookmark := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") err := s.Bookmark(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - unBookmark := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - retweetedByID := req.FormValue("retweeted_by_id") + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + unBookmark := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] + rid := c.Req.FormValue("retweeted_by_id") err := s.UnBookmark(c, id) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - s.ServeErrorPage(c, err) - return + return err } - - rID := id - if len(retweetedByID) > 0 { - rID = retweetedByID + if len(rid) > 0 { + id = rid } - w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) - w.WriteHeader(http.StatusFound) - } - - signout := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) + redirect(c, c.Req.Header.Get("Referer")+"#status-"+id) + return nil + }, CSRF, HTML) + signout := handle(func(c *client) error { s.Signout(c) + setSessionCookie(c, "", 0) + redirect(c, "/") + return nil + }, CSRF, HTML) - setSessionCookie(w, "", 0) - w.Header().Add("Location", "/") - w.WriteHeader(http.StatusFound) - } - - fLike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + fLike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.Like(c, id) if err != nil { - serveJsonError(w, err) - return + return err } + return writeJson(c, count) + }, CSRF, JSON) - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return - } - } - - fUnlike := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] - + fUnlike := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.UnLike(c, id) if err != nil { - serveJsonError(w, err) - return - } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return + return err } - } - - fRetweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] + return writeJson(c, count) + }, CSRF, JSON) + fRetweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.Retweet(c, id) if err != nil { - serveJsonError(w, err) - return - } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return + return err } - } - - fUnretweet := func(w http.ResponseWriter, req *http.Request) { - c := newClient(w, req, req.FormValue("csrf_token")) - id, _ := mux.Vars(req)["id"] + return writeJson(c, count) + }, CSRF, JSON) + fUnretweet := handle(func(c *client) error { + id, _ := mux.Vars(c.Req)["id"] count, err := s.UnRetweet(c, id) if err != nil { - serveJsonError(w, err) - return + return err } - - err = serveJson(w, count) - if err != nil { - serveJsonError(w, err) - return - } - } + return writeJson(c, count) + }, CSRF, JSON) r.HandleFunc("/", rootPage).Methods(http.MethodGet) r.HandleFunc("/nav", navPage).Methods(http.MethodGet) diff --git a/util/kv.go b/util/kv.go new file mode 100644 index 0000000..df61654 --- /dev/null +++ b/util/kv.go @@ -0,0 +1,91 @@ +package util + +import ( + "errors" + "io/ioutil" + "os" + "path/filepath" + "strings" + "sync" +) + +var ( + errInvalidKey = errors.New("invalid key") + errNoSuchKey = errors.New("no such key") +) + +type Database struct { + cache map[string][]byte + basedir string + m sync.RWMutex +} + +func NewDatabse(basedir string) (db *Database, err error) { + err = os.Mkdir(basedir, 0755) + if err != nil && !os.IsExist(err) { + return + } + + return &Database{ + cache: make(map[string][]byte), + basedir: basedir, + }, nil +} + +func (db *Database) Set(key string, val []byte) (err error) { + if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { + return errInvalidKey + } + + err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644) + if err != nil { + return + } + + db.m.Lock() + db.cache[key] = val + db.m.Unlock() + + return +} + +func (db *Database) Get(key string) (val []byte, err error) { + if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { + return nil, errInvalidKey + } + + db.m.RLock() + data, ok := db.cache[key] + db.m.RUnlock() + + if !ok { + data, err = ioutil.ReadFile(filepath.Join(db.basedir, key)) + if err != nil { + err = errNoSuchKey + return nil, err + } + + db.m.Lock() + db.cache[key] = data + db.m.Unlock() + } + + val = make([]byte, len(data)) + copy(val, data) + + return +} + +func (db *Database) Remove(key string) { + if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) { + return + } + + os.Remove(filepath.Join(db.basedir, key)) + + db.m.Lock() + delete(db.cache, key) + db.m.Unlock() + + return +} -- cgit v1.2.3