#!/usr/bin/env python3

# Copyright (c) 2017-2023 California Institute of Technology ("Caltech"). U.S.
# Government sponsorship acknowledged. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

r'''Stereo processing

SYNOPSIS

  $ mrcal-stereo                        \
     --az-fov-deg               90      \
     --el-fov-deg               70      \
     --pixels-per-deg           -0.5    \
     --disparity-range          0 100   \
     --sgbm-block-size          5       \
     --sgbm-p1                  600     \
     --sgbm-p2                  2400    \
     --sgbm-uniqueness-ratio    5       \
     --sgbm-disp12-max-diff     1       \
     --sgbm-speckle-window-size 200     \
     --sgbm-speckle-range       2       \
     --outdir /tmp                      \
     left.cameramodel right.cameramodel \
     left.jpg         right.jpg

  Processing left.jpg and right.jpg
  Wrote '/tmp/rectified0.cameramodel'
  Wrote '/tmp/rectified1.cameramodel'
  Wrote '/tmp/left-rectified.png'
  Wrote '/tmp/right-rectified.png'
  Wrote '/tmp/left-disparity.png'
  Wrote '/tmp/left-range.png'
  Wrote '/tmp/points-cam0.vnl'

Given a pair of calibrated cameras and pairs of images captured by these
cameras, this tool runs the whole stereo processing sequence to produce
disparity and range images and a point cloud array.

mrcal functions are used to construct the rectified system. Currently only the
OpenCV SGBM routine is available to perform stereo matching, but more options
will be made available with time.

The commandline arguments to configure the SGBM matcher (--sgbm-...) map to the
corresponding OpenCV APIs. Omitting an --sgbm-... argument will result in the
defaults being used in the cv2.StereoSGBM_create() call. Usually the
cv2.StereoSGBM_create() defaults are terrible, and produce a disparity map that
isn't great. The --sgbm-... arguments in the synopsis above are a good start to
get usable stereo.

The rectified system is constructed with the axes

- x: from the origin of the first camera to the origin of the second camera (the
  baseline direction)

- y: completes the system from x,z

- z: the mean "forward" direction of the two input cameras, with the component
  parallel to the baseline subtracted off

The active window in this system is specified using a few parameters. These
refer to

- the "azimuth" (or "az"): the direction along the baseline: rectified x axis

- the "elevation" (or "el"): the direction across the baseline: rectified y axis

The rectified field of view is given by the arguments --az-fov-deg and
--el-fov-deg. At this time there's no auto-detection logic, and these must be
given. Changing these is a "zoom" operation.

To pan the stereo system, pass --az0-deg and/or --el0-deg. These specify the
center of the rectified images, and are optional.

Finally, the resolution of the rectified images is given with --pixels-per-deg.
This is optional, and defaults to the resolution of the first input image. If we
want to scale the input resolution, pass a value <0. For instance, to generate
rectified images at half the first-input-image resolution, pass
--pixels-per-deg=-0.5. Note that the Python argparse has a problem with negative
numbers, so "--pixels-per-deg -0.5" does not work.

The input images are specified by a pair of globs, so we can process many images
with a single call. Each glob is expanded, and the filenames are sorted. The
resulting lists of files are assumed to match up.

There are several modes of operation:

- No images given: we compute the rectified system only, writing the models to
  disk

- No --viz argument given: we compute the rectified system and the disparity,
  and we write all output as images on disk

- --viz geometry: we compute the rectified system, and display its geometry as a
  plot. No rectification is computed, and the images aren't used, and don't need
  to be passed in

- --viz stereo: compute the rectified system and the disparity. We don't write
  anything to disk initially, but we invoke an interactive visualization tool to
  display the results. Requires pyFLTK (homepage: https://pyfltk.sourceforge.io)
  and GL_image_display (homepage: https://github.com/dkogan/GL_image_display)

It is often desired to compute dense stereo for lots of images in bulk. To make
this go faster, this tool supports the -j JOBS option. This works just like in
Make: the work will be parallelized among JOBS simultaneous processes. Unlike
make, the JOBS value must be specified.

'''

import sys
import argparse
import re
import os

