Skip to content

Commit

Permalink
Improve variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Aug 31, 2024
1 parent 17a5e13 commit e207ad7
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions src/transformers/models/llava/image_processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,40 +177,40 @@ def pad_to_square(
Returns:
`np.ndarray`: The padded image.
"""
h, w = get_image_size(image, input_data_format)
c = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]

if h == w:
if height == width:
return image

max_dim = max(h, w)
max_dim = max(height, width)

# Ensure background_color is the correct shape
if isinstance(background_color, (int, float)):
background_color = [background_color] * c
elif len(background_color) != c:
raise ValueError(f"background_color must have {c} elements to match the number of channels")
background_color = [background_color] * num_channels
elif len(background_color) != num_channels:
raise ValueError(f"background_color must have {num_channels} elements to match the number of channels")

if input_data_format == ChannelDimension.FIRST:
result = np.zeros((c, max_dim, max_dim), dtype=image.dtype)
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if w > h:
start = (max_dim - h) // 2
result[:, start : start + h, :] = image
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - w) // 2
result[:, :, start : start + w] = image
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, c), dtype=image.dtype)
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if w > h:
start = (max_dim - h) // 2
result[start : start + h, :, :] = image
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - w) // 2
result[:, start : start + w, :] = image
start = (max_dim - width) // 2
result[:, start : start + width, :] = image

return result

Expand Down

0 comments on commit e207ad7

Please sign in to comment.