Skip to content

Commit

Permalink
Refactor remaining tables
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Dec 25, 2023
1 parent 4148c7b commit b453b08
Show file tree
Hide file tree
Showing 24 changed files with 534 additions and 334 deletions.
7 changes: 6 additions & 1 deletion commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"strings"

"github.com/google/uuid"
"github.com/skip2/go-qrcode"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/event"
Expand Down Expand Up @@ -223,7 +224,11 @@ func fnLogin(ce *WrappedCommandEvent) {

// Update user with SignalID
if signalID != "" {
ce.User.SignalID = signalID
ce.User.SignalID, err = uuid.Parse(signalID)
if err != nil {
ce.Reply("Problem logging in - SignalID is not a valid UUID")
return
}
ce.User.SignalUsername = signalUsername
} else {
ce.Reply("Problem logging in - No SignalID received")
Expand Down
26 changes: 13 additions & 13 deletions database/disappearingmessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ type DisappearingMessageQuery struct {
type DisappearingMessage struct {
qh *dbutil.QueryHelper[*DisappearingMessage]

RoomID id.RoomID
EventID id.EventID
ExpireInSeconds int64 // TODO change to time.Duration
ExpireAt time.Time
RoomID id.RoomID
EventID id.EventID
ExpireIn time.Duration
ExpireAt time.Time
}

func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
return &DisappearingMessage{qh: qh}
}

func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireInSeconds int64, expireAt time.Time) *DisappearingMessage {
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
return &DisappearingMessage{
qh: dmq.QueryHelper,
RoomID: roomID,
EventID: eventID,
ExpireInSeconds: expireInSeconds,
ExpireAt: expireAt,
qh: dmq.QueryHelper,
RoomID: roomID,
EventID: eventID,
ExpireIn: expireIn,
ExpireAt: expireAt,
}
}

Expand All @@ -96,7 +96,7 @@ func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage
if err != nil {
return nil, err
}
msg.ExpireInSeconds = expireIn
msg.ExpireIn = time.Duration(expireIn) * time.Second
if expireAt.Valid {
msg.ExpireAt = time.Unix(expireAt.Int64, 0)
}
Expand All @@ -109,15 +109,15 @@ func (msg *DisappearingMessage) sqlVariables() []any {
expireAt.Valid = true
expireAt.Int64 = msg.ExpireAt.Unix()
}
return []any{msg.RoomID, msg.EventID, msg.ExpireInSeconds, expireAt}
return []any{msg.RoomID, msg.EventID, int64(msg.ExpireIn.Seconds()), expireAt}
}

func (msg *DisappearingMessage) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
}

func (msg *DisappearingMessage) StartExpirationTimer(ctx context.Context) error {
msg.ExpireAt = time.Now().Add(time.Duration(msg.ExpireInSeconds) * time.Second)
msg.ExpireAt = time.Now().Add(msg.ExpireIn)
return msg.qh.Exec(ctx, updateDisappearingMessageQuery, msg.EventID, msg.ExpireAt.Unix())
}

Expand Down
15 changes: 8 additions & 7 deletions database/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"strings"

"github.com/google/uuid"
"maunium.net/go/mautrix/id"

