1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
|
class ModelConfig(PretrainedConfig): model_type = "tfmodel"
def __init__(self, size=1024, **kwargs): super().__init__(**kwargs) self.size = size
config = ModelConfig(size=1024)
@dataclass class ModelOut(ModelOutput): logits: tf.Tensor = None
class Model(TFPreTrainedModel):
@property def dummy_inputs(self): dummy = tf.constant(tf.ones((self.config.size, self.config.size), dtype=tf.float32)) return dummy
def __init__(self, config: ModelConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.config = config
self.dense1 = tf.keras.layers.Dense(self.config.size, name="dense1") self.dense2 = tf.keras.layers.Dense(self.config.size, name="dense2")
def call(self, inputs, training=None, mask=None): out = self.dense1(inputs) out = self.dense2(out) return out
@tf.function( input_signature=[ (tf.TensorSpec(shape=(None, config.size), dtype=tf.float32, name="inputs")) ] ) def serving(self, inputs): output = self.call(inputs) return self.serving_output(output)
def serving_output(self, output): return ModelOut(logits=output)
if __name__ == '__main__': saved_path = "saved_tf" os.makedirs(saved_path, exist_ok=True)
config = ModelConfig(size=1024) data = tf.ones((1024, 1024)) model = Model(config=config) out = model(data) print(f"model output:", out) model.save_pretrained(saved_path, max_shard_size="6MB") print(model.trainable_variables)
model1 = Model(config=config).from_pretrained(saved_path, config=config) out1 = model1(data) print(f"model1 output:", out) print(model1.trainable_variables)
|