Skip to content

Commit 909fff4

Browse files
robelkinrobelkin-rational-partnersclaude
authored
Add account filter to MCP search, list, and aggregate tools (#110)
## Summary - Adds optional `account` parameter to `search_messages`, `list_messages`, and `aggregate` MCP tools - Allows filtering results to a specific archived Gmail account when multiple accounts are synced - Account parameter accepts an email address (e.g., "alice@gmail.com") - Users can discover available accounts via the `get_stats` tool ## Test plan - [x] Added `TestAccountFilter` with 7 test cases covering valid/invalid accounts for all three tools - [x] All existing MCP tests pass - [x] `go vet` passes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Rob Elkin <rob@rational.partners> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 63b11df commit 909fff4

File tree

3 files changed

+146
-4
lines changed

3 files changed

+146
-4
lines changed

internal/mcp/handlers.go

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ type handlers struct {
2525
attachmentsDir string
2626
}
2727

28+
// getAccountID looks up a source ID by email address.
29+
// Returns nil if account is empty (no filter), or an error if not found.
30+
func (h *handlers) getAccountID(ctx context.Context, account string) (*int64, error) {
31+
if account == "" {
32+
return nil, nil
33+
}
34+
accounts, err := h.engine.ListAccounts(ctx)
35+
if err != nil {
36+
return nil, fmt.Errorf("failed to list accounts: %w", err)
37+
}
38+
for _, acc := range accounts {
39+
if acc.Identifier == account {
40+
return &acc.ID, nil
41+
}
42+
}
43+
return nil, fmt.Errorf("account not found: %s", account)
44+
}
45+
2846
// getIDArg extracts a required positive integer ID from the arguments map.
2947
func getIDArg(args map[string]any, key string) (int64, error) {
3048
v, ok := args[key].(float64)
@@ -94,10 +112,20 @@ func (h *handlers) searchMessages(ctx context.Context, req mcp.CallToolRequest)
94112
limit := limitArg(args, "limit", 20)
95113
offset := limitArg(args, "offset", 0)
96114

115+
// Look up account filter
116+
account, _ := args["account"].(string)
117+
sourceID, err := h.getAccountID(ctx, account)
118+
if err != nil {
119+
return mcp.NewToolResultError(err.Error()), nil
120+
}
121+
97122
q := search.Parse(queryStr)
123+
q.AccountID = sourceID
124+
125+
filter := query.MessageFilter{SourceID: sourceID}
98126

99127
// Try fast search first (metadata only), fall back to full FTS.
100-
results, err := h.engine.SearchFast(ctx, q, query.MessageFilter{}, limit, offset)
128+
results, err := h.engine.SearchFast(ctx, q, filter, limit, offset)
101129
if err != nil {
102130
return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err)), nil
103131
}
@@ -276,7 +304,15 @@ func (h *handlers) exportAttachment(ctx context.Context, req mcp.CallToolRequest
276304
func (h *handlers) listMessages(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
277305
args := req.GetArguments()
278306

307+
// Look up account filter
308+
account, _ := args["account"].(string)
309+
sourceID, err := h.getAccountID(ctx, account)
310+
if err != nil {
311+
return mcp.NewToolResultError(err.Error()), nil
312+
}
313+
279314
filter := query.MessageFilter{
315+
SourceID: sourceID,
280316
Pagination: query.Pagination{
281317
Limit: limitArg(args, "limit", 20),
282318
Offset: limitArg(args, "offset", 0),
@@ -295,7 +331,6 @@ func (h *handlers) listMessages(ctx context.Context, req mcp.CallToolRequest) (*
295331
if v, ok := args["has_attachment"].(bool); ok && v {
296332
filter.WithAttachmentsOnly = true
297333
}
298-
var err error
299334
if filter.After, err = getDateArg(args, "after"); err != nil {
300335
return mcp.NewToolResultError(err.Error()), nil
301336
}
@@ -341,11 +376,18 @@ func (h *handlers) aggregate(ctx context.Context, req mcp.CallToolRequest) (*mcp
341376
return mcp.NewToolResultError("group_by parameter is required"), nil
342377
}
343378

379+
// Look up account filter
380+
account, _ := args["account"].(string)
381+
sourceID, err := h.getAccountID(ctx, account)
382+
if err != nil {
383+
return mcp.NewToolResultError(err.Error()), nil
384+
}
385+
344386
opts := query.AggregateOptions{
345-
Limit: limitArg(args, "limit", 50),
387+
SourceID: sourceID,
388+
Limit: limitArg(args, "limit", 50),
346389
}
347390

348-
var err error
349391
if opts.After, err = getDateArg(args, "after"); err != nil {
350392
return mcp.NewToolResultError(err.Error()), nil
351393
}

internal/mcp/server.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ func withBefore() mcp.ToolOption {
4646
)
4747
}
4848

49+
func withAccount() mcp.ToolOption {
50+
return mcp.WithString("account",
51+
mcp.Description("Filter by account email address (use get_stats to list available accounts)"),
52+
)
53+
}
54+
4955
// Serve creates an MCP server with email archive tools and serves over stdio.
5056
// It blocks until stdin is closed or the context is cancelled.
5157
func Serve(ctx context.Context, engine query.Engine, attachmentsDir string) error {
@@ -77,6 +83,7 @@ func searchMessagesTool() mcp.Tool {
7783
mcp.Required(),
7884
mcp.Description("Gmail-style search query (e.g. 'from:alice subject:meeting after:2024-01-01')"),
7985
),
86+
withAccount(),
8087
withLimit("20"),
8188
withOffset(),
8289
)
@@ -121,6 +128,7 @@ func listMessagesTool() mcp.Tool {
121128
return mcp.NewTool(ToolListMessages,
122129
mcp.WithDescription("List messages with optional filters. Returns message summaries sorted by date."),
123130
mcp.WithReadOnlyHintAnnotation(true),
131+
withAccount(),
124132
mcp.WithString("from",
125133
mcp.Description("Filter by sender email address"),
126134
),
@@ -156,6 +164,7 @@ func aggregateTool() mcp.Tool {
156164
mcp.Description("Dimension to group by"),
157165
mcp.Enum("sender", "recipient", "domain", "label", "time"),
158166
),
167+
withAccount(),
159168
withLimit("50"),
160169
withAfter(),
161170
withBefore(),

internal/mcp/server_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,3 +769,94 @@ func TestLimitArgClamping(t *testing.T) {
769769
})
770770
}
771771
}
772+
773+
func TestAccountFilter(t *testing.T) {
774+
eng := &querytest.MockEngine{
775+
Accounts: []query.AccountInfo{
776+
{ID: 1, Identifier: "alice@gmail.com"},
777+
{ID: 2, Identifier: "bob@gmail.com"},
778+
},
779+
SearchFastResults: []query.MessageSummary{
780+
testutil.NewMessageSummary(1).WithSubject("Test").WithFromEmail("alice@gmail.com").Build(),
781+
},
782+
ListResults: []query.MessageSummary{
783+
testutil.NewMessageSummary(2).WithSubject("List Test").WithFromEmail("bob@gmail.com").Build(),
784+
},
785+
AggregateRows: []query.AggregateRow{
786+
{Key: "alice@gmail.com", Count: 100},
787+
},
788+
}
789+
h := newTestHandlers(eng)
790+
791+
t.Run("search with valid account", func(t *testing.T) {
792+
msgs := runTool[[]query.MessageSummary](t, "search_messages", h.searchMessages, map[string]any{
793+
"query": "test",
794+
"account": "alice@gmail.com",
795+
})
796+
if len(msgs) != 1 {
797+
t.Fatalf("expected 1 message, got %d", len(msgs))
798+
}
799+
})
800+
801+
t.Run("search with invalid account", func(t *testing.T) {
802+
r := runToolExpectError(t, "search_messages", h.searchMessages, map[string]any{
803+
"query": "test",
804+
"account": "unknown@gmail.com",
805+
})
806+
txt := resultText(t, r)
807+
if !strings.Contains(txt, "account not found") {
808+
t.Fatalf("expected 'account not found' error, got: %s", txt)
809+
}
810+
})
811+
812+
t.Run("list with valid account", func(t *testing.T) {
813+
msgs := runTool[[]query.MessageSummary](t, "list_messages", h.listMessages, map[string]any{
814+
"account": "bob@gmail.com",
815+
})
816+
if len(msgs) != 1 {
817+
t.Fatalf("expected 1 message, got %d", len(msgs))
818+
}
819+
})
820+
821+
t.Run("list with invalid account", func(t *testing.T) {
822+
r := runToolExpectError(t, "list_messages", h.listMessages, map[string]any{
823+
"account": "unknown@gmail.com",
824+
})
825+
txt := resultText(t, r)
826+
if !strings.Contains(txt, "account not found") {
827+
t.Fatalf("expected 'account not found' error, got: %s", txt)
828+
}
829+
})
830+
831+
t.Run("aggregate with valid account", func(t *testing.T) {
832+
rows := runTool[[]query.AggregateRow](t, "aggregate", h.aggregate, map[string]any{
833+
"group_by": "sender",
834+
"account": "alice@gmail.com",
835+
})
836+
if len(rows) != 1 {
837+
t.Fatalf("expected 1 row, got %d", len(rows))
838+
}
839+
})
840+
841+
t.Run("aggregate with invalid account", func(t *testing.T) {
842+
r := runToolExpectError(t, "aggregate", h.aggregate, map[string]any{
843+
"group_by": "sender",
844+
"account": "unknown@gmail.com",
845+
})
846+
txt := resultText(t, r)
847+
if !strings.Contains(txt, "account not found") {
848+
t.Fatalf("expected 'account not found' error, got: %s", txt)
849+
}
850+
})
851+
852+
t.Run("empty account means no filter", func(t *testing.T) {
853+
// Empty string should not filter - return all results
854+
msgs := runTool[[]query.MessageSummary](t, "search_messages", h.searchMessages, map[string]any{
855+
"query": "test",
856+
"account": "",
857+
})
858+
if len(msgs) != 1 {
859+
t.Fatalf("expected 1 message, got %d", len(msgs))
860+
}
861+
})
862+
}

0 commit comments

Comments
 (0)