0

I'm trying to define a gradient method for my custom TF operation. Most of the solutions I have found online seem to based on a gist by harpone. I'm reluctant to use that approach as it uses py_func which won't run on GPU. I found another solution here that uses tf.identity() that looks more elegant and I think will run on GPU. However, I have some problems accessing inputs of the ops in my custom gradient function. Here's my code:

@tf.RegisterGradient('MyCustomGradient')
def _custom_gradient(op, gradients):
    x = op.inputs[0]
    return(x)

def my_op(w):
    return tf.pow(w,3)


var_foo = tf.Variable(5, dtype=tf.float32)
bar = my_op(var_foo)


g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyCustomGradient'}):
    bar = tf.identity(bar)
    g = tf.gradients(bar, var_foo)

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())
    print(sess.run(g))

I was expecting _custom_gradient() to return the input to the op (5 in this example) but instead it seems to return op output x gradient. My custom my_op will have non-differentiable operations like tf.sign and I'd like to define my custom gradient based on the inputs. What am I doing wrong?

Milad
  • 4,901
  • 5
  • 32
  • 43
  • I think what's going on is that the custom gradient is attached to the `identity()` op and not the `my_op()` function as I had hoped. – Milad Dec 28 '17 at 20:35

1 Answers1

2

There is no problem with your code:

Let's first do the forward pass:

var_foo = 5 -> bar = 125 -> tf.identity(bar) = 125

Now let's backpropagate:

The gradient of tf.identity(bar) with respect to its argument bar equals (by your definition) to bar, that is, 125. The gradient of bar with respect to var_foo equals 3 times the square of var_foo which is 75. Multiply, and you get 9375, which is indeed the output of your code.

op.inputs[0] contains the forward-pass value of the op. In this case, the forward pass of the identity op is 125.

Lior
  • 2,019
  • 1
  • 15
  • 22
  • Thanks that makes total sense. I was hoping to define a custom gradient for the `my_op` function but instead I've defined it for tf.identity. I guess I have to fall back to py_func. – Milad Dec 28 '17 at 21:03