"go.mau.fi/util/dbutil"
Expand Down Expand Up @@ -79,12 +80,12 @@ type MessageQuery struct {
type Message struct {
qh *dbutil.QueryHelper[*Message]

Sender string
Sender uuid.UUID
Timestamp uint64
PartIndex int

SignalChatID string
SignalReceiver string
SignalReceiver uuid.UUID

MXID id.EventID
RoomID id.RoomID
Expand All @@ -98,23 +99,23 @@ func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Messag
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
}

func (mq *MessageQuery) GetBySignalIDWithUnknownReceiver(ctx context.Context, sender string, timestamp uint64, partIndex int, receiver string) (*Message, error) {
func (mq *MessageQuery) GetBySignalIDWithUnknownReceiver(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) {
return mq.QueryOne(ctx, getMessagePartBySignalIDWithUnknownReceiverQuery, sender, timestamp, partIndex, receiver)
}

func (mq *MessageQuery) GetBySignalID(ctx context.Context, sender string, timestamp uint64, partIndex int, receiver string) (*Message, error) {
func (mq *MessageQuery) GetBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) {
return mq.QueryOne(ctx, getMessagePartBySignalIDQuery, sender, timestamp, partIndex, receiver)
}

func (mq *MessageQuery) GetLastPartBySignalID(ctx context.Context, sender string, timestamp uint64, receiver string) (*Message, error) {
func (mq *MessageQuery) GetLastPartBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) (*Message, error) {
return mq.QueryOne(ctx, getLastMessagePartBySignalIDQuery, sender, timestamp, receiver)
}

func (mq *MessageQuery) GetAllPartsBySignalID(ctx context.Context, sender string, timestamp uint64, receiver string) ([]*Message, error) {
func (mq *MessageQuery) GetAllPartsBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) ([]*Message, error) {
return mq.QueryMany(ctx, getAllMessagePartsBySignalIDQuery, sender, timestamp, receiver)
}

func (mq *MessageQuery) GetManyBySignalID(ctx context.Context, sender string, timestamps []uint64, receiver string) ([]*Message, error) {
func (mq *MessageQuery) GetManyBySignalID(ctx context.Context, sender uuid.UUID, timestamps []uint64, receiver uuid.UUID) ([]*Message, error) {
if mq.GetDB().Dialect == dbutil.Postgres {
return mq.QueryMany(ctx, getManyMessagesBySignalIDQueryPostgres, sender, receiver, timestamps)
} else {
Expand Down
53 changes: 27 additions & 26 deletions database/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package database
import (
"context"
"database/sql"
"fmt"

"github.com/google/uuid"
"maunium.net/go/mautrix/id"

"go.mau.fi/util/dbutil"

"go.mau.fi/mautrix-signal/pkg/signalmeow"
)

const (
Expand Down Expand Up @@ -56,11 +58,23 @@ type PortalQuery struct {
}

type PortalKey struct {
ChatID string // TODO use some kind of union type between *uuid.UUID and a group ID as bytes?
Receiver string // TODO change to *uuid.UUID?
ChatID string
Receiver uuid.UUID
}

func (pk *PortalKey) UserID() uuid.UUID {
parsed, _ := uuid.Parse(pk.ChatID)
return parsed
}

func NewPortalKey(chatID, receiver string) PortalKey {
func (pk *PortalKey) GroupID() signalmeow.GroupIdentifier {
if len(pk.ChatID) == 44 {
return signalmeow.GroupIdentifier(pk.ChatID)
}
return ""
}

func NewPortalKey(chatID string, receiver uuid.UUID) PortalKey {
return PortalKey{
ChatID: chatID,
Receiver: receiver,
Expand Down Expand Up @@ -96,7 +110,7 @@ func (pq *PortalQuery) GetByChatID(ctx context.Context, pk PortalKey) (*Portal,
return pq.QueryOne(ctx, getPortalByChatIDQuery, pk.ChatID, pk.Receiver)
}

func (pq *PortalQuery) FindPrivateChatsOf(ctx context.Context, receiver string) ([]*Portal, error) {
func (pq *PortalQuery) FindPrivateChatsOf(ctx context.Context, receiver uuid.UUID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPortalsByReceiver, receiver)
}

Expand All @@ -105,39 +119,26 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
}

func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, name, topic, avatarHash, avatarURL, relayUserID sql.NullString
var expirationTime sql.NullInt64
var mxid sql.NullString
err := row.Scan(
&p.ChatID,
&p.Receiver,
&mxid,
&name,
&topic,
&avatarHash,
&avatarURL,
&p.Name,
&p.Topic,
&p.AvatarHash,
&p.AvatarURL,
&p.NameSet,
&p.AvatarSet,
&p.Revision,
&p.Encrypted,
&relayUserID,
&expirationTime,
&p.RelayUserID,
&p.ExpirationTime,
)
if err != nil {
return nil, err
}
p.MXID = id.RoomID(mxid.String)
p.Name = name.String
p.Topic = topic.String
p.AvatarHash = avatarHash.String
p.RelayUserID = id.UserID(relayUserID.String)
p.ExpirationTime = int(expirationTime.Int64)
if len(avatarURL.String) > 0 {
parsedAvatarURL, err := id.ParseContentURI(avatarURL.String)
if err != nil {
return nil, fmt.Errorf("failed to parse avatar URL: %w", err)
}
p.AvatarURL = parsedAvatarURL
}
return p, nil
}

Expand All @@ -149,7 +150,7 @@ func (p *Portal) sqlVariables() []any {
p.Name,
p.Topic,
p.AvatarHash,
p.AvatarURL.String(),
p.AvatarURL,
p.NameSet,
p.AvatarSet,
p.Revision,
Expand Down
31 changes: 10 additions & 21 deletions database/puppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package database
import (
"context"
"database/sql"
"fmt"

"github.com/google/uuid"
"maunium.net/go/mautrix/id"

"go.mau.fi/util/dbutil"
Expand Down Expand Up @@ -62,7 +62,7 @@ type PuppetQuery struct {
type Puppet struct {
qh *dbutil.QueryHelper[*Puppet]

SignalID string // TODO change to uuid.UUID
SignalID uuid.UUID
Number string
Name string
NameQuality int
Expand All @@ -82,7 +82,7 @@ func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
return &Puppet{qh: qh}
}

func (pq *PuppetQuery) GetBySignalID(ctx context.Context, signalID string) (*Puppet, error) {
func (pq *PuppetQuery) GetBySignalID(ctx context.Context, signalID uuid.UUID) (*Puppet, error) {
return pq.QueryOne(ctx, getPuppetBySignalIDQuery, signalID)
}

Expand All @@ -99,37 +99,26 @@ func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, err
}

func (p *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
var number, name, avatarHash, avatarURL, customMXID, accessToken sql.NullString
var number, customMXID sql.NullString
err := row.Scan(
&p.SignalID,
&number,
&name,
&p.Name,
&p.NameQuality,
&avatarHash,
&avatarURL,
&p.AvatarHash,
&p.AvatarURL,
&p.NameSet,
&p.AvatarSet,
&p.ContactInfoSet,
&p.IsRegistered,
&customMXID,
&accessToken,
&p.AccessToken,
)
if err != nil {
return nil, nil
}
if len(avatarURL.String) > 0 {
parsedAvatarURL, err := id.ParseContentURI(avatarURL.String)
if err != nil {
return nil, fmt.Errorf("failed to parse avatar URL: %w", err)
}
p.AvatarURL = parsedAvatarURL
}

p.Number = number.String
p.Name = name.String
p.AvatarHash = avatarHash.String
p.CustomMXID = id.UserID(customMXID.String)
p.AccessToken = accessToken.String
return p, nil
}

Expand All @@ -140,12 +129,12 @@ func (p *Puppet) sqlVariables() []any {
p.Name,
p.NameQuality,
p.AvatarHash,
p.AvatarURL.String(),
p.AvatarURL,
p.NameSet,
p.AvatarSet,
p.ContactInfoSet,
p.IsRegistered,
p.CustomMXID.String(),
dbutil.StrPtr(p.CustomMXID),
p.AccessToken,
}
}
Expand Down
9 changes: 5 additions & 4 deletions database/reaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package database
import (
"context"

"github.com/google/uuid"
"maunium.net/go/mautrix/id"

"go.mau.fi/util/dbutil"
Expand Down Expand Up @@ -47,13 +48,13 @@ func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
type Reaction struct {
qh *dbutil.QueryHelper[*Reaction]

MsgAuthor string // TODO change to uuid.UUID
MsgAuthor uuid.UUID
MsgTimestamp uint64
Author string // TODO change to uuid.UUID
Author uuid.UUID
Emoji string

SignalChatID string
SignalReceiver string // TODO change to uuid.UUID
SignalReceiver uuid.UUID

MXID id.EventID
RoomID id.RoomID
Expand All @@ -63,7 +64,7 @@ func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*React
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
}

func (rq *ReactionQuery) GetBySignalID(ctx context.Context, msgAuthor string, msgTimestamp uint64, author, signalReceiver string) (*Reaction, error) {
func (rq *ReactionQuery) GetBySignalID(ctx context.Context, msgAuthor uuid.UUID, msgTimestamp uint64, author, signalReceiver uuid.UUID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionBySignalIDQuery, msgAuthor, msgTimestamp, author, signalReceiver)
}

Expand Down
Loading

0 comments on commit b453b08

Please sign in to comment.