c

com.johnsnowlabs.ml.tensorflow

TensorflowDistilBert

class TensorflowDistilBert extends Serializable

The DistilBERT model was proposed in the paper DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter https://arxiv.org/abs/1910.01108. DistilBERT is a small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than bert-base-uncased, runs 60% faster while preserving over 95% of BERT's performances as measured on the GLUE language understanding benchmark.

The abstract from the paper is the following:

As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive biases learned by larger models during pretraining, we introduce a triple loss combining language modeling, distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device study.

Tips:

- DistilBERT doesn't have :obj:token_type_ids, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token :obj:tokenizer.sep_token (or :obj:[SEP]).

- DistilBERT doesn't have options to select the input positions (:obj:position_ids input). This could be added if necessary though, just let us know if you need this option.

Linear Supertypes
Serializable, Serializable, AnyRef, Any
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. TensorflowDistilBert
  2. Serializable
  3. Serializable
  4. AnyRef
  5. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new TensorflowDistilBert(tensorflowWrapper: TensorflowWrapper, sentenceStartTokenId: Int, sentenceEndTokenId: Int, configProtoBytes: Option[Array[Byte]] = None, signatures: Option[Map[String, String]] = None)

    tensorflowWrapper

    Bert Model wrapper with TensorFlow Wrapper

    sentenceStartTokenId

    Id of sentence start Token

    sentenceEndTokenId

    Id of sentence end Token.

    configProtoBytes

    Configuration for TensorFlow session

Value Members

  1. final def !=(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int
    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  4. val _tfBertSignatures: Map[String, String]
  5. final def asInstanceOf[T0]: T0
    Definition Classes
    Any
  6. def calculateEmbeddings(sentences: Seq[WordpieceTokenizedSentence], originalTokenSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence]
  7. def calculateSentenceEmbeddings(tokens: Seq[WordpieceTokenizedSentence], sentences: Seq[Sentence], batchSize: Int, maxSentenceLength: Int): Seq[Annotation]
  8. def clone(): AnyRef
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  9. def encode(sentences: Seq[(WordpieceTokenizedSentence, Int)], maxSequenceLength: Int): Seq[Array[Int]]

    Encode the input sequence to indexes IDs adding padding where necessary

  10. final def eq(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  11. def equals(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  12. def finalize(): Unit
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  13. final def getClass(): Class[_]
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  14. def hashCode(): Int
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  15. final def isInstanceOf[T0]: Boolean
    Definition Classes
    Any
  16. final def ne(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  17. final def notify(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  18. final def notifyAll(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  19. final def synchronized[T0](arg0: ⇒ T0): T0
    Definition Classes
    AnyRef
  20. def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]]
  21. def tagSentence(batch: Seq[Array[Int]]): Array[Array[Float]]

    batch

    batches of sentences

    returns

    batches of vectors for each sentence

  22. val tensorflowWrapper: TensorflowWrapper
  23. def toString(): String
    Definition Classes
    AnyRef → Any
  24. final def wait(): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  25. final def wait(arg0: Long, arg1: Int): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  26. final def wait(arg0: Long): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()

Inherited from Serializable

Inherited from Serializable

Inherited from AnyRef

Inherited from Any

Ungrouped