def parse_args():

    def positive_float(string):
        try:
            value = float(string)
        except:
            print(f"argument MUST be a positive floating-point number. Got '{string}'",
                  file=sys.stderr)
            sys.exit(1)
        if value <= 0:
            print(f"argument MUST be a positive floating-point number. Got '{string}'",
                  file=sys.stderr)
            sys.exit(1)
        return value
    def positive_int(string):
        try:
            value = int(string)
        except:
            print(f"argument MUST be a positive integer. Got '{string}'",
                  file=sys.stderr)
            sys.exit(1)
        if value <= 0 or abs(value-float(string)) > 1e-6:
            print(f"argument MUST be a positive integer. Got '{string}'",
                  file=sys.stderr)
            sys.exit(1)
        return value


    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)


    ######## geometry and rectification system parameters
    parser.add_argument('--az-fov-deg',
                        type=float,
                        help='''The field of view in the azimuth direction, in
                        degrees. There's no auto-detection at this time, so this
                        argument is required (unless --already-rectified)''')
    parser.add_argument('--el-fov-deg',
                        type=float,
                        help='''The field of view in the elevation direction, in
                        degrees. There's no auto-detection at this time, so this
                        argument is required (unless --already-rectified)''')
    parser.add_argument('--az0-deg',
                        default = None,
                        type=float,
                        help='''The azimuth center of the rectified images. "0"
                        means "the horizontal center of the rectified system is
                        the mean forward direction of the two cameras projected
                        to lie perpendicular to the baseline". If omitted, we
                        align the center of the rectified system with the center
                        of the two cameras' views''')
    parser.add_argument('--el0-deg',
                        default = 0,
                        type=float,
                        help='''The elevation center of the rectified system.
                        "0" means "the vertical center of the rectified system
                        lies along the mean forward direction of the two
                        cameras" Defaults to 0.''')
    parser.add_argument('--pixels-per-deg',
                        help='''The resolution of the rectified images. This is
                        either a whitespace-less, comma-separated list of two
                        values (az,el) or a single value to be applied to both
                        axes. If a resolution of >0 is requested, the value is
                        used as is. If a resolution of <0 is requested, we use
                        this as a scale factor on the resolution of the first
                        input image. For instance, to downsample by a factor of
                        2, pass -0.5. By default, we use -1 for both axes: the
                        resolution of the input image at the center of the
                        rectified system.''')
    parser.add_argument('--rectification',
                        choices=('LENSMODEL_PINHOLE', 'LENSMODEL_LATLON'),
                        default = 'LENSMODEL_LATLON',
                        help='''The lens model to use for rectification.
                        Currently two models are supported: LENSMODEL_LATLON
                        (the default) and LENSMODEL_PINHOLE. Pinhole stereo
                        works badly for wide lenses and suffers from varying
                        angular resolution across the image. LENSMODEL_LATLON
                        rectification uses a transverse equirectangular
                        projection, and does not suffer from these effects. It
                        is thus the recommended model''')

    parser.add_argument('--already-rectified',
                        action='store_true',
                        help='''If given, assume the given models and images
                        already represent a rectified system. This will be
                        checked, and the models will be used as-is if the checks
                        pass''')

    ######## image pre-filtering
    parser.add_argument('--clahe',
                        action='store_true',
                        help='''If given, apply CLAHE equalization to the images
                        prior to the stereo matching. If --already-rectified, we
                        still apply this equalization, if requested. Requires
                        --force-grayscale''')

    parser.add_argument('--force-grayscale',
                        action='store_true',
                        help='''If given, convert the images to grayscale prior
                        to doing anything else with them. By default, read the
                        images in their default format, and pass those
                        posibly-color images to all the processing steps.
                        Required if --clahe''')

    ######## --viz
    parser.add_argument('--viz',
                        choices=('geometry', 'stereo'),
                        default='',
                        help='''If given, we visualize either the rectified
                        geometry or the stereo results. If --viz geometry: we
                        construct the rectified stereo system, but instead of
                        continuing with the stereo processing, we render the
                        geometry of the stereo world; the images are ignored in
                        this mode. If --viz stereo: we launch an interactive
                        graphical tool to examine the rectification and stereo
                        matching results; the Fl_Gl_Image_Widget Python library
                        must be available''')
    parser.add_argument('--axis-scale',
                        type=float,
                        help='''Used if --viz geometry. Scale for the camera
                        axes. By default a reasonable default is chosen (see
                        mrcal.show_geometry() for the logic)''')
    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Used if --viz geometry. Title string for the plot''')
    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Used if --viz geometry. Write the output to
                        disk, instead of making an interactive plot. The output
                        filename is given in the option''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''Used if --viz geometry. The gnuplotlib
                        terminal. The default is almost always right, so most
                        people don't need this option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Used if --viz geometry. Extra 'set' directives
                        to pass to gnuplotlib. May be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Used if --viz geometry. Extra 'unset'
                        directives to pass to gnuplotlib. May be given multiple
                        times''')


    ######## stereo processing
    parser.add_argument('--force', '-f',
                        action='store_true',
                        default=False,
                        help='''By default existing files are not overwritten. Pass --force to overwrite them
                        without complaint''')
    parser.add_argument('--outdir',
                        default='.',
                        type=lambda d: d if os.path.isdir(d) else \
                        parser.error(f"--outdir requires an existing directory as the arg, but got '{d}'"),
                        help='''Directory to write the output into. If omitted,
                        we user the current directory''')
    parser.add_argument('--tag',
                        help='''String to use in the output filenames.
                        Non-specific output filenames if omitted ''')
    parser.add_argument('--disparity-range',
                        type=int,
                        nargs=2,
                        default=(0,100),
                        help='''The disparity limits to use in the search, in
                        pixels. Two integers are expected: MIN_DISPARITY
                        MAX_DISPARITY. Completely arbitrarily, we default to
                        MIN_DISPARITY=0 and MAX_DISPARITY=100''')
    parser.add_argument('--valid-intrinsics-region',
                        action='store_true',
                        help='''If given, annotate the image with its
                        valid-intrinsics region. This will end up in the
                        rectified images, and make it clear where successful
                        matching shouldn't be expected''')
    parser.add_argument('--range-image-limits',
                        type=positive_float,
                        nargs=2,
                        default=(1,1000),
                        help='''The nearest,furthest range to encode in the range image.
                        Defaults to 1,1000, arbitrarily''')

    parser.add_argument('--stereo-matcher',
                        choices=( 'SGBM', 'ELAS' ),
                        default = 'SGBM',
                        help='''The stereo-matching method. By default we use
                        the "SGBM" method from OpenCV. libelas isn't always
                        available, and must be enabled at compile-time by
                        setting USE_LIBELAS=1 during the build''')

    parser.add_argument('--sgbm-block-size',
                        type=int,
                        default = 5,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, 5 is used''')
    parser.add_argument('--sgbm-p1',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-p2',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-disp12-max-diff',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-pre-filter-cap',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-uniqueness-ratio',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-speckle-window-size',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-speckle-range',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-mode',
                        choices=('SGBM','HH','HH4','SGBM_3WAY'),
                        help='''A parameter for the OpenCV SGBM matcher. Must be
                        one of ('SGBM','HH','HH4','SGBM_3WAY'). If omitted, the
                        OpenCV default (SGBM) is used''')

    parser.add_argument('--write-point-cloud',
                        action='store_true',
                        help='''If given, we write out the point cloud as a .ply
                        file. Each point is reported in the reference coordinate
                        system, colored with the nearest-neighbor color of the
                        camera0 image. This is disabled by default because this
                        is potentially a very large file''')
    parser.add_argument('--jobs', '-j',
                        type=int,
                        required=False,
                        default=1,
                        help='''parallelize the processing JOBS-ways. This is
                        like Make, except you're required to explicitly specify
                        a job count. This applies when processing multiple sets
                        of images with the same set of models''')
    parser.add_argument('models',
                        type=str,
                        nargs = 2,
                        help='''Camera models representing cameras used to
                        capture the images. Both intrinsics and extrinsics are
                        used''')
    parser.add_argument('images',
                        type=str,
                        nargs='*',
                        help='''The image globs to use for the stereo. If
                        omitted, we only write out the rectified models. If
                        given, exactly two image globs must be given''')

    args = parser.parse_args()

    if not (len(args.images) == 0 or \
            len(args.images) == 2):
            print("Argument-parsing error: exactly 0 or exactly 2 images should have been given",
                  file=sys.stderr)
            sys.exit(1)

    if args.pixels_per_deg is None:
        args.pixels_per_deg = (-1, -1)
    else:
        try:
            l = [float(x) for x in args.pixels_per_deg.split(',')]
            if len(l) < 1 or len(l) > 2:
                raise
            for x in l:
                if x == 0:
                    raise
            args.pixels_per_deg = l
        except:
            print("""Argument-parsing error:
  --pixels_per_deg requires "RESX,RESY" or "RESXY", where RES... is a value <0 or >0""",
                  file=sys.stderr)
            sys.exit(1)

    if not args.already_rectified and \
       (args.az_fov_deg is None or \
        args.el_fov_deg is None ):
        print("""Argument-parsing error:
  --az-fov-deg and --el-fov-deg are required if not --already-rectified""",
              file=sys.stderr)
        sys.exit(1)

    if args.clahe and not args.force_grayscale:
        print("--clahe requires --force-grayscale",
              file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README






import numpy as np
import numpysane as nps
import glob
import mrcal


if args.stereo_matcher == 'ELAS':
    if not hasattr(mrcal, 'stereo_matching_libelas'):
        print("ERROR: the ELAS stereo matcher isn't available. libelas must be installed, and enabled at compile-time with USE_LIBELAS=1. Pass '--stereo-matcher SGBM' instead", file=sys.stderr)
        sys.exit(1)

if args.viz == 'stereo':

    try:
        from fltk import *
    except:
        print("The visualizer needs the pyFLTK tool. See https://pyfltk.sourceforge.io",
              file=sys.stderr)
        sys.exit(1)

    try:
        from Fl_Gl_Image_Widget import *
    except:
        print("The visualizer needs the GL_image_display library. See https://github.com/dkogan/GL_image_display",
              file=sys.stderr)
        sys.exit(1)

if len(args.pixels_per_deg) == 2:
    pixels_per_deg_az,pixels_per_deg_el = args.pixels_per_deg
else:
    pixels_per_deg_az = pixels_per_deg_el = args.pixels_per_deg[0]

if args.tag is not None:
    args.tag = re.sub('[^a-zA-Z0-9_+-]+', '_', args.tag)


def openmodel(f):
    try:
        return mrcal.cameramodel(f)
    except Exception as e:
        print(f"Couldn't load camera model '{f}': {e}",
              file=sys.stderr)
        sys.exit(1)

models = [openmodel(modelfilename) for modelfilename in args.models]

if not args.already_rectified:
    models_rectified = \
        mrcal.rectified_system(models,
                               az_fov_deg          = args.az_fov_deg,
                               el_fov_deg          = args.el_fov_deg,
                               az0_deg             = args.az0_deg,
                               el0_deg             = args.el0_deg,
                               pixels_per_deg_az   = pixels_per_deg_az,
                               pixels_per_deg_el   = pixels_per_deg_el,
                               rectification_model = args.rectification)
else:
    models_rectified = models
    mrcal.stereo._validate_models_rectified(models_rectified)

Rt_cam0_rect0 = \
    mrcal.compose_Rt( models          [0].extrinsics_Rt_fromref(),
                      models_rectified[0].extrinsics_Rt_toref() )
Rt_cam1_rect1 = \
    mrcal.compose_Rt( models          [1].extrinsics_Rt_fromref(),
                      models_rectified[1].extrinsics_Rt_toref() )


if args.viz == 'geometry':
    # Display the geometry of the two cameras in the stereo pair, and of the
    # rectified system
    plot_options = dict(terminal = args.terminal,
                        hardcopy = args.hardcopy,
                        _set     = args.set,
                        unset    = args.unset,
                        wait     = args.hardcopy is None)
    if args.title is not None:
        plot_options['title'] = args.title

    data_tuples, plot_options = \
        mrcal.show_geometry( list(models) + list(models_rectified),
                             cameranames      = ( "camera0", "camera1", "camera0-rectified", "camera1-rectified" ),
                             show_calobjects  = False,
                             return_plot_args = True,
                             axis_scale       = args.axis_scale,
                             **plot_options)

    Nside = 8
    model = models_rectified[0]
    w,h = model.imagersize()
    linspace_w = np.linspace(0,w-1, Nside+1)
    linspace_h = np.linspace(0,h-1, Nside+1)
    # shape (Nloop, 2)
    qloop = \
       nps.transpose( nps.glue( nps.cat( linspace_w,
                                         0*linspace_w),
                                nps.cat( 0*linspace_h[1:] + w-1,
                                         linspace_h[1:]),
                                nps.cat( linspace_w[-2::-1],
                                         0*linspace_w[-2::-1] + h-1),
                                nps.cat( 0*linspace_h[-2::-1],
                                         linspace_h[-2::-1]),
                                axis=-1) )

    # shape (Nloop,3): unit vectors in cam-local coords
    v = mrcal.unproject(np.ascontiguousarray(qloop),
                        *model.intrinsics(),
                        normalize = True)

    # Scale the vectors to a nice-looking length. This is intended to work with
    # the defaults to mrcal.show_geometry(). That function sets the right,down
    # axis lengths to baseline/3. Here I default to a bit less: baseline/4
    baseline = \
        nps.mag(mrcal.compose_Rt( models_rectified[1].extrinsics_Rt_fromref(),
                                  models_rectified[0].extrinsics_Rt_toref())[3,:])
    v *= baseline/4

    # shape (Nloop,3): ref-coord system points
    p = mrcal.transform_point_Rt(model.extrinsics_Rt_toref(), v)

    # shape (Nloop,2,3). Nloop-length different lines with 2 points in each one
    p = nps.xchg(nps.cat(p,p), -2,-3)
    p[:,1,:] = model.extrinsics_Rt_toref()[3,:]

    data_tuples.append( (p, dict(tuplesize = -3,
                                 _with     = 'lines lt 1',)))

    import gnuplotlib as gp
    gp.plot(*data_tuples, **plot_options)

    sys.exit()




def write_output_one(tag, outdir, force,
                     func, filename):
    if tag is None:
        tag = ''
    else:
        tag = '-' + tag

    f,e = os.path.splitext(filename)
    filename = f"{outdir}/{f}{tag}{e}"

    if os.path.exists(filename) and not force:
        print(f"WARNING: '{filename}' already exists. Not overwriting this file. Pass -f to overwrite")
        return

    func(filename)
    print(f"Wrote '{filename}'")


def write_output_all(args,
                     models,
                     images,
                     image_filenames,
                     models_rectified,
                     images_rectified,
                     disparity,
                     disparity_scale,
                     disparity_colored,
                     ranges,
                     ranges_colored,
                     force_override = False):

    image_filenames_base = \
        [os.path.splitext(os.path.split(f)[1])[0] for f in image_filenames]


    if not args.already_rectified:
        write_output_one(args.tag,
                         args.outdir,
                         args.force or force_override,
                         lambda filename: models_rectified[0].write(filename),
                         'rectified0.cameramodel')
        write_output_one(args.tag,
                         args.outdir,
                         args.force or force_override,
                         lambda filename: models_rectified[1].write(filename),
                         'rectified1.cameramodel')

        write_output_one(args.tag,
                         args.outdir,
                         args.force or force_override,
                         lambda filename: mrcal.save_image(filename, images_rectified[0]),
                         image_filenames_base[0] + '-rectified.png')
        write_output_one(args.tag,
                         args.outdir,
                         args.force or force_override,
                         lambda filename: mrcal.save_image(filename, images_rectified[1]),
                         image_filenames_base[1] + '-rectified.png')

    write_output_one(args.tag,
                     args.outdir,
                     args.force or force_override,
                     lambda filename: \
                     mrcal.save_image(filename, disparity_colored),
                     image_filenames_base[0] + '-disparity.png')
    write_output_one(args.tag,
                     args.outdir,
                     args.force or force_override,
                     lambda filename: \
                     mrcal.save_image(filename, disparity.astype(np.uint16)),
                     image_filenames_base[0] + f'-disparity-uint16-scale{disparity_scale}.png')
    write_output_one(args.tag,
                     args.outdir,
                     args.force or force_override,
                     lambda filename: \
                     mrcal.save_image(filename, ranges_colored),
                     image_filenames_base[0] + '-range.png')


    if not args.write_point_cloud:
        return


    # generate and output the point cloud, in the cam0 coord system
    p_rect0 = \
        mrcal.stereo_unproject(disparity        = None,
                               models_rectified = models_rectified,
                               ranges           = ranges)

    Rt_ref_rect0 = models_rectified[0].extrinsics_Rt_toref()
    Rt_cam0_ref  = models          [0].extrinsics_Rt_fromref()

    # Point cloud in ref coordinates
    # shape (H,W,3)
    p_ref = mrcal.transform_point_Rt(Rt_ref_rect0, p_rect0)

    # shape (N,3)
    p_ref = nps.clump(p_ref, n=2)

    # I set each point color to the nearest-pixel color in the original
    # image. I only report those points that are in-bounds (they SHOULD
    # be, but I make sure)
    q_cam0 = mrcal.project(mrcal.transform_point_Rt(Rt_cam0_ref, p_ref),
                           *models[0].intrinsics())

    # Without this sometimes q_cam0 can get nan, and this throws ugly warnings
    i      = np.isfinite(q_cam0)
    i = np.min(i, axis=-1) # qx or qx being nan means the whole q is nan
    p_ref  = p_ref [i,:]
    q_cam0 = q_cam0[i,:]

    H,W = images[0].shape[:2]

    q_cam0_integer = (q_cam0 + 0.5).astype(int)

    i = \
        (q_cam0_integer[:,0] >= 0) * \
        (q_cam0_integer[:,1] >= 0) * \
        (q_cam0_integer[:,0] <  W) * \
        (q_cam0_integer[:,1] <  H)
    p_ref          = p_ref[i]
    q_cam0_integer = q_cam0_integer[i]

    image0 = images[0]
    if len(image0.shape) == 2 and args.force_grayscale:
        # I have a grayscale image, but the user asked for it for the
        # purposes of stereo. I reload without asking for grayscale. If
        # the image was a color one, that would be useful here.
        image0 = mrcal.load_image(image_filenames[0])

    # shape (N,)  if grayscale or
    #       (N,3) if bgr color
    color = image0[q_cam0_integer[:,1], q_cam0_integer[:,0]]

    write_output_one(args.tag,
                     args.outdir,
                     args.force or force_override,
                     lambda filename: \
                       mrcal.write_point_cloud_as_ply(filename,
                                                      p_ref,
                                                      color  = color,
                                                      binary = True),
                     image_filenames_base[0] + '-points.ply')


if len(args.images) == 0:
    # No images are given. Just write out the rectified models
    write_output_one(args.tag,
                     args.outdir,
                     args.force,
                     lambda filename: models_rectified[0].write(filename),
                     'rectified0.cameramodel')
    write_output_one(args.tag,
                     args.outdir,
                     args.force,
                     lambda filename: models_rectified[1].write(filename),
                     'rectified1.cameramodel')
    sys.exit(0)


if not args.already_rectified:
    rectification_maps = mrcal.rectification_maps(models, models_rectified)




image_filenames_all = [sorted(glob.glob(os.path.expanduser(g))) for g in args.images]

if len(image_filenames_all[0]) != len(image_filenames_all[1]):
    print(f"First glob matches {len(image_filenames_all[0])} files, but the second glob matches {len(image_filenames_all[1])} files. These must be identical",
          file=sys.stderr)
    sys.exit(1)

Nimages = len(image_filenames_all[0])
if Nimages == 0:
    print("The given globs matched 0 images. Nothing to do",
          file=sys.stderr)
    sys.exit(1)

if args.viz == 'stereo' and Nimages > 1:
    print(f"The given globs matched {Nimages} images. But '--viz stereo' was given, so exactly ONE set of images is expected",
          file=sys.stderr)
    sys.exit(1)


if args.clahe or args.stereo_matcher == 'SGBM':
    import cv2


if args.clahe:
    clahe = cv2.createCLAHE()
    clahe.setClipLimit(8)


kwargs_load_image = dict()
if args.force_grayscale:
    kwargs_load_image['bits_per_pixel'] = 8
    kwargs_load_image['channels']       = 1

# Done with all the preliminaries. Run the stereo matching
disparity_min,disparity_max = args.disparity_range

if args.stereo_matcher == 'SGBM':

    # This is a hard-coded property of the OpenCV StereoSGBM implementation
    disparity_scale = 16

    # round to nearest multiple of disparity_scale. The OpenCV StereoSGBM
    # implementation requires this
    num_disparities = disparity_max - disparity_min
    num_disparities = disparity_scale*round(num_disparities/disparity_scale)

    # I only add non-default args. StereoSGBM_create() doesn't like being given
    # None args
    kwargs_sgbm = dict()
    if args.sgbm_p1 is not None:
        kwargs_sgbm['P1']                = args.sgbm_p1
    if args.sgbm_p2 is not None:
        kwargs_sgbm['P2']                = args.sgbm_p2
    if args.sgbm_disp12_max_diff is not None:
        kwargs_sgbm['disp12MaxDiff']     = args.sgbm_disp12_max_diff
    if args.sgbm_uniqueness_ratio is not None:
        kwargs_sgbm['uniquenessRatio']   = args.sgbm_uniqueness_ratio
    if args.sgbm_speckle_window_size is not None:
        kwargs_sgbm['speckleWindowSize'] = args.sgbm_speckle_window_size
    if args.sgbm_speckle_range is not None:
        kwargs_sgbm['speckleRange']      = args.sgbm_speckle_range
    if args.sgbm_mode is not None:
        if   args.sgbm_mode == 'SGBM':
            kwargs_sgbm['mode'] = cv2.StereoSGBM_MODE_SGBM
        elif args.sgbm_mode == 'HH':
            kwargs_sgbm['mode'] = cv2.StereoSGBM_MODE_HH
        elif args.sgbm_mode == 'HH4':
            kwargs_sgbm['mode'] = cv2.StereoSGBM_MODE_HH4
        elif args.sgbm_mode == 'SGBM_3WAY':
            kwargs_sgbm['mode'] = cv2.StereoSGBM_MODE_SGBM_3WAY
        else:
            raise Exception("arg-parsing error. This is a bug. Please report")

    # blocksize is required, so I always pass it. There's a default set in
    # the argument parser, so this is never None
    kwargs_sgbm['blockSize'] = args.sgbm_block_size

    stereo_sgbm = \
        cv2.StereoSGBM_create(minDisparity      = disparity_min,
                              numDisparities    = num_disparities,
                              **kwargs_sgbm)
elif args.stereo_matcher == 'ELAS':
    disparity_scale = 1
else:
    raise Exception("Getting here is a bug")

def process(i_image):

    image_filenames=(image_filenames_all[0][i_image],
                     image_filenames_all[1][i_image])

    print(f"Processing {os.path.split(image_filenames[0])[1]} and {os.path.split(image_filenames[1])[1]}")

    images = [mrcal.load_image(f, **kwargs_load_image) for f in image_filenames]


    # This doesn't really matter: I don't use the input imagersize. But a
    # mismatch suggests the user probably messed up, and it would save them time
    # to yell at them
    imagersize_image = np.array((images[0].shape[1], images[0].shape[0]))
    imagersize_model = models[0].imagersize()
    if np.any(imagersize_image - imagersize_model):
        print(f"Image '{image_filenames[0]}' dimensions {imagersize_image} don't match the model '{args.models[0]}' dimensions {imagersize_model}",
              file=sys.stderr)
        return False
    imagersize_image = np.array((images[1].shape[1], images[1].shape[0]))
    imagersize_model = models[1].imagersize()
    if np.any(imagersize_image - imagersize_model):
        print(f"Image '{image_filenames[1]}' dimensions {imagersize_image} don't match the model '{args.models[1]}' dimensions {imagersize_model}",
              file=sys.stderr)
        return False

    if args.clahe:
        images = [ clahe.apply(image) for image in images ]

    if len(images[0].shape) == 3:
        if args.stereo_matcher == 'ELAS':
            print("The ELAS matcher requires grayscale images. Pass --force-grayscale",
                  file=sys.stderr)
            return False

    if args.valid_intrinsics_region:
        for i in range(2):
            mrcal.annotate_image__valid_intrinsics_region(images[i], models[i])

    if not args.already_rectified:
        images_rectified = [mrcal.transform_image(images[i],
                                                  rectification_maps[i]) \
                            for i in range(2)]
    else:
        images_rectified = images

    if args.stereo_matcher == 'SGBM':

        # opencv barfs if I give it a color image and a monochrome image, so I
        # convert
        if len(images_rectified[0].shape) == 2 and \
           len(images_rectified[1].shape) == 3:
            disparity = stereo_sgbm.compute(images_rectified[0],
                                            cv2.cvtColor(images_rectified[1], cv2.COLOR_BGR2GRAY))
        elif len(images_rectified[0].shape) == 3 and \
             len(images_rectified[1].shape) == 2:
            disparity = stereo_sgbm.compute(cv2.cvtColor(images_rectified[0], cv2.COLOR_BGR2GRAY),
                                            images_rectified[1])
        else:
            disparity = stereo_sgbm.compute(*images_rectified)

    elif args.stereo_matcher == 'ELAS':

        disparities = mrcal.stereo_matching_libelas(*images_rectified,
                                                    disparity_min = disparity_min,
                                                    disparity_max = disparity_max)

        disparity = disparities[0]


    disparity_colored = mrcal.apply_color_map(disparity,
                                              a_min = disparity_min*disparity_scale,
                                              a_max = disparity_max*disparity_scale)

    ranges = mrcal.stereo_range( disparity,
                                 models_rectified,
                                 disparity_scale = disparity_scale,
                                 disparity_min   = disparity_min)
    ranges_colored = mrcal.apply_color_map(ranges,
                                           a_min = args.range_image_limits[0],
                                           a_max = args.range_image_limits[1],)

    if args.viz == 'stereo':
        # Earlier I confirmed that in this path Nimages==1
        gui_stereo(images,
                   image_filenames,
                   images_rectified,
                   disparity,
                   disparity_colored,
                   ranges,
                   ranges_colored)
    else:
        write_output_all(args,
                         models,
                         images,
                         image_filenames,
                         models_rectified,
                         images_rectified,
                         disparity,
                         disparity_scale,
                         disparity_colored,
                         ranges,
                         ranges_colored)
    return True


def gui_stereo(images,
               image_filenames,
               images_rectified,
               disparity,
               disparity_colored,
               ranges,
               ranges_colored):
    # Done with all the processing. Invoke the visualizer!

    UI_usage_message = r'''Usage:

    Left mouse button click/drag: pan
    Mouse wheel up/down/left/right: pan
    Ctrl-mouse wheel up/down: zoom
    'u': reset view: zoom out, pan to the center
    'd': show a colorized disparity image (default)
    'r': show a colorized range image

    Right mouse button click: examine stereo at pixel
    TAB: transpose windows
'''

    def set_status(string,
                   q_rect   = None,
                   q_input0 = None,
                   q_input1 = None):
        # None means "use previous"
        if string is not None:
            set_status.string = string
        else:
            if not hasattr(set_status, 'string'):
                set_status.string = ''

        if q_rect is not None:
            set_status.qstring = \
                f"q_rectified0 = {q_rect[0]:.1f},{q_rect[1]:.1f}\n"
            if q_input0 is not None:
                set_status.qstring += f"q_input_cam0 = {q_input0[0]:.1f},{q_input0[1]:.1f}\n"
            if q_input1 is not None:
                set_status.qstring += f"q_input_cam1 = {q_input1[0]:.1f},{q_input1[1]:.1f}\n"
            set_status.qstring += "\n"
        else:
            if not hasattr(set_status, 'qstring'):
                set_status.qstring = 'No current pixel\n\n'

        widget_status.value(UI_usage_message + "\n" + set_status.qstring + set_status.string)

    def write_output_interactive(widget):
        # cursor-setting needs threading. I don't bother with it yet
        # window.cursor(FL_CURSOR_WAIT)

        write_output_all(args,
                         models,
                         images,
                         image_filenames,
                         models_rectified,
                         images_rectified,
                         disparity,
                         disparity_scale,
                         disparity_colored,
                         ranges, ranges_colored,
                         force_override = True)

        # window.cursor(FL_CURSOR_DEFAULT)

    def process_tab():
        x_image1    = widget_image1.x()
        y_image1    = widget_image1.y()
        x_disparity = widget_result.x()
        y_disparity = widget_result.y()

        widget_image1.position(x_disparity, y_disparity)
        widget_result.position(x_image1,    y_image1)


    class Fl_Gl_Image_Widget_Derived(Fl_Gl_Image_Widget):

        def set_panzoom(self,
                        x_centerpixel, y_centerpixel,
                        visible_width_pixels):
            r'''Pan/zoom the image

            This is an override of the function to do this: any request to
            pan/zoom the widget will come here first. I dispatch any
            pan/zoom commands to all the widgets, so that they all work in
            unison. visible_width_pixels < 0 means: this is the redirected
            call; just call the base class

            '''
            if visible_width_pixels < 0:
                return super().set_panzoom(x_centerpixel, y_centerpixel,
                                           -visible_width_pixels)

            # All the widgets should pan/zoom together
            result = \
                all( w.set_panzoom(x_centerpixel, y_centerpixel,
                                   -visible_width_pixels) \
                     for w in (widget_image0, widget_image1, widget_result) )

            # Switch back to THIS widget
            self.make_current()
            return result

        def set_cross_at(self, q, no_vertical = False):
            if q is None:
                return \
                    self.set_lines()

            x,y = q
            W,H = models_rectified[0].imagersize()
            if no_vertical:
                points = np.array( ((( -0.5, y),
                                     (W-0.5, y),),),
                                   dtype=np.float32)
                color_rgb = np.array((1,0,0),
                                     dtype=np.float32)
            else:
                points = np.array( ((( -0.5, y),
                                     (W-0.5, y)),
                                    ((x,       -0.5),
                                     (x,      H-0.5))),
                                   dtype=np.float32)
                color_rgb = np.array((0,0,1),
                                     dtype=np.float32)

            return \
                self.set_lines( dict(points    = points,
                                     color_rgb = color_rgb ))

        def handle(self, event):
            if event == FL_PUSH:

                if Fl.event_button() != FL_RIGHT_MOUSE:
                    return super().handle(event)

                try:
                    q0_rectified = \
                        np.array( self.map_pixel_image_from_viewport( (Fl.event_x(),Fl.event_y()), ),
                                  dtype=float )
                except:
                    set_status("Error converting pixel coordinates")
                    widget_image0.set_cross_at(None)
                    widget_image1.set_cross_at(None)
                    widget_result.set_cross_at(None)
                    return super().handle(event)

                if self is widget_image1:
                    set_status("Please click in the left or disparity windows")
                    widget_image0.set_cross_at(q0_rectified, no_vertical = True)
                    widget_image1.set_cross_at(q0_rectified, no_vertical = True)
                    widget_result.set_cross_at(q0_rectified, no_vertical = True)
                    return super().handle(event)

                if not (q0_rectified[0] >= -0.5 and q0_rectified[0] <= images_rectified[0].shape[1]-0.5 and \
                        q0_rectified[1] >= -0.5 and q0_rectified[1] <= images_rectified[0].shape[0]-0.5):
                    set_status("Out of bounds")
                    widget_image0.set_cross_at(q0_rectified, no_vertical = True)
                    widget_image1.set_cross_at(q0_rectified, no_vertical = True)
                    widget_result.set_cross_at(q0_rectified, no_vertical = True)
                    return super().handle(event)

                d = disparity[int(round(q0_rectified[-1])),
                              int(round(q0_rectified[-2]))]
                if d < disparity_min*disparity_scale:
                    set_status("No valid disparity at the clicked location")
                    widget_image0.set_cross_at(q0_rectified)
                    widget_image1.set_cross_at(q0_rectified, no_vertical = True)
                    widget_result.set_cross_at(q0_rectified)
                    return super().handle(event)
                if d == disparity_min*disparity_scale:
                    set_status("Disparity: 0pixels\n" +
                               "range: infinity\n")
                    widget_image0.set_cross_at(q0_rectified)
                    widget_image1.set_cross_at(q0_rectified)
                    widget_result.set_cross_at(q0_rectified)
                    return super().handle(event)

                widget_image0.set_cross_at(q0_rectified)
                widget_image1.set_cross_at(q0_rectified -
                                           np.array((d/disparity_scale, 0)))
                widget_result.set_cross_at(q0_rectified)

                delta = 1e-3
                _range   = mrcal.stereo_range( d,
                                               models_rectified,
                                               disparity_scale = disparity_scale,
                                               disparity_min   = disparity_min,
                                               qrect0          = q0_rectified)
                _range_d = mrcal.stereo_range( d + delta*disparity_scale,
                                               models_rectified,
                                               disparity_scale = disparity_scale,
                                               disparity_min   = disparity_min,
                                               qrect0          = q0_rectified)
                drange_ddisparity = np.abs(_range_d - _range) / delta


                # mrcal-triangulate tool: guts into a function. Call those
                # guts here, and display them in the window
                set_status(f"Disparity: {d/disparity_scale:.2f}pixels\n" +
                           f"range: {_range:.2f}m\n" +
                           f"drange/ddisparity: {drange_ddisparity:.5f}m/pixel")

                if 0:
                    p0_rectified = \
                        mrcal.unproject(q0_rectified,
                                        *models_rectified[0].intrinsics(),
                                        normalize = True) * _range
                    p0 = mrcal.transform_point_Rt( Rt_cam0_rect0,
                                                   p0_rectified )
                    p1 = mrcal.transform_point_Rt( mrcal.compose_Rt( models[1].extrinsics_Rt_fromref(),
                                                                     models[0].extrinsics_Rt_toref()),
                                                   p0 )

                    q0  = mrcal.project(p0, *models[0].intrinsics())

                    pt = \
                        mrcal.triangulate( nps.cat(q0, q1),
                                           models,
                                           stabilize_coords = True )
                    print(f"range = {_range:.2f}, {nps.mag(pt):.2f}")

                return 0

            if event == FL_KEYDOWN:
                if Fl.event_key() == fltk.FL_Tab:
                    process_tab()
                    return 1

                if Fl.event_key() == ord('d') or \
                   Fl.event_key() == ord('D'):
                    widget_result. \
                      update_image(decimation_level = 0,
                                   image_data       = disparity_colored)
                    return 1

                if Fl.event_key() == ord('r') or \
                   Fl.event_key() == ord('R'):
                    widget_result. \
                      update_image(decimation_level = 0,
                                   image_data       = ranges_colored)
                    return 1

            if event == FL_MOVE:
                q_rect   = None
                q_input0 = None
                q_input1 = None

                try:
                    q_rect = self.map_pixel_image_from_viewport( (Fl.event_x(),Fl.event_y()), )
                except:
                    pass

                if q_rect is not None:
                    q_input0 = \
                        mrcal.project( mrcal.transform_point_Rt( Rt_cam0_rect0,
                                                                 mrcal.unproject(q_rect, *models_rectified[0].intrinsics())),
                                       *models[0].intrinsics())
                    q_input1 = \
                        mrcal.project( mrcal.transform_point_Rt( Rt_cam1_rect1,
                                                                 mrcal.unproject(q_rect, *models_rectified[1].intrinsics())),
                                       *models[1].intrinsics())

                set_status(string  = None, # keep old string
                           q_rect   = q_rect,
                           q_input0 = q_input0,
                           q_input1 = q_input1)
                # fall through to let parent handlers run

            return super().handle(event)

    class Fl_Multiline_Output_pass_keys(Fl_Multiline_Output):
        def handle(self, event):
            if event == FL_KEYDOWN:
                if Fl.event_key() == fltk.FL_Tab:
                    process_tab()
                    return 1

                if Fl.event_key() == fltk.FL_Escape:
                    window.hide()

            return super().handle(event)

    class Fl_Button_pass_keys(Fl_Button):
        def handle(self, event):
            if event == FL_KEYDOWN:
                if Fl.event_key() == fltk.FL_Tab:
                    process_tab()
                    return 1

                if Fl.event_key() == fltk.FL_Escape:
                    window.hide()

            return super().handle(event)


    window        = Fl_Window(800, 600, "mrcal-stereo")
    widget_image0 = Fl_Gl_Image_Widget_Derived(   0,    0,        400,300)
    widget_image1 = Fl_Gl_Image_Widget_Derived(   0,  300,        400,300)
    widget_result = Fl_Gl_Image_Widget_Derived(   400,  0,        400,300)

    group_status     = Fl_Group(                     400,300,        400,300)
    widget_status    = Fl_Multiline_Output_pass_keys(400,300,        400,300-30)
    widget_dowrite   = Fl_Button_pass_keys(          400,300+300-30, 400,30,
                                                     f'Write output to "{args.outdir}" (will overwrite)')
    widget_dowrite.callback(write_output_interactive)
    group_status.resizable(widget_status)
    group_status.end()

    set_status('')

    window.resizable(window)
    window.end()
    window.show()

    widget_image0. \
      update_image(decimation_level = 0,
                   image_data       = images_rectified[0])
    widget_image1. \
      update_image(decimation_level = 0,
                   image_data       = images_rectified[1])
    widget_result. \
      update_image(decimation_level = 0,
                   image_data       = disparity_colored)
    Fl.run()

    sys.exit(0)




if Nimages == 1 or args.jobs <= 1:

    # Serial path. Useful for debugging or if I only have one set of images
    for i in range(Nimages):
        process(i)
    sys.exit(0)

####### Parallel path

import multiprocessing
import signal

# I do the same thing in mrcal-reproject-image. Please consolidate
#
# weird business to handle weird signal handling in multiprocessing. I want
# things like the user hitting C-c to work properly. So I ignore SIGINT for the
# children. And I want the parent's blocking wait for results to respond to
# signals. Which means map_async() instead of map(), and wait(big number)
# instead of wait()
signal_handler_sigint = signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGINT, signal_handler_sigint)

pool = multiprocessing.Pool(args.jobs)
try:
    mapresult = pool.map_async(process, np.arange(Nimages))

    # like wait(), but will barf if something goes wrong. I don't actually care
    # about the results
    mapresult.get(1000000)
except:
    pool.terminate()

pool.close()
pool.join()
