目录
sam2 抠图生成png
缩放到512
sam2 抠图生成png
demo_image_png.py
import argparse import json import os.path as osp import time import numpy as np import gc import sys from PIL import Image sys.path.append("./sam2") from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor import os from glob import glob import supervision as sv import torch import cv2 from scipy.spatial import cKDTree def determine_model_cfg(model_path): if "large" in model_path: return "configs/samurai/sam2.1_hiera_l.yaml" elif "base_plus" in model_path: return "configs/samurai/sam2.1_hiera_b+.yaml" elif "small" in model_path: return "configs/samurai/sam2.1_hiera_s.yaml" elif "tiny" in model_path: return "configs/samurai/sam2.1_hiera_t.yaml" else: raise ValueError("Unknown model size in path!") def top_red_points(img, top_k=5): hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # 红色区间(两段) lower_red1 = np.array([0, 100, 100]) upper_red1 = np.array([10, 255, 255]) lower_red2 = np.array([170, 100, 100]) upper_red2 = np.array([180, 255, 255]) mask1 = cv2.inRange(hsv, lower_red1, upper_red1) mask2 = cv2.inRange(hsv, lower_red2, upper_red2) mask = cv2.bitwise_or(mask1, mask2) red_pixels = np.column_stack(np.where(mask > 0)) # (y, x) if red_pixels.size == 0: return [] scores = mask[red_pixels[:, 0], red_pixels[:, 1]] sorted_idx = np.argsort(scores)[::-1] result = [] for idx in sorted_idx[:top_k]: y, x = red_pixels[idx] # 注意顺序 (y,x) score = int(scores[idx]) result.append((int(x), int(y))) return result def min_distance_point_to_contour(point, contours): """ 计算 point 到所有轮廓的最小距离 contours: list of np.array, 每个形状为 (N,1,2) """ min_dist = float('inf') for cnt in contours: pts = cnt[:, 0, :] # shape (N,2) dists = np.linalg.norm(pts - point, axis=1) min_dist = min(min_dist, dists.min()) return min_dist class MultiPointDrawer: def __init__(self, img): self.img = img.copy() self.display = img.copy() self.points = [] # (x, y) self.labels = [] # 1=positive, 0=negative cv2.namedWindow("Click Points (L = + , R = - , Enter = OK)") cv2.setMouseCallback("Click Points (L = + , R = - , Enter = OK)", self.mouse_callback) def mouse_callback(self, event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: # 正样本点(红色) self.points.append((x, y)) self.labels.append(1) cv2.circle(self.display, (x, y), 4, (0, 0, 255), -1) # 红色点 elif event == cv2.EVENT_RBUTTONDOWN: # 负样本点(蓝色) self.points.append((x, y)) self.labels.append(0) cv2.circle(self.display, (x, y), 4, (255, 0, 0), -1) # 蓝色点 def run(self): while True: cv2.imshow("Click Points (L = + , R = - , Enter = OK)", self.display) key = cv2.waitKey(1) if key == 13: # Enter break if key == 27: # Esc 清空退出 self.points.clear() self.labels.clear() break cv2.destroyWindow("Click Points (L = + , R = - , Enter = OK)") return self.points, self.labels def main(args): model_cfg = determine_model_cfg(args.model_path) device = "cuda:0" sam2_image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, args.model_path, device=device)) # sam2_image_predictor.set_image_size(1024) start = time.time() # dir_a=r"D:\data\pred_res\1107_2107" # files=glob(os.path.join(dir_a, "*.jpg")) files=glob(r"D:\project_2025\live2d\mm01.jpg") for img_path in files: new_img_name = os.path.basename(img_path) frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1) h_o, w_o = frame.shape[:2] h_center, w_center = frame.shape[:2] input_boxes = [] sam2_image_predictor.set_image(frame) drawer = MultiPointDrawer(frame) clicked_points, clicked_labels = drawer.run() if len(clicked_points) == 0: print("⚠ 未选择任何点,跳过此图") continue # 转 numpy point_coords = np.array(clicked_points, dtype=np.float32) point_labels = np.array(clicked_labels, dtype=np.int32) masks, scores, logits = sam2_image_predictor.predict(point_coords=point_coords, point_labels=point_labels, box=None, multimask_output=False, ) mask = None if masks.ndim == 2: mask = masks masks = masks[None] elif masks.ndim == 3: mask = masks[0] elif masks.ndim == 4: masks = masks.squeeze(1) mask_img = mask.astype(np.uint8) non_zero_indices = np.argwhere(mask > 0) vis = frame.copy() for (y, x) in non_zero_indices: cv2.circle(vis, (x, y), 1, (0, 255, 0), -1) # 小绿点 mask=mask_img if mask.dtype == bool: mask_uint8 = mask.astype(np.uint8) * 255 else: mask_uint8 = mask.astype(np.uint8) * 255 # 将BGR转换为RGB rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 将mask扩展到3个通道 if len(mask_uint8.shape) == 2: mask_3channel = np.stack([mask_uint8, mask_uint8, mask_uint8], axis=2) else: mask_3channel = mask_uint8 # 创建RGBA图像 rgba_image = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) # 设置RGB通道 rgba_image[:, :, :3] = rgb_image # 设置Alpha通道(透明通道) rgba_image[:, :, 3] = mask_uint8 # 使用PIL保存为PNG pil_image = Image.fromarray(rgba_image) pil_image.save('mm01.png', 'PNG') cv2.imshow("Mask Points", vis) cv2.waitKey(0) if __name__ == "__main__": parser = argparse.ArgumentParser() # parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_large.pt",) parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_small.pt",) # parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_base_plus.pt", parser.add_argument("--save_to_video", default=True, help="Save results to a video.") args = parser.parse_args() main(args)缩放到512
resize_5_512.py
import cv2 import numpy as np from PIL import Image def resize_with_padding_to_512(input_path, output_path): """ 将PNG图像同比例缩小到最长边512,不够的边两侧补透明背景 Args: input_path: 输入PNG图像路径 output_path: 输出512x512 PNG图像路径 Returns: resized_image: 处理后的512x512 RGBA图像(numpy数组) """ # 使用PIL读取PNG图像,保持透明度通道 pil_image = Image.open(input_path).convert("RGBA") # 获取原始尺寸 width, height = pil_image.size # 计算缩放比例 scale = 512 / max(width, height) # 计算新尺寸 new_width = int(width * scale) new_height = int(height * scale) # 使用高质量抗锯齿缩小图像 resized_pil = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS) # 创建512x512的透明背景 canvas = Image.new("RGBA", (512, 512), (0, 0, 0, 0)) # 计算居中位置 x_offset = (512 - new_width) // 2 y_offset = (512 - new_height) // 2 # 将缩小的图像粘贴到透明画布上 canvas.paste(resized_pil, (x_offset, y_offset)) # 保存结果 canvas.save(output_path, "PNG") # 转换为numpy数组返回 resized_image = np.array(canvas) return resized_image def resize_with_padding_to_512_opencv(input_path, output_path): """ 使用OpenCV处理PNG图像(保留透明通道) Args: input_path: 输入PNG图像路径 output_path: 输出512x512 PNG图像路径 Returns: resized_image: 处理后的512x512 RGBA图像(numpy数组) """ # 读取图像,包括alpha通道 img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) # 检查是否有alpha通道 if img.shape[2] == 3: # 没有alpha通道 # 添加alpha通道 bgr_img = img alpha = np.ones((img.shape[0], img.shape[1]), dtype=np.uint8) * 255 img = cv2.merge([bgr_img[:, :, 0], bgr_img[:, :, 1], bgr_img[:, :, 2], alpha]) elif img.shape[2] == 4: # 有alpha通道 # OpenCV读取的是BGRA格式,需要转换为RGBA b, g, r, a = cv2.split(img) img = cv2.merge([r, g, b, a]) # 获取原始尺寸 height, width = img.shape[:2] # 计算缩放比例 scale = 512 / max(width, height) # 计算新尺寸 new_width = int(width * scale) new_height = int(height * scale) # 使用高质量插值方法缩小图像 resized = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) # 创建512x512的透明背景 canvas = np.zeros((512, 512, 4), dtype=np.uint8) # 计算居中位置 x_offset = (512 - new_width) // 2 y_offset = (512 - new_height) // 2 # 将缩小的图像放入画布 canvas[y_offset:y_offset +new_height, x_offset:x_offset +new_width] = resized # 将BGRA转回RGBA格式 r, g, b, a = cv2.split(canvas) rgba_canvas = cv2.merge([r, g, b, a]) # 使用PIL保存(保持透明度) pil_image = Image.fromarray(rgba_canvas, mode='RGBA') pil_image.save(output_path, "PNG") return rgba_canvas def batch_process_pngs(input_folder, output_folder, max_size=512): import os from glob import glob # 创建输出文件夹 os.makedirs(output_folder, exist_ok=True) # 获取所有PNG文件 png_files = glob(os.path.join(input_folder, "*.png")) + \ glob(os.path.join(input_folder, "*.PNG")) for png_path in png_files: # 获取文件名 filename = os.path.basename(png_path) output_path = os.path.join(output_folder, filename) try: # 处理图像 resized_image = resize_with_padding_to_512(png_path, output_path) print(f"✓ 已处理: {filename} -> {output_path}") except Exception as e: print(f"✗ 处理失败 {filename}: {str(e)}") return len(png_files) # 使用示例 if __name__ == "__main__": # 单个图像处理示例 input_path = r"mm01.png" output_path = r"mm01_512.png" resized_img = resize_with_padding_to_512(input_path, output_path) print(f"图像已保存到: {output_path}") print(f"图像尺寸: {resized_img.shape}") # # 批量处理示例 # input_folder = r"D:\project_2025\live2d\transparent_outputs" # output_folder = r"D:\project_2025\live2d\resized_512_outputs" # num_processed = batch_process_pngs(input_folder, output_folder) # print(f"批量处理完成,共处理 {num_processed} 个文件")