Skip to content

Commit 63ac36e

Browse files
Copilotalexarje
andauthored
Fix from_numpy path bug and add numpy/array tests
Agent-Logs-Url: https://github.com/fourMs/MGT-python/sessions/804afced-8cfa-4b53-b671-348fedf3ab55 Co-authored-by: alexarje <114316+alexarje@users.noreply.github.com>
1 parent f7fd040 commit 63ac36e

4 files changed

Lines changed: 194 additions & 8 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ tests/pytest.ini
3737
tests/htmlcov
3838
musicalgestures/pose/body_25/pose_iter_584000.caffemodel
3939
musicalgestures/pose/body_25/.wget-hsts
40+
dancer_grid.png

musicalgestures/_video.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,14 @@ def __repr__(self):
276276
return f"MgVideo('{self.filename}')"
277277

278278
def numpy(self):
279-
"Pipe all video frames from FFmpeg to numpy array"
279+
"""
280+
Read all video frames into a numpy array using FFmpeg.
280281
282+
Returns:
283+
tuple: A tuple ``(array, fps)`` where ``array`` is a ``numpy.ndarray``
284+
of shape ``(N, H, W, 3)`` in BGR format (uint8) containing all N
285+
frames, and ``fps`` is the frame rate of the video.
286+
"""
281287
# Define ffmpeg command and load all the video frames in memory
282288
cmd = ["ffmpeg", "-y", "-i", self.filename]
283289
process = ffmpeg_cmd(cmd, total_time=self.length, pipe="load")
@@ -289,13 +295,26 @@ def numpy(self):
289295
return array, self.fps
290296

291297
def from_numpy(self, array, fps, target_name=None):
292-
if target_name is not None:
293-
self.filename = os.path.splitext(target_name)[0] + self.fex
298+
"""
299+
Writes a numpy array of video frames to a video file using FFmpeg.
294300
295-
if self.path is not None:
296-
target_name = os.path.join(self.path, self.filename)
301+
After writing, updates ``self.filename``, ``self.of``, and ``self.fex`` to
302+
reflect the actual output path so that subsequent operations on this object
303+
refer to the newly created file.
304+
305+
Args:
306+
array (np.ndarray): Video frames array with shape (N, H, W, 3) in BGR format.
307+
fps (float): Frames per second for the output video.
308+
target_name (str, optional): Full path for the output file. If None, uses
309+
``self.path/self.filename`` (or just ``self.filename`` if path is None).
310+
Defaults to None.
311+
"""
312+
if target_name is not None:
313+
write_path = os.path.splitext(target_name)[0] + self.fex
314+
elif self.path is not None:
315+
write_path = os.path.join(self.path, self.filename)
297316
else:
298-
target_name = self.filename
317+
write_path = self.filename
299318

300319
process = None
301320
for frame in array:
@@ -319,14 +338,16 @@ def from_numpy(self, array, fps, target_name=None):
319338
"libx264",
320339
"-pix_fmt",
321340
"yuv420p",
322-
target_name,
341+
write_path,
323342
]
324343
process = ffmpeg_cmd(cmd, total_time=array.shape[0], pipe="write")
325344
process.stdin.write(frame.astype(np.uint8))
326345
process.stdin.close()
327346
process.wait()
328347

329-
return
348+
# Update self.filename to the actual written path so that get_video() can find the file
349+
self.filename = write_path
350+
self.of, self.fex = os.path.splitext(write_path)
330351

