Skip to content

Commit 5b5182e

Browse files
committed
feat(vector): implement MagnitudeVector and add test cases
1 parent fb294b0 commit 5b5182e

2 files changed

Lines changed: 116 additions & 2 deletions

File tree

geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,80 @@
1919

2020
package org.apache.geaflow.ai.index.vector;
2121

22+
import java.util.Objects;
23+
2224
public class MagnitudeVector implements IVector {
2325

26+
private final double magnitude;
27+
28+
public MagnitudeVector() {
29+
this.magnitude = 0.0;
30+
}
31+
32+
public MagnitudeVector(double magnitude) {
33+
this.magnitude = magnitude;
34+
}
35+
36+
public double getMagnitude() {
37+
return magnitude;
38+
}
39+
2440
@Override
2541
public double match(IVector other) {
26-
return 0;
42+
if (!(other instanceof MagnitudeVector)) {
43+
throw new IllegalArgumentException("Other vector must be a MagnitudeVector");
44+
}
45+
46+
MagnitudeVector otherVec = (MagnitudeVector) other;
47+
double otherMagnitude = otherVec.magnitude;
48+
49+
return computeSimilarity(otherMagnitude);
50+
2751
}
2852

53+
private double computeSimilarity(double otherMagnitude) {
54+
if (this.magnitude == 0.0 && otherMagnitude == 0.0) {
55+
return 1.0;
56+
}
57+
58+
if (this.magnitude == 0.0 || otherMagnitude == 0.0) {
59+
return 0.0;
60+
}
61+
62+
double diff = Math.abs(this.magnitude - otherMagnitude);
63+
double max = Math.max(Math.abs(this.magnitude), Math.abs(otherMagnitude));
64+
65+
if (max == 0.0) {
66+
return 1.0;
67+
}
68+
69+
return 1.0 - (diff / max);
70+
}
71+
2972
@Override
3073
public VectorType getType() {
3174
return VectorType.MagnitudeVector;
3275
}
3376

77+
@Override
78+
public boolean equals(Object o) {
79+
if (this == o) {
80+
return true;
81+
}
82+
if (o == null || getClass() != o.getClass()) {
83+
return false;
84+
}
85+
MagnitudeVector that = (MagnitudeVector) o;
86+
return Double.compare(that.magnitude, magnitude) == 0;
87+
}
88+
89+
@Override
90+
public int hashCode() {
91+
return Objects.hash(magnitude);
92+
}
93+
3494
@Override
3595
public String toString() {
36-
return "MagnitudeVector{}";
96+
return "MagnitudeVector{magnitude=" + magnitude + '}';
3797
}
3898
}

geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import org.apache.geaflow.ai.index.EntityAttributeIndexStore;
2828
import org.apache.geaflow.ai.index.IndexStore;
2929
import org.apache.geaflow.ai.index.vector.EmbeddingVector;
30+
import org.apache.geaflow.ai.index.vector.IVector;
3031
import org.apache.geaflow.ai.index.vector.KeywordVector;
3132
import org.apache.geaflow.ai.index.vector.MagnitudeVector;
3233
import org.apache.geaflow.ai.index.vector.TraversalVector;
34+
import org.apache.geaflow.ai.index.vector.VectorType;
3335
import org.apache.geaflow.ai.search.VectorSearch;
3436
import org.apache.geaflow.ai.verbalization.Context;
3537
import org.apache.geaflow.ai.verbalization.SubgraphSemanticPromptFunction;
@@ -38,6 +40,9 @@
3840
import org.slf4j.Logger;
3941
import org.slf4j.LoggerFactory;
4042

43+
import static org.junit.jupiter.api.Assertions.assertEquals;
44+
import static org.junit.jupiter.api.Assertions.assertNotEquals;
45+
4146
public class GraphMemoryTest {
4247

4348
private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class);
@@ -52,6 +57,55 @@ public void testVectorSearch() {
5257
LOGGER.info(String.valueOf(search));
5358
}
5459

60+
// ========== MagnitudeVector Tests ==========
61+
62+
@Test
63+
public void testMagnitudeVectorConstructorAndGetter() {
64+
MagnitudeVector vector = new MagnitudeVector(0.85);
65+
assertEquals(vector.getMagnitude(), 0.85, 0.0001);
66+
}
67+
68+
@Test
69+
public void testMagnitudeVectorMatchExactSameValue() {
70+
MagnitudeVector v1 = new MagnitudeVector(5.0);
71+
MagnitudeVector v2 = new MagnitudeVector(5.0);
72+
73+
assertEquals(v1.match(v2), 1.0, 0.0001);
74+
}
75+
76+
@Test
77+
public void testMagnitudeVectorMatchDifferentValues() {
78+
MagnitudeVector v1 = new MagnitudeVector(10.0);
79+
MagnitudeVector v2 = new MagnitudeVector(5.0);
80+
81+
// Expected: 1 - |10-5|/max(10,5) = 1 - 5/10 = 0.5
82+
assertEquals(v1.match(v2), 0.5, 0.0001);
83+
}
84+
@Test
85+
public void testMagnitudeVectorEqualsAndHashCode() {
86+
MagnitudeVector v1 = new MagnitudeVector(5.0);
87+
MagnitudeVector v2 = new MagnitudeVector(5.0);
88+
MagnitudeVector v3 = new MagnitudeVector(10.0);
89+
90+
assertEquals(v1, v2);
91+
assertEquals(v1.hashCode(), v2.hashCode());
92+
assertNotEquals(v1, v3);
93+
}
94+
95+
@Test
96+
public void testMagnitudeVectorToString() {
97+
MagnitudeVector vector = new MagnitudeVector(0.75);
98+
String str = vector.toString();
99+
100+
assertEquals(str, "MagnitudeVector{magnitude=0.75}");
101+
}
102+
103+
@Test
104+
public void testMagnitudeVectorGetType() {
105+
MagnitudeVector vector = new MagnitudeVector(1.0);
106+
assertEquals(vector.getType(), VectorType.MagnitudeVector);
107+
}
108+
55109
@Test
56110
public void testEmptyMainPipeline() {
57111
GraphMemoryServer server = new GraphMemoryServer();

0 commit comments

Comments
 (0)