Add support to generate_lut for writing CSP
[OpenColorIO-Configs.git] / aces_1.0.0 / python / aces_ocio / generate_lut.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 """
5 Defines objects to generate various kind of 1d, 2d and 3d LUTs in various file
6 formats.
7 """
8
9 from __future__ import division
10
11 import array
12 import os
13 import sys
14
15 import OpenImageIO as oiio
16
17 from aces_ocio.process import Process
18
19 __author__ = 'ACES Developers'
20 __copyright__ = 'Copyright (C) 2014 - 2015 - ACES Developers'
21 __license__ = ''
22 __maintainer__ = 'ACES Developers'
23 __email__ = 'aces@oscars.org'
24 __status__ = 'Production'
25
26 __all__ = ['generate_1d_LUT_image',
27            'write_SPI_1d',
28            'generate_1d_LUT_from_image',
29            'generate_3d_LUT_image',
30            'generate_3d_LUT_from_image',
31            'apply_CTL_to_image',
32            'convert_bit_depth',
33            'generate_1d_LUT_from_CTL',
34            'correct_LUT_image',
35            'generate_3d_LUT_from_CTL',
36            'main']
37
38
39 def generate_1d_LUT_image(ramp_1d_path,
40                           resolution=1024,
41                           min_value=0,
42                           max_value=1):
43     """
44     Object description.
45
46     Parameters
47     ----------
48     parameter : type
49         Parameter description.
50
51     Returns
52     -------
53     type
54          Return value description.
55     """
56
57     ramp = oiio.ImageOutput.create(ramp_1d_path)
58
59     spec = oiio.ImageSpec()
60     spec.set_format(oiio.FLOAT)
61     # spec.format.basetype = oiio.FLOAT
62     spec.width = resolution
63     spec.height = 1
64     spec.nchannels = 3
65
66     ramp.open(ramp_1d_path, spec, oiio.Create)
67
68     data = array.array('f',
69                        '\0' * spec.width * spec.height * spec.nchannels * 4)
70     for i in range(resolution):
71         value = float(i) / (resolution - 1) * (
72             max_value - min_value) + min_value
73         data[i * spec.nchannels + 0] = value
74         data[i * spec.nchannels + 1] = value
75         data[i * spec.nchannels + 2] = value
76
77     ramp.write_image(spec.format, data)
78     ramp.close()
79
80
81 def write_SPI_1d(filename, 
82                  from_min, 
83                  from_max, 
84                  data, 
85                  entries, 
86                  channels, 
87                  components=3):
88     """
89     Object description.
90
91     Credit to *Alex Fry* for the original single channel version of the spi1d
92     writer.
93
94     Parameters
95     ----------
96     parameter : type
97         Parameter description.
98
99     Returns
100     -------
101     type
102          Return value description.
103     """
104
105     # May want to use fewer components than there are channels in the data
106     # Most commonly used for single channel LUTs
107     components = min(3, components, channels)
108
109     with open(filename, 'w') as fp:
110         fp.write('Version 1\n')
111         fp.write('From %f %f\n' % (from_min, from_max))
112         fp.write('Length %d\n' % entries)
113         fp.write('Components %d\n' % components)
114         fp.write('{\n')
115         for i in range(0, entries):
116             entry = ''
117             for j in range(0, components):
118                 entry = '%s %s' % (entry, data[i * channels + j])
119             fp.write('        %s\n' % entry)
120         fp.write('}\n')
121
122
123 def write_CSP_1d(filename, 
124                  from_min, 
125                  from_max, 
126                  data, 
127                  entries, 
128                  channels, 
129                  components=3):
130     """
131     Object description.
132
133     Parameters
134     ----------
135     parameter : type
136         Parameter description.
137
138     Returns
139     -------
140     type
141          Return value description.
142     """
143
144     # May want to use fewer components than there are channels in the data
145     # Most commonly used for single channel LUTs
146     components = min(3, components, channels)
147
148     with open(filename, 'w') as fp:
149         fp.write('CSPLUTV100\n')
150         fp.write('1D\n')
151         fp.write('\n')
152         fp.write('BEGIN METADATA')
153         fp.write('END METADATA')
154
155         fp.write('\n')
156
157         fp.write('2\n')
158         fp.write('%f %f\n' % (from_min, from_max))
159         fp.write('0.0 1.0\n')
160         fp.write('2\n')
161         fp.write('%f %f\n' % (from_min, from_max))
162         fp.write('0.0 1.0\n')
163         fp.write('2\n')
164         fp.write('%f %f\n' % (from_min, from_max))
165         fp.write('0.0 1.0\n')
166
167         fp.write('\n')
168
169         fp.write('%d\n' % entries)
170         if components == 1:
171           for i in range(0, entries):
172               entry = ''
173               for j in range(3):
174                   entry = '%s %s' % (entry, data[i * channels])
175               fp.write('%s\n' % entry)
176         else:
177           for i in range(entries):
178               entry = ''
179               for j in range(components):
180                   entry = '%s %s' % (entry, data[i * channels + j])
181               fp.write('%s\n' % entry)
182         fp.write('\n')
183
184 def write_1d(filename, 
185              from_min, 
186              from_max, 
187              data, 
188              data_entries, 
189              data_channels, 
190              lut_components=3,
191              format='spi1d'):
192     """
193     Object description.
194
195     Parameters
196     ----------
197     parameter : type
198         Parameter description.
199
200     Returns
201     -------
202     type
203          Return value description.
204     """
205
206     ocioFormatsToExtensions = {'cinespace' : 'csp',
207                                'flame'     : '3dl',
208                                'icc'       : 'icc',
209                                'houdini'   : 'lut',
210                                'lustre'    : '3dl'}
211
212     if format in ocioFormatsToExtensions:
213       if ocioFormatsToExtensions[format] == 'csp':
214         write_CSP_1d(filename,
215                      from_min,
216                      from_max,
217                      data,
218                      data_entries,
219                      data_channels,
220                      lut_components)
221     else:
222       write_SPI_1d(filename,
223                    from_min,
224                    from_max,
225                    data,
226                    data_entries,
227                    data_channels,
228                    lut_components)
229
230 def generate_1d_LUT_from_image(ramp_1d_path,
231                                output_path=None,
232                                min_value=0,
233                                max_value=1,
234                                channels=3,
235                                format='spi1d'):
236     """
237     Object description.
238
239     Parameters
240     ----------
241     parameter : type
242         Parameter description.
243
244     Returns
245     -------
246     type
247          Return value description.
248     """
249
250     if output_path is None:
251         output_path = '%s.%s' % (ramp_1d_path, 'spi1d')
252
253     ramp = oiio.ImageInput.open(ramp_1d_path)
254
255     ramp_spec = ramp.spec()
256     ramp_width = ramp_spec.width
257     ramp_channels = ramp_spec.nchannels
258
259     # Forcibly read data as float, the Python API doesn't handle half-float
260     # well yet.
261     type = oiio.FLOAT
262     ramp_data = ramp.read_image(type)
263
264     write_1d(output_path, min_value, max_value, 
265       ramp_data, ramp_width, ramp_channels, channels, format)
266
267
268 def generate_3d_LUT_image(ramp_3d_path, resolution=32):
269     """
270     Object description.
271
272     Parameters
273     ----------
274     parameter : type
275         Parameter description.
276
277     Returns
278     -------
279     type
280          Return value description.
281     """
282
283     args = ['--generate',
284             '--cubesize',
285             str(resolution),
286             '--maxwidth',
287             str(resolution * resolution),
288             '--output',
289             ramp_3d_path]
290     lut_extract = Process(description='generate a 3d LUT image',
291                           cmd='ociolutimage',
292                           args=args)
293     lut_extract.execute()
294
295
296 def generate_3d_LUT_from_image(ramp_3d_path, 
297                                output_path=None, 
298                                resolution=32,
299                                format='spi3d'):
300     """
301     Object description.
302
303     Parameters
304     ----------
305     parameter : type
306         Parameter description.
307
308     Returns
309     -------
310     type
311          Return value description.
312     """
313
314     if output_path is None:
315         output_path = '%s.%s' % (ramp_3d_path, 'spi3d')
316
317     ocioFormatsToExtensions = {'cinespace' : 'csp',
318                                'flame'     : '3dl',
319                                'icc'       : 'icc',
320                                'houdini'   : 'lut',
321                                'lustre'    : '3dl'}
322
323     if format == 'spi3d' or not (format in ocioFormatsToExtensions):
324       # Extract a spi3d LUT
325       args = ['--extract',
326               '--cubesize',
327               str(resolution),
328               '--maxwidth',
329               str(resolution * resolution),
330               '--input',
331               ramp_3d_path,
332               '--output',
333               output_path]
334       lut_extract = Process(description='extract a 3d LUT',
335                             cmd='ociolutimage',
336                             args=args)
337       lut_extract.execute()
338
339     else:
340       output_path_spi3d = '%s.%s' % (output_path, 'spi3d')
341
342       # Extract a spi3d LUT
343       args = ['--extract',
344               '--cubesize',
345               str(resolution),
346               '--maxwidth',
347               str(resolution * resolution),
348               '--input',
349               ramp_3d_path,
350               '--output',
351               output_path_spi3d]
352       lut_extract = Process(description='extract a 3d LUT',
353                             cmd='ociolutimage',
354                             args=args)
355       lut_extract.execute()
356
357       # Convert to a different format
358       args = ['--lut',
359               output_path_spi3d,
360               '--format',
361               format,
362               output_path]
363       lut_convert = Process(description='convert a 3d LUT',
364                             cmd='ociobakelut',
365                             args=args)
366       lut_convert.execute()
367
368
369 def apply_CTL_to_image(input_image,
370                        output_image,
371                        ctl_paths=None,
372                        input_scale=1,
373                        output_scale=1,
374                        global_params=None,
375                        aces_ctl_directory=None):
376     """
377     Object description.
378
379     Parameters
380     ----------
381     parameter : type
382         Parameter description.
383
384     Returns
385     -------
386     type
387          Return value description.
388     """
389
390     if ctl_paths is None:
391         ctl_paths = []
392     if global_params is None:
393         global_params = {}
394
395     if len(ctl_paths) > 0:
396         ctlenv = os.environ
397         if aces_ctl_directory is not None:
398             if os.path.split(aces_ctl_directory)[1] != 'utilities':
399                 ctl_module_path = os.path.join(aces_ctl_directory, 'utilities')
400             else:
401                 ctl_module_path = aces_ctl_directory
402             ctlenv['CTL_MODULE_PATH'] = ctl_module_path
403
404         args = []
405         for ctl in ctl_paths:
406             args += ['-ctl', ctl]
407         args += ['-force']
408         args += ['-input_scale', str(input_scale)]
409         args += ['-output_scale', str(output_scale)]
410         args += ['-global_param1', 'aIn', '1.0']
411         for key, value in global_params.iteritems():
412             args += ['-global_param1', key, str(value)]
413         args += [input_image]
414         args += [output_image]
415
416         ctlp = Process(description='a ctlrender process',
417                        cmd='ctlrender',
418                        args=args, env=ctlenv)
419
420         ctlp.execute()
421
422
423 def convert_bit_depth(input_image, output_image, depth):
424     """
425     Object description.
426
427     Parameters
428     ----------
429     parameter : type
430         Parameter description.
431
432     Returns
433     -------
434     type
435          Return value description.
436     """
437
438     args = [input_image,
439             '-d',
440             depth,
441             '-o',
442             output_image]
443     convert = Process(description='convert image bit depth',
444                       cmd='oiiotool',
445                       args=args)
446     convert.execute()
447
448
449 def generate_1d_LUT_from_CTL(lut_path,
450                              ctl_paths,
451                              lut_resolution=1024,
452                              identity_LUT_bit_depth='half',
453                              input_scale=1,
454                              output_scale=1,
455                              global_params=None,
456                              cleanup=True,
457                              aces_ctl_directory=None,
458                              min_value=0,
459                              max_value=1,
460                              channels=3,
461                              format='spi1d'):
462     """
463     Object description.
464
465     Parameters
466     ----------
467     parameter : type
468         Parameter description.
469
470     Returns
471     -------
472     type
473          Return value description.
474     """
475
476     if global_params is None:
477         global_params = {}
478
479     lut_path_base = os.path.splitext(lut_path)[0]
480
481     identity_LUT_image_float = '%s.%s.%s' % (lut_path_base, 'float', 'tiff')
482     generate_1d_LUT_image(identity_LUT_image_float,
483                           lut_resolution,
484                           min_value,
485                           max_value)
486
487     if identity_LUT_bit_depth not in ['half', 'float']:
488         identity_LUT_image = '%s.%s.%s' % (lut_path_base, 'uint16', 'tiff')
489         convert_bit_depth(identity_LUT_image_float,
490                           identity_LUT_image,
491                           identity_LUT_bit_depth)
492     else:
493         identity_LUT_image = identity_LUT_image_float
494
495     transformed_LUT_image = '%s.%s.%s' % (lut_path_base, 'transformed', 'exr')
496     apply_CTL_to_image(identity_LUT_image,
497                        transformed_LUT_image,
498                        ctl_paths,
499                        input_scale,
500                        output_scale,
501                        global_params,
502                        aces_ctl_directory)
503
504     generate_1d_LUT_from_image(transformed_LUT_image,
505                                lut_path,
506                                min_value,
507                                max_value,
508                                channels,
509                                format)
510
511     if cleanup:
512         os.remove(identity_LUT_image)
513         if identity_LUT_image != identity_LUT_image_float:
514             os.remove(identity_LUT_image_float)
515         os.remove(transformed_LUT_image)
516
517
518 def correct_LUT_image(transformed_LUT_image,
519                       corrected_LUT_image,
520                       lut_resolution):
521     """
522     Object description.
523
524     Parameters
525     ----------
526     parameter : type
527         Parameter description.
528
529     Returns
530     -------
531     type
532          Return value description.
533     """
534
535     transformed = oiio.ImageInput.open(transformed_LUT_image)
536
537     transformed_spec = transformed.spec()
538     width = transformed_spec.width
539     height = transformed_spec.height
540     channels = transformed_spec.nchannels
541
542     if width != lut_resolution * lut_resolution or height != lut_resolution:
543         print(('Correcting image as resolution is off. '
544                'Found %d x %d. Expected %d x %d') % (
545                   width,
546                   height,
547                   lut_resolution * lut_resolution,
548                   lut_resolution))
549         print('Generating %s' % corrected_LUT_image)
550
551         # Forcibly read data as float, the Python API doesn't handle half-float
552         # well yet.
553         type = oiio.FLOAT
554         source_data = transformed.read_image(type)
555
556         correct = oiio.ImageOutput.create(corrected_LUT_image)
557
558         correct_spec = oiio.ImageSpec()
559         correct_spec.set_format(oiio.FLOAT)
560         correct_spec.width = height
561         correct_spec.height = width
562         correct_spec.nchannels = channels
563
564         correct.open(corrected_LUT_image, correct_spec, oiio.Create)
565
566         dest_data = array.array('f',
567                                 ('\0' * correct_spec.width *
568                                  correct_spec.height *
569                                  correct_spec.nchannels * 4))
570         for j in range(0, correct_spec.height):
571             for i in range(0, correct_spec.width):
572                 for c in range(0, correct_spec.nchannels):
573                     dest_data[(correct_spec.nchannels *
574                                correct_spec.width * j +
575                                correct_spec.nchannels * i + c)] = (
576                         source_data[correct_spec.nchannels *
577                                     correct_spec.width * j +
578                                     correct_spec.nchannels * i + c])
579
580         correct.write_image(correct_spec.format, dest_data)
581         correct.close()
582     else:
583         # shutil.copy(transformedLUTImage, correctedLUTImage)
584         corrected_LUT_image = transformed_LUT_image
585
586     transformed.close()
587
588     return corrected_LUT_image
589
590
591 def generate_3d_LUT_from_CTL(lut_path,
592                              ctl_paths,
593                              lut_resolution=64,
594                              identity_LUT_bit_depth='half',
595                              input_scale=1,
596                              output_scale=1,
597                              global_params=None,
598                              cleanup=True,
599                              aces_ctl_directory=None,
600                              format='spi3d'):
601     """
602     Object description.
603
604     Parameters
605     ----------
606     parameter : type
607         Parameter description.
608
609     Returns
610     -------
611     type
612          Return value description.
613     """
614
615     if global_params is None:
616         global_params = {}
617
618     lut_path_base = os.path.splitext(lut_path)[0]
619
620     identity_LUT_image_float = '%s.%s.%s' % (lut_path_base, 'float', 'tiff')
621     generate_3d_LUT_image(identity_LUT_image_float, lut_resolution)
622
623     if identity_LUT_bit_depth not in ['half', 'float']:
624         identity_LUT_image = '%s.%s.%s' % (lut_path_base,
625                                            identity_LUT_bit_depth,
626                                            'tiff')
627         convert_bit_depth(identity_LUT_image_float,
628                           identity_LUT_image,
629                           identity_LUT_bit_depth)
630     else:
631         identity_LUT_image = identity_LUT_image_float
632
633     transformed_LUT_image = '%s.%s.%s' % (lut_path_base, 'transformed', 'exr')
634     apply_CTL_to_image(identity_LUT_image,
635                        transformed_LUT_image,
636                        ctl_paths,
637                        input_scale,
638                        output_scale,
639                        global_params,
640                        aces_ctl_directory)
641
642     corrected_LUT_image = '%s.%s.%s' % (lut_path_base, 'correct', 'exr')
643     corrected_LUT_image = correct_LUT_image(transformed_LUT_image,
644                                             corrected_LUT_image,
645                                             lut_resolution)
646
647     generate_3d_LUT_from_image(corrected_LUT_image, 
648                                lut_path, 
649                                lut_resolution, 
650                                format)
651
652     if cleanup:
653         os.remove(identity_LUT_image)
654         if identity_LUT_image != identity_LUT_image_float:
655             os.remove(identity_LUT_image_float)
656         os.remove(transformed_LUT_image)
657         if corrected_LUT_image != transformed_LUT_image:
658             os.remove(corrected_LUT_image)
659         if format != 'spi3d':
660             lut_path_spi3d = '%s.%s' % (lut_path, 'spi3d')
661             os.remove(lut_path_spi3d)
662
663 def main():
664     """
665     Object description.
666
667     Parameters
668     ----------
669     parameter : type
670         Parameter description.
671
672     Returns
673     -------
674     type
675          Return value description.
676     """
677
678     import optparse
679
680     p = optparse.OptionParser(
681         description='A utility to generate LUTs from CTL',
682         prog='generateLUT',
683         version='0.01',
684         usage='%prog [options]')
685
686     p.add_option('--lut', '-l', type='string', default='')
687     p.add_option('--format', '-f', type='string', default='')
688     p.add_option('--ctl', '-c', type='string', action='append')
689     p.add_option('--lutResolution1d', '', type='int', default=1024)
690     p.add_option('--lutResolution3d', '', type='int', default=33)
691     p.add_option('--ctlReleasePath', '-r', type='string', default='')
692     p.add_option('--bitDepth', '-b', type='string', default='float')
693     p.add_option('--keepTempImages', '', action='store_true')
694     p.add_option('--minValue', '', type='float', default=0)
695     p.add_option('--maxValue', '', type='float', default=1)
696     p.add_option('--inputScale', '', type='float', default=1)
697     p.add_option('--outputScale', '', type='float', default=1)
698     p.add_option('--ctlRenderParam', '-p', type='string', nargs=2,
699                  action='append')
700
701     p.add_option('--generate1d', '', action='store_true')
702     p.add_option('--generate3d', '', action='store_true')
703
704     options, arguments = p.parse_args()
705
706     lut = options.lut
707     format = options.format
708     ctls = options.ctl
709     lut_resolution_1d = options.lutResolution1d
710     lut_resolution_3d = options.lutResolution3d
711     min_value = options.minValue
712     max_value = options.maxValue
713     input_scale = options.inputScale
714     output_scale = options.outputScale
715     ctl_release_path = options.ctlReleasePath
716     generate_1d = options.generate1d is True
717     generate_3d = options.generate3d is True
718     bit_depth = options.bitDepth
719     cleanup = not options.keepTempImages
720
721     params = {}
722     if options.ctlRenderParam is not None:
723         for param in options.ctlRenderParam:
724             params[param[0]] = float(param[1])
725
726     try:
727         args_start = sys.argv.index('--') + 1
728         args = sys.argv[args_start:]
729     except:
730         args_start = len(sys.argv) + 1
731         args = []
732
733     if generate_1d:
734         print('1D LUT generation options')
735     else:
736         print('3D LUT generation options')
737
738     print('lut                 : %s' % lut)
739     print('format              : %s' % format)
740     print('ctls                : %s' % ctls)
741     print('lut res 1d          : %s' % lut_resolution_1d)
742     print('lut res 3d          : %s' % lut_resolution_3d)
743     print('min value           : %s' % min_value)
744     print('max value           : %s' % max_value)
745     print('input scale         : %s' % input_scale)
746     print('output scale        : %s' % output_scale)
747     print('ctl render params   : %s' % params)
748     print('ctl release path    : %s' % ctl_release_path)
749     print('bit depth of input  : %s' % bit_depth)
750     print('cleanup temp images : %s' % cleanup)
751
752     if generate_1d:
753         generate_1d_LUT_from_CTL(lut,
754                                  ctls,
755                                  lut_resolution_1d,
756                                  bit_depth,
757                                  input_scale,
758                                  output_scale,
759                                  params,
760                                  cleanup,
761                                  ctl_release_path,
762                                  min_value,
763                                  max_value,
764                                  format=format)
765
766     elif generate_3d:
767         generate_3d_LUT_from_CTL(lut,
768                                  ctls,
769                                  lut_resolution_3d,
770                                  bit_depth,
771                                  input_scale,
772                                  output_scale,
773                                  params,
774                                  cleanup,
775                                  ctl_release_path,
776                                  format=format)
777     else:
778         print(('\n\nNo LUT generated. '
779                'You must choose either 1D or 3D LUT generation\n\n'))
780
781
782 if __name__ == '__main__':
783     main()
784