331352
def extract_frame(self, **kwargs):
332353
"""
-464 KB
Binary file not shown.

tests/test_numpy.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Tests for numpy array read/write and memory-based processing flow (issue #294)."""
2+
import os
3+
import numpy as np
4+
import pytest
5+
import musicalgestures
6+
from musicalgestures._utils import extract_subclip
7+
8+
9+
@pytest.fixture(scope="module")
10+
def testvideo_avi(tmp_path_factory):
11+
target_name = os.path.join(str(tmp_path_factory.mktemp("data")), "testvideo.avi")
12+
return extract_subclip(musicalgestures.examples.dance, 5, 6, target_name=target_name)
13+
14+
15+
class Test_MgVideo_numpy:
16+
"""Tests for MgVideo.numpy() – read video frames as numpy array."""
17+
18+
def test_returns_tuple(self, testvideo_avi):
19+
mg = musicalgestures.MgVideo(testvideo_avi)
20+
result = mg.numpy()
21+
assert isinstance(result, tuple)
22+
assert len(result) == 2
23+
24+
def test_array_shape(self, testvideo_avi):
25+
mg = musicalgestures.MgVideo(testvideo_avi)
26+
array, fps = mg.numpy()
27+
# shape should be (N_frames, height, width, 3)
28+
assert array.ndim == 4
29+
assert array.shape[1] == mg.height
30+
assert array.shape[2] == mg.width
31+
assert array.shape[3] == 3
32+
33+
def test_array_dtype(self, testvideo_avi):
34+
mg = musicalgestures.MgVideo(testvideo_avi)
35+
array, fps = mg.numpy()
36+
assert array.dtype == np.uint8
37+
38+
def test_fps_matches(self, testvideo_avi):
39+
mg = musicalgestures.MgVideo(testvideo_avi)
40+
array, fps = mg.numpy()
41+
assert fps == mg.fps
42+
43+
def test_frame_count(self, testvideo_avi):
44+
mg = musicalgestures.MgVideo(testvideo_avi)
45+
array, fps = mg.numpy()
46+
from musicalgestures._utils import get_framecount
47+
expected_frames = get_framecount(testvideo_avi)
48+
assert array.shape[0] == expected_frames
49+
50+
51+
class Test_MgAudio_numpy:
52+
"""Tests for MgAudio.numpy() – read audio as numpy array."""
53+
54+
def test_returns_array(self, testvideo_avi):
55+
mg = musicalgestures.MgVideo(testvideo_avi)
56+
result = mg.audio.numpy()
57+
assert isinstance(result, np.ndarray)
58+
59+
def test_array_1d(self, testvideo_avi):
60+
mg = musicalgestures.MgVideo(testvideo_avi)
61+
result = mg.audio.numpy()
62+
assert result.ndim == 1
63+
64+
def test_sample_rate_set(self, testvideo_avi):
65+
mg = musicalgestures.MgVideo(testvideo_avi)
66+
mg.audio.numpy()
67+
assert mg.audio.sr > 0
68+
69+
def test_array_length_matches_duration(self, testvideo_avi):
70+
mg = musicalgestures.MgVideo(testvideo_avi)
71+
result = mg.audio.numpy()
72+
# Audio duration = n_samples / sr, should be roughly 1 second (we extracted 5-6 s)
73+
duration = len(result) / mg.audio.sr
74+
assert 0.5 < duration < 2.0
75+
76+
77+
class Test_MgVideo_from_numpy:
78+
"""Tests for creating MgVideo from a numpy array (via __init__ array parameter)."""
79+
80+
def test_init_with_array_no_path(self, testvideo_avi, tmp_path):
81+
mg = musicalgestures.MgVideo(testvideo_avi)
82+
array, fps = mg.numpy()
83+
out_file = str(tmp_path / "from_arr.avi")
84+
new_mg = musicalgestures.MgVideo(
85+
filename=out_file,
86+
array=array[:30],
87+
fps=fps,
88+
)
89+
assert os.path.isfile(new_mg.filename)
90+
assert new_mg.fps == fps
91+
assert new_mg.width == array.shape[2]
92+
assert new_mg.height == array.shape[1]
93+
94+
def test_init_with_array_and_path(self, testvideo_avi, tmp_path):
95+
mg = musicalgestures.MgVideo(testvideo_avi)
96+
array, fps = mg.numpy()
97+
new_mg = musicalgestures.MgVideo(
98+
filename="arr_output.avi",
99+
array=array[:30],
100+
fps=fps,
101+
path=str(tmp_path),
102+
)
103+
expected_path = os.path.join(str(tmp_path), "arr_output.avi")
104+
assert new_mg.filename == expected_path
105+
assert os.path.isfile(new_mg.filename)
106+
107+
def test_from_numpy_direct_call(self, testvideo_avi, tmp_path):
108+
mg = musicalgestures.MgVideo(testvideo_avi)
109+
array, fps = mg.numpy()
110+
target = str(tmp_path / "direct.avi")
111+
mg.from_numpy(array[:30], fps, target_name=target)
112+
assert os.path.isfile(target)
113+
114+
def test_roundtrip_frame_count(self, testvideo_avi, tmp_path):
115+
"""Array written to disk should have the same number of frames."""
116+
mg = musicalgestures.MgVideo(testvideo_avi)
117+
array, fps = mg.numpy()
118+
n_frames = 20
119+
out_file = str(tmp_path / "roundtrip.avi")
120+
new_mg = musicalgestures.MgVideo(
121+
filename=out_file,
122+
array=array[:n_frames],
123+
fps=fps,
124+
)
125+
from musicalgestures._utils import get_framecount
126+
assert get_framecount(new_mg.filename) == n_frames
127+
128+
129+
class Test_mg_grid_return_array:
130+
"""Tests for mg_grid() memory-based flow (return_array=True)."""
131+
132+
def test_return_array_type(self, testvideo_avi):
133+
mg = musicalgestures.MgVideo(testvideo_avi)
134+
result = mg.grid(height=100, rows=2, cols=2, return_array=True)
135+
assert isinstance(result, np.ndarray)
136+
137+
def test_return_array_dtype(self, testvideo_avi):
138+
mg = musicalgestures.MgVideo(testvideo_avi)
139+
result = mg.grid(height=100, rows=2, cols=2, return_array=True)
140+
assert result.dtype == np.uint8
141+
142+
def test_return_array_shape(self, testvideo_avi):
143+
mg = musicalgestures.MgVideo(testvideo_avi)
144+
rows, cols, height = 2, 3, 100
145+
result = mg.grid(height=height, rows=rows, cols=cols, return_array=True)
146+
assert result.ndim == 3
147+
assert result.shape[0] == height * rows
148+
assert result.shape[2] == 3 # RGB channels
149+
150+
def test_no_file_written(self, testvideo_avi, tmp_path):
151+
"""return_array=True should not write any file to disk."""
152+
mg = musicalgestures.MgVideo(testvideo_avi)
153+
of = os.path.splitext(testvideo_avi)[0]
154+
expected_file = of + "_grid.png"
155+
if os.path.exists(expected_file):
156+
os.remove(expected_file)
157+
mg.grid(height=100, rows=2, cols=2, return_array=True)
158+
assert not os.path.exists(expected_file)
159+
160+
def test_return_mgimage_when_no_array(self, testvideo_avi):
161+
mg = musicalgestures.MgVideo(testvideo_avi)
162+
result = mg.grid(height=100, rows=2, cols=2, return_array=False)
163+
assert isinstance(result, musicalgestures.MgImage)
164+
assert os.path.isfile(result.filename)

0 commit comments

Comments
 (0)