Content Overview
- Effects of turning on the new type promotion
- More consistent and predictable promotion type
- Reduced risk of bit-widening
- tf.Tensor mathematical dunder methods
- tf.Variable in-place ops
- tf.constant implicit conversions
- TF-NumPay Array
- Input Type Inference
- WeakTensor-supporting APIs
Below is a non-exhaustive list of changes that result from turning on the new type promotion.
- More consistent and predictable promotion results.
- Reduced risk of bit-widening.
tf.Tensormathematical dunder methods use new type promotion.tf.constantcan returnWeakTensor.tf.constantallows implicit conversions when a Tensor input with a dtype different from thedtypearg is passed in.tf.Variablein-place ops (assign,assign-add,assign-sub) allow implicit conversions.tnp.array(1)andtnp.array(1.0)returns 32-bit WeakTensor.WeakTensors will be created and used for WeakTensor-supporting unary and binary API's.
More consistent and predictable promotion results
Using a lattice-based system allows the new type promotion to produce consistent and predictable type promotion results.
Old Type Promotion
Changing the order of operations produces inconsistent results using old type promotion.
# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
print(f'{type(e)}: {e}') # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>
New Type Promotion
New type promotion produces consistent results regardless of the order.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
Reduced risk of bit-widening
Old Type Promotion
Old type promotion often resulted in 64-bit results.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
New Type Promotion
New type promotion returns results with minimal number of bits necessary.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.1875>
tf.Tensor mathematical dunder methods
All tf.Tensor mathematical dunder methods will follow the new type promotion.
-tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
tf.Variable in-place ops
Implicit conversions will be allowed in tf.Variable in-place ops.
Note: Any promotion that results in a dtype that is different from the variable's original dtype will be not allowed. This is because tf.Variable cannot change its dtype.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16)) # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>
tf.constant implicit conversions
In the old type promotion, tf.constant required an input Tensor to have the same dtype as the dtype argument. However, in the new type promotion, we implicitly convert Tensor to the specified dtype.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
TF-NumPy Array
tnp.array defaults to i32* and f32* for python inputs using the new type promotion.
tnp.array(1) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>
Input Type Inference
This is how different inputs' types are inferred in the new type promotion.
tf.Tensor: Sincetf.Tensorhas a dtype property, we don't do further inference.- NumPy types: This includes types like
np.array(1),np.int16(1), andnp.float. Since NumPy inputs also have a dtype property, we take the dtype property as the result inference type. Note that NumPy defaults toi64andf64. - Python scalars/Nested types: This includes types like
1,[1, 2, 3], and(1.0, 2.0).- Python
intis inferred asi32*. - Python
floatis inferred asf32*. - Python
complexis inferred asc128*.
- Python
- If the input doesn't fall into any of the above categories but has a dtype property, we take the dtype property as the result inference type.
Further Reading
The new type promotion closely resembles JAX-NumPy's type promotion. If you want to know more details about the new type promotion and the design choices, check out the resources below.
- JAX Type Promotion Semantics
- Design of Type Promotion Semantics for JAX
- Old TF-NumPy Promotion Semantics
WeakTensor-supporting APIs
Below is a list of APIs that supports WeakTensor.
For an unary op, this means that if an input with no user-specified type is passed in, it will return a WeakTensor.
For a binary op, it will follow the promotion table here. It may or may not return a WeakTensor depending on the promotion result of the two inputs.
Note: All mathematical operations (+, -, *, ...) are supported.
tf.bitwise.inverttf.clip_by_valuetf.debugging.check_numericstf.expand_dimstf.identitytf.image.adjust_brightnesstf.image.adjust_gammatf.image.extract_patchestf.image.random_brightnesstf.image.stateless_random_brightnesstf.linalg.diagtf.linalg.diag_parttf.linalg.matmultf.linalg.matrix_transposetf.linalg.tensor_diag_parttf.linalg.tracetf.math.abstf.math.acostf.math.acoshtf.math.addtf.math.angletf.math.asintf.math.asinhtf.math.atantf.math.atanhtf.math.ceiltf.math.conjtf.math.costf.math.coshtf.math.digammatf.math.divide_no_nantf.math.dividetf.math.erftf.math.erfctf.math.erfcinvtf.math.erfinvtf.math.exptf.math.expm1tf.math.floortf.math.floordivtf.math.floormodtf.math.imagtf.math.lgammatf.math.log1ptf.math.log_sigmoidtf.math.logtf.math.multiply_no_nantf.math.multiplytf.math.ndtritf.math.negativetf.math.powtf.math.realtf.math.realtf.math.reciprocal_no_nantf.math.reciprocaltf.math.reduce_euclidean_normtf.math.reduce_logsumexptf.math.reduce_maxtf.math.reduce_meantf.math.reduce_mintf.math.reduce_prodtf.math.reduce_stdtf.math.reduce_sumtf.math.reduce_variancetf.math.rinttf.math.roundtf.math.rsqrttf.math.scalar_multf.math.sigmoidtf.math.signtf.math.sintf.math.sinhtf.math.softplustf.math.special.bessel_i0tf.math.special.bessel_i0etf.math.special.bessel_i1tf.math.special.bessel_i1etf.math.special.bessel_j0tf.math.special.bessel_j1tf.math.special.bessel_k0tf.math.special.bessel_k0etf.math.special.bessel_k1tf.math.special.bessel_k1etf.math.special.bessel_y0tf.math.special.bessel_y1tf.math.special.dawsntf.math.special.expinttf.math.special.fresnel_costf.math.special.fresnel_sintf.math.special.spencetf.math.sqrttf.math.squaretf.math.subtracttf.math.tantf.math.tanhtf.nn.depth_to_spacetf.nn.elutf.nn.gelutf.nn.leaky_relutf.nn.log_softmaxtf.nn.relu6tf.nn.relutf.nn.selutf.nn.softsigntf.nn.space_to_depthtf.nn.swishtf.ones_liketf.realdivtf.reshapetf.squeezetf.stop_gradienttf.transposetf.truncatedivtf.truncatemodtf.zeros_liketf.experimental.numpy.abstf.experimental.numpy.absolutetf.experimental.numpy.amaxtf.experimental.numpy.amintf.experimental.numpy.angletf.experimental.numpy.arangetf.experimental.numpy.arccostf.experimental.numpy.arccoshtf.experimental.numpy.arcsintf.experimental.numpy.arcsinhtf.experimental.numpy.arctantf.experimental.numpy.arctanhtf.experimental.numpy.aroundtf.experimental.numpy.arraytf.experimental.numpy.asanyarraytf.experimental.numpy.asarraytf.experimental.numpy.ascontiguousarraytf.experimental.numpy.averagetf.experimental.numpy.bitwise_nottf.experimental.numpy.cbrttf.experimental.numpy.ceiltf.experimental.numpy.conjtf.experimental.numpy.conjugatetf.experimental.numpy.copytf.experimental.numpy.costf.experimental.numpy.coshtf.experimental.numpy.cumprodtf.experimental.numpy.cumsumtf.experimental.numpy.deg2radtf.experimental.numpy.diagtf.experimental.numpy.diagflattf.experimental.numpy.diagonaltf.experimental.numpy.difftf.experimental.numpy.empty_liketf.experimental.numpy.exp2tf.experimental.numpy.exptf.experimental.numpy.expand_dimstf.experimental.numpy.expm1tf.experimental.numpy.fabstf.experimental.numpy.fixtf.experimental.numpy.flattentf.experimental.numpy.fliptf.experimental.numpy.fliplrtf.experimental.numpy.flipudtf.experimental.numpy.floortf.experimental.numpy.full_liketf.experimental.numpy.imagtf.experimental.numpy.log10tf.experimental.numpy.log1ptf.experimental.numpy.log2tf.experimental.numpy.logtf.experimental.numpy.maxtf.experimental.numpy.meantf.experimental.numpy.mintf.experimental.numpy.moveaxistf.experimental.numpy.nanmeantf.experimental.numpy.negativetf.experimental.numpy.ones_liketf.experimental.numpy.positivetf.experimental.numpy.prodtf.experimental.numpy.rad2degtf.experimental.numpy.raveltf.experimental.numpy.realtf.experimental.numpy.reciprocaltf.experimental.numpy.repeattf.experimental.numpy.reshapetf.experimental.numpy.rot90tf.experimental.numpy.roundtf.experimental.numpy.signbittf.experimental.numpy.sintf.experimental.numpy.sinctf.experimental.numpy.sinhtf.experimental.numpy.sorttf.experimental.numpy.sqrttf.experimental.numpy.squaretf.experimental.numpy.squeezetf.experimental.numpy.stdtf.experimental.numpy.sumtf.experimental.numpy.swapaxestf.experimental.numpy.tantf.experimental.numpy.tanhtf.experimental.numpy.tracetf.experimental.numpy.transposetf.experimental.numpy.triutf.experimental.numpy.vandertf.experimental.numpy.vartf.experimental.numpy.zeros_like
Originally published on the
