参考https://blog.csdn.net/qq_42178122/article/details/122787261博主的博文
import os import os.path as osp import time import shutil import torch import torchvision import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.optim import cv2 import numpy as np import models import argparse from utils.config import Config from runner.runner import Runner from datasets import build_dataloader color_list =[ (255, 0, 0), (255, 225, 0), (255, 0, 255), (125, 125, 125), (255, 125, 125), (0, 125, 0) ] def main(): args = parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus) cfg = Config.fromfile(args.config) cfg.gpus = len(args.gpus) cfg.load_from = args.load_from cfg.finetune_from = args.finetune_from cfg.view = args.view cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type cudnn.benchmark = True cudnn.fastest = True runner = Runner(cfg) runner.net.eval() val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False) def to_cuda(batch): for k in batch: if k == 'meta': continue batch[k] = batch[k].cuda() return batch def is_short(lane): start = [i for i, x in enumerate(lane) if x > 0] if not start: return 1 else: return 0 def probmap2lane( seg_pred, exist, b, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6): """ Arguments: ---------- seg_pred: np.array size (5, h, w) resize_shape: reshape size target, (H, W) exist: list of existence, e.g. [0, 1, 1, 0] smooth: whether to smooth the probability or not y_px_gap: y pixel gap for sampling pts: how many points for one lane thresh: probability threshold Return: ---------- coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ] """ if resize_shape is None: resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w) _, h, w = seg_pred.shape H, W = resize_shape coordinates = [] a = 0 for i in range(cfg.num_classes - 1): prob_map = seg_pred[i + 1] # seg_pred[0]:背景 if smooth: prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) coords = get_lane(prob_map, y_px_gap, pts, thresh, resize_shape) # print(exist) # if (int)(b[i]) == 0: # if (int)(exist[i])==0: # continue if is_short(coords): continue coordinates.append( [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in range(pts)]) # if (int)(exist[i])==1: # a =a+1 # if a==2: # break if len(coordinates) == 0: coords = np.zeros(pts) coordinates.append( [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in range(pts)]) # print(coordinates) return coordinates def fix_gap(coordinate): if any(x > 0 for x in coordinate): start = [i for i, x in enumerate(coordinate) if x > 0][0] end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0] lane = coordinate[start:end+1] if any(x < 0 for x in lane): gap_start = [i for i, x in enumerate( lane[:-1]) if x > 0 and lane[i+1] < 0] gap_end = [i+1 for i, x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0] gap_id = [i for i, x in enumerate(lane) if x < 0] if len(gap_start) == 0 or len(gap_end) == 0: return coordinate for id in gap_id: for i in range(len(gap_start)): if i >= len(gap_end): return coordinate if id > gap_start[i] and id < gap_end[i]: gap_width = float(gap_end[i] - gap_start[i]) lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + ( gap_end[i] - id) / gap_width * lane[gap_start[i]]) if not all(x > 0 for x in lane): print("Gaps still exist!") coordinate[start:end+1] = lane return coordinate def get_lane(prob_map, y_px_gap, pts, thresh, resize_shape=None): """ Arguments: ---------- prob_map: prob map for single lane, np array size (h, w) resize_shape: reshape size target, (H, W) Return: ---------- coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape """ if resize_shape is None: resize_shape = prob_map.shape h, w = prob_map.shape H, W = resize_shape H -= cfg.cut_height coords = np.zeros(pts) coords[:] = -1.0 for i in range(pts): y = int((H - 10 - i * y_px_gap) * h / H) if y < 0: break line = prob_map[y, :] id = np.argmax(line) if line[id] > thresh: coords[i] = int(id / w * W) if (coords > 0).sum() < 2: coords = np.zeros(pts) fix_gap(coords) # print(coords.shape) return coords def view(img, coords, file_path=None): i=0 for coord in coords: for x, y in coord: if x <= 0 or y <= 0: continue x, y = int(x), int(y) cv2.circle(img, (x, y), 4, color_list[i], 2) i = i+1 # if file_path is not None: # if not os.path.exists(osp.dirname(file_path)): # os.makedirs(osp.dirname(file_path)) # cv2.imwrite(file_path, img) import time time_start = time.clock() fps = 0.0 capture = cv2.VideoCapture("/media/gooddz/新加卷/检测视频/极弯场景.mp4") import torchvision import utils.transforms as tf def transform_val(): val_transform = torchvision.transforms.Compose([ tf.SampleResize((640, 368)), tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0, )), std=( [1., 1., 1.], (1, ))), ]) return val_transform while (True): t1 = time.time() ref,frame = capture.read() # img_test1 = cv.resize(img, (int(y / 2), int(x / 2))) frame = cv2.resize(frame,(1280,720)) frame_copy = frame.copy() frame = frame[160:, :, :] # print(type(frame)) # frame = frame[None,:] # val_transform = transforms.Compose([ # tf.SampleResize((640, 368)), # tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0,)), std=( # [1., 1., 1.], (1,))), # ]) # print(frame.shape) transform = transform_val() frame = transform((frame,)) # print(frame, "zzz") # print(frame[0].shape) frame = torch.from_numpy(frame[0]).permute(2, 0, 1).contiguous().float() frame = torch.tensor(frame) # print(frame.shape) frame = frame.unsqueeze(0).float() frame = frame.cuda() with torch.no_grad(): # print(data['img']) output = runner.net(frame) # print(output) seg_pred, exist_pred = output['seg'], output['exist'] # a = output['exist_lane'] # _, b_1 = torch.max(F.softmax(a, dim=2), 2) # print(F.softmax(a, dim=1),b) # a = F.softmax(a, dim=0) # print(b,b.shape) # s = torch.argmax(seg_pred[0],0) # s = s.detach().cpu().numpy() # dst_binary_image = np.zeros([s.shape[0], s.shape[1]], np.uint8) # for y in range(s.shape[0]): # for x in range(s.shape[1]): # dst_binary_image[y,x] = (s[y,x]*40) # cv2.imshow("zz",dst_binary_image) # cv2.waitKey(5) seg_pred = F.softmax(seg_pred, dim=1) seg_pred = seg_pred.detach().cpu().numpy() exist_pred = exist_pred.detach().cpu().numpy() # print(b, b.shape, exist_pred, exist_pred.shape) for b in range(len(seg_pred)): seg = seg_pred[b] # print(len(seg_pred)) exist_1 = [1 if exist_pred[b, i] > 0.5 else 0 for i in range(cfg.num_classes - 1)] lane_coords = probmap2lane(seg, exist_1, thresh=0.6, b=exist_1[b]) # print(lane_coords) for i in range(len(lane_coords)): lane_coords[i] = sorted( lane_coords[i], key=lambda pair: pair[1]) # frame = np.array(frame) # print(lane_coords) # print(frame_copy.shape, type(frame_copy)) view(frame_copy, lane_coords) # frame = frame[0].permute([1, 2, 0]) # (720, 1280, 3) # print(frame.shape) fps = (fps + (1. / (time.time() - t1))) / 2 # print(frame[0].shape,frame) # frame_copy = frame_copy.astype(np.uint8) # cv2.namedWindow('imshow', cv2.WINDOW_NORMAL) cv2.imshow('imshow', frame_copy) cv2.waitKey(1) print("fps:", fps) cv2.destroyAllWindows() time_end = time.clock() print(time_end-time_start) def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument( '--work_dirs', type=str, default='work_dirs', help='work dirs') parser.add_argument( '--load_from', default='/home/llgj/桌面/ldz/resa-main_原/work_dirs/TuSimple/20220120_083126_lr_2e-02_b_4/ckpt/best.pth') parser.add_argument( '--finetune_from', default=None, help='whether to finetune from the checkpoint') parser.add_argument( '--validate', action='store_true', help='whether to evaluate the checkpoint during training') parser.add_argument( '--view', action='store_true', help='whether to show visualization result') parser.add_argument('--gpus', nargs='+', type=int, default='0') parser.add_argument('--seed', type=int, default=None, help='random seed') args = parser.parse_args() return args if __name__ == '__main__': main() #configs/tusimple.py --gpus 0 #configs/tusimple.py --validate --load_from /media/gooddz/学习/culane_resnet50.pth --gpus 0 --view