Skip to content

Commit

Permalink
Merge pull request #7437 from radarhere/apng
Browse files Browse the repository at this point in the history
  • Loading branch information
hugovk committed Oct 13, 2023
2 parents 5666c05 + c9ba107 commit 101154e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
10 changes: 8 additions & 2 deletions Tests/test_file_apng.py
Expand Up @@ -673,10 +673,16 @@ def test_seek_after_close():


@pytest.mark.parametrize("mode", ("RGBA", "RGB", "P"))
def test_different_modes_in_later_frames(mode, tmp_path):
@pytest.mark.parametrize("default_image", (True, False))
def test_different_modes_in_later_frames(mode, default_image, tmp_path):
test_file = str(tmp_path / "temp.png")

im = Image.new("L", (1, 1))
im.save(test_file, save_all=True, append_images=[Image.new(mode, (1, 1))])
im.save(
test_file,
save_all=True,
default_image=default_image,
append_images=[Image.new(mode, (1, 1))],
)
with Image.open(test_file) as reloaded:
assert reloaded.mode == mode
13 changes: 4 additions & 9 deletions src/PIL/PngImagePlugin.py
Expand Up @@ -1105,10 +1105,7 @@ def _write_multiple_frames(im, fp, chunk, rawmode, default_image, append_images)
if im_frame.mode == rawmode:
im_frame = im_frame.copy()
else:
if rawmode == "P":
im_frame = im_frame.convert(rawmode, palette=im.palette)
else:
im_frame = im_frame.convert(rawmode)
im_frame = im_frame.convert(rawmode)
encoderinfo = im.encoderinfo.copy()
if isinstance(duration, (list, tuple)):
encoderinfo["duration"] = duration[frame_count]
Expand Down Expand Up @@ -1167,6 +1164,8 @@ def _write_multiple_frames(im, fp, chunk, rawmode, default_image, append_images)

# default image IDAT (if it exists)
if default_image:
if im.mode != rawmode:
im = im.convert(rawmode)
ImageFile._save(im, _idat(fp, chunk), [("zip", (0, 0) + im.size, 0, rawmode)])

seq_num = 0
Expand Down Expand Up @@ -1228,11 +1227,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False):
)
modes = set()
append_images = im.encoderinfo.get("append_images", [])
if default_image:
chain = itertools.chain(append_images)
else:
chain = itertools.chain([im], append_images)
for im_seq in chain:
for im_seq in itertools.chain([im], append_images):
for im_frame in ImageSequence.Iterator(im_seq):
modes.add(im_frame.mode)
for mode in ("RGBA", "RGB", "P"):
Expand Down

0 comments on commit 101154e

Please sign in to comment.