diff --git a/psnr_score_for_challenge.py b/psnr_score_for_challenge.py new file mode 100644 index 0000000..6223db6 --- /dev/null +++ b/psnr_score_for_challenge.py @@ -0,0 +1,89 @@ +import os +import glob +import numpy as np +import json +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from tqdm import tqdm +from moviepy.video.io.VideoFileClip import VideoFileClip +import PIL.Image + + +def calculate_psnr(img1, img2): + mse = np.mean((img1.astype(np.float64) - img2.astype(np.float64)) ** 2) + if mse == 0: + return float('inf') + max_pixel = 255.0 + psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) + return psnr + + +def process_video_psnr(gt_path, pred_path): + try: + clip_gt = VideoFileClip(gt_path) + clip_pred = VideoFileClip(pred_path) + + fps = min(clip_gt.fps, clip_pred.fps) + duration = min(clip_gt.duration, clip_pred.duration) + + time_points = np.arange(0, duration, 1.0 / fps) + + video_psnrs = [] + + for t in time_points: + frame_gt = clip_gt.get_frame(t) + frame_pred = clip_pred.get_frame(t) + + img_gt = PIL.Image.fromarray(frame_gt).resize((256, 256), PIL.Image.Resampling.BILINEAR) + img_pred = PIL.Image.fromarray(frame_pred).resize((256, 256), PIL.Image.Resampling.BILINEAR) + + psnr = calculate_psnr(np.array(img_gt), np.array(img_pred)) + video_psnrs.append(psnr) + + clip_gt.close() + clip_pred.close() + + return np.mean(video_psnrs) if video_psnrs else 0.0 + + except Exception as e: + print(f"Error processing {os.path.basename(gt_path)}: {e}") + return None + + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--gt_video', type=str, required=True, help='path to reference videos') + parser.add_argument('--pred_video', type=str, required=True, help='path to pred videos') + parser.add_argument('--output_file', type=str, default=None, help='path to output file') + args = parser.parse_args() + + if not os.path.exists(args.gt_video): + print(f"Error: GT video not found at {args.gt_video}") + return + if not os.path.exists(args.pred_video): + print(f"Error: Pred video not found at {args.pred_video}") + return + + print(f"Comparing:\nRef: {args.gt_video}\nPred: {args.pred_video}") + + v_psnr = process_video_psnr(args.gt_video, args.pred_video) + + if v_psnr is not None: + print("-" * 30) + print(f"Video PSNR: {v_psnr:.4f} dB") + print("-" * 30) + + if args.output_file: + result = { + "gt_video": args.gt_video, + "pred_video": args.pred_video, + "psnr": v_psnr + } + with open(args.output_file, 'w') as f: + json.dump(result, f, indent=4) + print(f"Result saved to {args.output_file}") + else: + print("Failed to calculate PSNR.") + + +if __name__ == '__main__': + main()