Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: lingvo.jax exception flag mechanism #333

Open
drpngx opened this issue Sep 12, 2023 · 0 comments
Open

RFC: lingvo.jax exception flag mechanism #333

drpngx opened this issue Sep 12, 2023 · 0 comments

Comments

@drpngx
Copy link
Contributor

drpngx commented Sep 12, 2023

So, it's nice to be able to have runtime asserts, but we can't have them in a pure function.

In X86, when the processor encounters an error, it raises the floating point exception flag and goes on. You can check that flag at the end of the computation.

I prototyped a similar mechanism in my code. An error flag is raised during the computation. At the end of the computation, we check whether the error flag has been raised.

Unlike asserts, this method:

  • Can be enabled in prod on the same graph. You could place the checks on a device of your choosing.
  • Does not know which error came first. If there is a cascade of errors, you will not know. I don't think you can know unless you annotate the graph, or provide a jax op to timestamp the execution.
  • Does not interrupt errors which cause infinite loops.

The second issue is the annoying one. My assumption is that you errors should generally happen in the order in which they are constructed.

I've implemented it in my model.py but it should go in layers.

model.py:
import traceback

def Check(ok_cond: JTensor, errflag: JTensor, errinfo: List[Any]) -> JTensor:
  AssertShape(ok_cond, [])
  errinfo.append(traceback.extract_stack())
  err = jnp.logical_not(ok_cond)
  return jnp.concatenate([errflag, err[jnp.newaxis]])


class Model(base_model.BaseModel):
   def __init__(...):
    self.errinfo = []
    self.errflag = jnp.zeros(shape=[0], dtype=bool)

  def compute_loss(...):
    invariant = jnp.prod((z > 0).astype(int)) == 0
    self.errflag = Check(invariant, self.errflag, self.errinfo)

Then, in lingvo/jax/train.py:

def SafePmap(errsrc, step, **pmap_kwargs):
  # errsrc = jax_task.model which contains errflags and errinfo
  import numpy as np
  def StepWithErrflag(*args, **kwargs):
    # The step will construct errflag and errinfo.
    # This is pure and can be compiled.
    ret = list(step(*args, **kwargs))
    ret += [errsrc.errflag]
    return ret

  def RunCheckCompileStep(*arg, **kwargs):
    ret = compiled_step(*arg, **kwargs)
    # This runs in python land and we can access python objects.
    errinfo = errsrc.errinfo
    errflag = ret[-1][0]
    if np.sum(errflag):
      logging.info('==== ERRORS found: %s', np.sum(errflag))
      for k, flag in enumerate(errflag):
        if flag:
          logging.info('== ERR[%d] at %s', k, '\n'.join(errinfo[k].format()))
      raise ValueError('Exception flag raised')
    return ret[:-1]

  compiled_step = jax.pmap(StepWithErrflag, **pmap_kwargs)

  return RunCheckCompileStep

...
def train_and_evaluate_pmap(...):
  ...
  p_train_step = SafePmap(jax_task.model, train_step, donate_argnums=(0,), axis_name='batch')

There is the pjit.pjit in trainer_lib.py as well but it's not called in my path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant