Skip to content

Commit 5b0d193

Browse files
h3n4lclaude
andauthored
feat: add db.collection.distinct() operation support (#10)
Add support for the distinct() method which returns an array of unique values for a specified field across a collection. Syntax: - db.collection.distinct("field") - db.collection.distinct("field", { filter }) Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 8a58a0c commit 5b0d193

3 files changed

Lines changed: 267 additions & 0 deletions

File tree

executor.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gomongo
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67

78
"github.com/antlr4-go/antlr/v4"
@@ -79,6 +80,8 @@ func executeOperation(ctx context.Context, client *mongo.Client, database string
7980
return executeCountDocuments(ctx, client, database, op)
8081
case opEstimatedDocumentCount:
8182
return executeEstimatedDocumentCount(ctx, client, database, op)
83+
case opDistinct:
84+
return executeDistinct(ctx, client, database, op)
8285
default:
8386
return nil, &UnsupportedOperationError{
8487
Operation: statement,
@@ -376,3 +379,51 @@ func executeEstimatedDocumentCount(ctx context.Context, client *mongo.Client, da
376379
RowCount: 1,
377380
}, nil
378381
}
382+
383+
// executeDistinct executes a db.collection.distinct() command.
384+
func executeDistinct(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) {
385+
collection := client.Database(database).Collection(op.collection)
386+
387+
filter := op.filter
388+
if filter == nil {
389+
filter = bson.D{}
390+
}
391+
392+
result := collection.Distinct(ctx, op.distinctField, filter)
393+
if err := result.Err(); err != nil {
394+
return nil, fmt.Errorf("distinct failed: %w", err)
395+
}
396+
397+
var values []any
398+
if err := result.Decode(&values); err != nil {
399+
return nil, fmt.Errorf("decode failed: %w", err)
400+
}
401+
402+
var rows []string
403+
for _, val := range values {
404+
jsonBytes, err := marshalValue(val)
405+
if err != nil {
406+
return nil, fmt.Errorf("marshal failed: %w", err)
407+
}
408+
rows = append(rows, string(jsonBytes))
409+
}
410+
411+
return &Result{
412+
Rows: rows,
413+
RowCount: len(rows),
414+
}, nil
415+
}
416+
417+
// marshalValue marshals a value to JSON.
418+
// bson.MarshalExtJSONIndent only works for documents/arrays at top level,
419+
// so we use encoding/json for primitive values (strings, numbers, booleans).
420+
func marshalValue(val any) ([]byte, error) {
421+
switch v := val.(type) {
422+
case bson.M, bson.D, map[string]any:
423+
return bson.MarshalExtJSONIndent(v, false, false, "", " ")
424+
case bson.A, []any:
425+
return bson.MarshalExtJSONIndent(v, false, false, "", " ")
426+
default:
427+
return json.Marshal(v)
428+
}
429+
}

executor_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,3 +1755,138 @@ func TestEstimatedDocumentCountWithEmptyOptions(t *testing.T) {
17551755
require.NotNil(t, result)
17561756
require.Equal(t, "2", result.Rows[0])
17571757
}
1758+
1759+
func TestDistinct(t *testing.T) {
1760+
client, cleanup := setupTestContainer(t)
1761+
defer cleanup()
1762+
1763+
ctx := context.Background()
1764+
1765+
// Create a collection with documents
1766+
collection := client.Database("testdb").Collection("users")
1767+
_, err := collection.InsertMany(ctx, []any{
1768+
bson.M{"name": "alice", "status": "active"},
1769+
bson.M{"name": "bob", "status": "inactive"},
1770+
bson.M{"name": "charlie", "status": "active"},
1771+
bson.M{"name": "diana", "status": "active"},
1772+
})
1773+
require.NoError(t, err)
1774+
1775+
gc := gomongo.NewClient(client)
1776+
1777+
// Test distinct on status field
1778+
result, err := gc.Execute(ctx, "testdb", `db.users.distinct("status")`)
1779+
require.NoError(t, err)
1780+
require.NotNil(t, result)
1781+
require.Equal(t, 2, result.RowCount)
1782+
1783+
// Verify both values are present
1784+
values := make(map[string]bool)
1785+
for _, row := range result.Rows {
1786+
values[row] = true
1787+
}
1788+
require.True(t, values[`"active"`] || values[`"inactive"`])
1789+
}
1790+
1791+
func TestDistinctWithFilter(t *testing.T) {
1792+
client, cleanup := setupTestContainer(t)
1793+
defer cleanup()
1794+
1795+
ctx := context.Background()
1796+
1797+
// Create a collection with documents
1798+
collection := client.Database("testdb").Collection("products")
1799+
_, err := collection.InsertMany(ctx, []any{
1800+
bson.M{"category": "electronics", "brand": "Apple", "price": 999},
1801+
bson.M{"category": "electronics", "brand": "Samsung", "price": 799},
1802+
bson.M{"category": "electronics", "brand": "Apple", "price": 1299},
1803+
bson.M{"category": "clothing", "brand": "Nike", "price": 99},
1804+
bson.M{"category": "clothing", "brand": "Adidas", "price": 89},
1805+
})
1806+
require.NoError(t, err)
1807+
1808+
gc := gomongo.NewClient(client)
1809+
1810+
// Test distinct with filter
1811+
result, err := gc.Execute(ctx, "testdb", `db.products.distinct("brand", { category: "electronics" })`)
1812+
require.NoError(t, err)
1813+
require.NotNil(t, result)
1814+
require.Equal(t, 2, result.RowCount)
1815+
1816+
// Verify only electronics brands are returned
1817+
values := make(map[string]bool)
1818+
for _, row := range result.Rows {
1819+
values[row] = true
1820+
}
1821+
require.True(t, values[`"Apple"`])
1822+
require.True(t, values[`"Samsung"`])
1823+
require.False(t, values[`"Nike"`])
1824+
require.False(t, values[`"Adidas"`])
1825+
}
1826+
1827+
func TestDistinctEmptyCollection(t *testing.T) {
1828+
client, cleanup := setupTestContainer(t)
1829+
defer cleanup()
1830+
1831+
ctx := context.Background()
1832+
1833+
gc := gomongo.NewClient(client)
1834+
1835+
// Test distinct on empty/non-existent collection
1836+
result, err := gc.Execute(ctx, "testdb", `db.users.distinct("status")`)
1837+
require.NoError(t, err)
1838+
require.NotNil(t, result)
1839+
require.Equal(t, 0, result.RowCount)
1840+
require.Empty(t, result.Rows)
1841+
}
1842+
1843+
func TestDistinctBracketNotation(t *testing.T) {
1844+
client, cleanup := setupTestContainer(t)
1845+
defer cleanup()
1846+
1847+
ctx := context.Background()
1848+
1849+
// Create a collection with hyphenated name
1850+
collection := client.Database("testdb").Collection("user-logs")
1851+
_, err := collection.InsertMany(ctx, []any{
1852+
bson.M{"level": "info"},
1853+
bson.M{"level": "warn"},
1854+
bson.M{"level": "error"},
1855+
bson.M{"level": "info"},
1856+
})
1857+
require.NoError(t, err)
1858+
1859+
gc := gomongo.NewClient(client)
1860+
1861+
// Test with bracket notation
1862+
result, err := gc.Execute(ctx, "testdb", `db["user-logs"].distinct("level")`)
1863+
require.NoError(t, err)
1864+
require.NotNil(t, result)
1865+
require.Equal(t, 3, result.RowCount)
1866+
}
1867+
1868+
func TestDistinctNumericValues(t *testing.T) {
1869+
client, cleanup := setupTestContainer(t)
1870+
defer cleanup()
1871+
1872+
ctx := context.Background()
1873+
1874+
// Create a collection with numeric values
1875+
collection := client.Database("testdb").Collection("scores")
1876+
_, err := collection.InsertMany(ctx, []any{
1877+
bson.M{"score": 100},
1878+
bson.M{"score": 85},
1879+
bson.M{"score": 100},
1880+
bson.M{"score": 90},
1881+
bson.M{"score": 85},
1882+
})
1883+
require.NoError(t, err)
1884+
1885+
gc := gomongo.NewClient(client)
1886+
1887+
// Test distinct on numeric field
1888+
result, err := gc.Execute(ctx, "testdb", `db.scores.distinct("score")`)
1889+
require.NoError(t, err)
1890+
require.NotNil(t, result)
1891+
require.Equal(t, 3, result.RowCount) // 100, 85, 90
1892+
}

translator.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const (
2424
opGetIndexes
2525
opCountDocuments
2626
opEstimatedDocumentCount
27+
opDistinct
2728
)
2829

2930
// mongoOperation represents a parsed MongoDB operation.
@@ -40,6 +41,8 @@ type mongoOperation struct {
4041
pipeline bson.A
4142
// countDocuments options
4243
hint any // string (index name) or document (index spec)
44+
// distinct field name
45+
distinctField string
4346
}
4447

4548
// mongoShellVisitor extracts operations from a parse tree.
@@ -177,6 +180,81 @@ func (v *mongoShellVisitor) extractGetCollectionInfosArgs(ctx *mongodb.GetCollec
177180
v.operation.filter = filter
178181
}
179182

183+
func (v *mongoShellVisitor) extractDistinctArgs(ctx *mongodb.GenericMethodContext) {
184+
args := ctx.Arguments()
185+
if args == nil {
186+
v.err = fmt.Errorf("distinct() requires a field name argument")
187+
return
188+
}
189+
190+
argsCtx, ok := args.(*mongodb.ArgumentsContext)
191+
if !ok {
192+
v.err = fmt.Errorf("distinct() requires a field name argument")
193+
return
194+
}
195+
196+
allArgs := argsCtx.AllArgument()
197+
if len(allArgs) == 0 {
198+
v.err = fmt.Errorf("distinct() requires a field name argument")
199+
return
200+
}
201+
202+
// First argument is the field name (required)
203+
firstArg, ok := allArgs[0].(*mongodb.ArgumentContext)
204+
if !ok {
205+
v.err = fmt.Errorf("distinct() requires a field name argument")
206+
return
207+
}
208+
209+
valueCtx := firstArg.Value()
210+
if valueCtx == nil {
211+
v.err = fmt.Errorf("distinct() requires a field name argument")
212+
return
213+
}
214+
215+
literalValue, ok := valueCtx.(*mongodb.LiteralValueContext)
216+
if !ok {
217+
v.err = fmt.Errorf("distinct() field name must be a string")
218+
return
219+
}
220+
221+
stringLiteral, ok := literalValue.Literal().(*mongodb.StringLiteralValueContext)
222+
if !ok {
223+
v.err = fmt.Errorf("distinct() field name must be a string")
224+
return
225+
}
226+
227+
v.operation.distinctField = unquoteString(stringLiteral.StringLiteral().GetText())
228+
229+
// Second argument is the filter (optional)
230+
if len(allArgs) < 2 {
231+
return
232+
}
233+
234+
secondArg, ok := allArgs[1].(*mongodb.ArgumentContext)
235+
if !ok {
236+
return
237+
}
238+
239+
filterValueCtx := secondArg.Value()
240+
if filterValueCtx == nil {
241+
return
242+
}
243+
244+
docValue, ok := filterValueCtx.(*mongodb.DocumentValueContext)
245+
if !ok {
246+
v.err = fmt.Errorf("distinct() filter must be a document")
247+
return
248+
}
249+
250+
filter, err := convertDocument(docValue.Document())
251+
if err != nil {
252+
v.err = fmt.Errorf("invalid filter: %w", err)
253+
return
254+
}
255+
v.operation.filter = filter
256+
}
257+
180258
func (v *mongoShellVisitor) extractCountDocumentsArgs(ctx *mongodb.GenericMethodContext) {
181259
args := ctx.Arguments()
182260
if args == nil {
@@ -527,6 +605,9 @@ func (v *mongoShellVisitor) visitMethodCall(ctx mongodb.IMethodCallContext) {
527605
v.extractCountDocumentsArgs(gmCtx)
528606
case "estimatedDocumentCount":
529607
v.operation.opType = opEstimatedDocumentCount
608+
case "distinct":
609+
v.operation.opType = opDistinct
610+
v.extractDistinctArgs(gmCtx)
530611
default:
531612
v.err = &UnsupportedOperationError{
532613
Operation: methodName,

0 commit comments

Comments
 (0)