[Github page]
python 3.8
pytorch 2.0.1
cuda 11.7
torchvision 0.15.2
einops 0.6.1
kornia 0.7.0
xformers 0.0.21
opencv-python
matplotlib
Code
import warnings
import torch.nn as nn
from roma.models.matcher import *
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
from roma.models.encoders import *
def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs):
# roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
gp_dim = 512
feat_dim = 512
decoder_dim = gp_dim + feat_dim
cls_to_coord_res = 64
coordinate_decoder = TransformerDecoder(
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
decoder_dim,
cls_to_coord_res**2 + 1,
is_classifier=True,
amp = True,
pos_enc = False,)
dw = True
hidden_blocks = 8
kernel_size = 5
displacement_emb = "linear"
disable_local_corr_grad = True
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512+128+(2*7+1)**2,
2 * 512+128+(2*7+1)**2,
2 + 1,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=128,
local_corr_radius = 7,
corr_in_other = True,
amp = True,
disable_local_corr_grad = disable_local_corr_grad,
bn_momentum = 0.01,
),
"8": ConvRefiner(
2 * 512+64+(2*3+1)**2,
2 * 512+64+(2*3+1)**2,
2 + 1,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=64,
local_corr_radius = 3,
corr_in_other = True,
amp = True,
disable_local_corr_grad = disable_local_corr_grad,
bn_momentum = 0.01,
),
"4": ConvRefiner(
2 * 256+32+(2*2+1)**2,
2 * 256+32+(2*2+1)**2,
2 + 1,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=32,
local_corr_radius = 2,
corr_in_other = True,
amp = True,
disable_local_corr_grad = disable_local_corr_grad,
bn_momentum = 0.01,
),
"2": ConvRefiner(
2 * 64+16,
128+16,
2 + 1,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=16,
amp = True,
disable_local_corr_grad = disable_local_corr_grad,
bn_momentum = 0.01,
),
"1": ConvRefiner(
2 * 9 + 6,
24,
2 + 1,
kernel_size=kernel_size,
dw=dw,
hidden_blocks = hidden_blocks,
displacement_emb = displacement_emb,
displacement_emb_dim = 6,
amp = True,
disable_local_corr_grad = disable_local_corr_grad,
bn_momentum = 0.01,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"16": gp16})
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
proj = nn.ModuleDict({
"16": proj16,
"8": proj8,
"4": proj4,
"2": proj2,
"1": proj1,
})
displacement_dropout_p = 0.0
gm_warp_dropout_p = 0.0
decoder = Decoder(coordinate_decoder,
gps,
proj,
conv_refiner,
detach=True,
scales=["16", "8", "4", "2", "1"],
displacement_dropout_p = displacement_dropout_p,
gm_warp_dropout_p = gm_warp_dropout_p)
encoder = CNNandDinov2(
cnn_kwargs = dict(
pretrained=False,
amp = True),
amp = True,
use_vgg = True,
dinov2_weights = dinov2_weights
)
h,w = resolution
symmetric = True
attenuate_cert = True
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
matcher.load_state_dict(weights)
return matcher
demo folder에 작성되어 있는 예시 기반 code
from roma import roma_indoor, roma_outdoor
from PIL import Image
import numpy as np
import PIL
import torch
import pdb
import triton
import cv2
# ====================================================================================================== #
def draw_keypoints_on_image(image, keypoints, color='blue', radius=2, use_normalized_coordinates=False):
"""Draws keypoints on an image.
Args:
image: a PIL.Image object.
keypoints: a numpy array with shape [num_keypoints, 2].
color: color to draw the keypoints with. Default is red.
radius: keypoint radius. Default value is 2.
use_normalized_coordinates: if True (default), treat keypoint values as
relative to the image. Otherwise treat them as absolute.
"""
draw = PIL.ImageDraw.Draw(image)
im_width, im_height = image.size
keypoints_x = [k[1] for k in keypoints]
keypoints_y = [k[0] for k in keypoints]
if use_normalized_coordinates:
keypoints_x = tuple([im_width * x for x in keypoints_x])
keypoints_y = tuple([im_height * y for y in keypoints_y])
for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
(keypoint_x + radius, keypoint_y + radius)],
outline=color, fill=color)
# ====================================================================================================== #
# ====================================================================================================== #
def draw_line_on_image(image, kptsA, kptsB, color='red', size=0):
draw = PIL.ImageDraw.Draw(image)
length = kptsA.shape[0]
for i in range(length):
correspondence = [(kptsA[i][0], kptsA[i][1]), (kptsB[i][0], kptsB[i][1])]
draw.line(correspondence, fill=color, width=size)
# ====================================================================================================== #
# ============================================= MAIN =================================================== #
# ================= Set image path and cuda device ================= #
query_path = "~/src/RoMa/query.png"
cand_path = "~/src/RoMa/cand.png"
device = "cuda"
# ====================== Create RoMa Model ====================== #
# Create Model
roma_model = roma_indoor(device=device)
# ====================== Get Output Resolution ====================== #
# Output: 560 x 560
H, W = roma_model.get_output_resolution()
# ============ Change image size to fit output resolution ============ #
im1 = Image.open(query_path).resize((W, H))
im2 = Image.open(cand_path).resize((W, H))
# ============ Create image to show the correspondence pairs ============ #
# Create a new output image that concatenates the two images together
output_img = Image.new("RGB", (im1.width + im2.width, im1.height))
output_img.paste(im1, (0, 0))
output_img.paste(im2, (im1.width, 0))
# ====================== Match Two Images ====================== #
# Get warp and convariance (certainty)
# warp size: torch.Size([560, 1120, 4]) & type: <class 'torch.Tensor'>
# certainty size: torch.Size([10000]) & type: <class 'torch.Tensor'>
warp, certainty = roma_model.match(query_path, cand_path, device=device)
# ====================== Match Two Images ====================== #
# Sample matches for estimation
# matches size: torch.Size([10000, 4]) & type: <class 'torch.Tensor'>
matches, certainty = roma_model.sample(warp, certainty)
# ====================== Get Feature Points ====================== #
# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
# kptsA size: torch.Size([10000, 2]) & type: <class 'torch.Tensor'>
# kptsB size: torch.Size([10000, 2]) & type: <class 'torch.Tensor'>
kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H, W, H, W)
# ====================== Get Feature Points ====================== #
# Find a fundamental matrix (or anything else of interest)
F, mask = cv2.findFundamentalMat(
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
)
# ================= Convert PIL image channel to RGB ================= #
image1_pil = Image.fromarray(np.uint8(im1)).convert('RGB')
image2_pil = Image.fromarray(np.uint8(im2)).convert('RGB')
# ================= Select inlier points ================= #
# We select only inlier points
kptsA = kptsA[mask.ravel()==1]
kptsB = kptsB[mask.ravel()==1]
# ================= Draw Feature Points ================= #
draw_keypoints_on_image(image1_pil, np.array(kptsA.cpu()))
draw_keypoints_on_image(image2_pil, np.array(kptsB.cpu()))
# use numpy to convert the pil_image into a numpy array
numpy_image1 = np.array(image1_pil)
numpy_image2 = np.array(image2_pil)
# convert to a openCV2 image and convert from RGB to BGR format
cv_image1 = cv2.cvtColor(numpy_image1, cv2.COLOR_RGB2BGR)
cv_image2 = cv2.cvtColor(numpy_image2, cv2.COLOR_RGB2BGR)
# ================= Move feature points for plotting correspondence pair ================= #
# Get each image row & column
rows1, cols1 = cv_image1.shape[:2]
rows2, cols2 = cv_image2.shape[:2]
kptsC = np.array(kptsB.cpu()) + [cols1, 0]
# ================= Draw Matching Pairs ================= #
# Draw matching pair
draw_line_on_image(output_img, np.array(kptsA.cpu()), kptsC)
output_img.show()
Reference Site
[Python PIL | ImageDraw.Draw.line() - GeeksforGeeks](https://www.geeksforgeeks.org/python-pil-imagedraw-draw-line/) |
[Concatenate images with Python, Pillow | note.nkmk.me](https://note.nkmk.me/en/python-pillow-concat-images/) |
Query Image | Candidate Image |
Query Image | Candidate Image |
Query Image | Candidate Image |
# of inliers: 2978 | # of inliers: 512 | # of inliers: 697 |
# of inliers: 1596 | # of inliers: 312 | # of inliers: 323 |
# of inliers: 855 | # of inliers: 158 | # of inliers: 175 |
# of inliers: 146 | # of inliers: 31 | # of inliers: 25 |
# of inliers: 89 | # of inliers: 15 | # of inliers: 19 |