Predictor#
- class pai.predictor.Predictor(service_name: str, endpoint_type: str = 'INTERNET', serializer: Optional[SerializerBase] = None, session: Optional[Session] = None)#
Bases:
PredictorBase
,_ServicePredictorMixin
Predictor is responsible for making prediction to an online service.
The predictor.predict method sends the input data to the online prediction service and returns the prediction result. The serializer object of the predictor is responsible for data transformation when the predict method is invoked. The input data is serialized using the serializer.serialize method before it is sent, and the response is deserialized using the serializer.deserialize method before the prediction result returns.
Examples:
# Initialize a predictor object from an existing service using PyTorch # processor. torch_predictor = Predictor(service_name="example_torch_service") result = torch_predictor.predict(numpy.asarray([[22,33,44], [19,22,33]])) assert isinstance(result, numpy.ndarray)
Construct a Predictor object using an existing prediction service.
- Parameters
service_name (str) – Name of the existing prediction service.
endpoint_type (str) – Selects the endpoint used by the predictor, which should be one of INTERNET or INTRANET. The INTERNET endpoint type means that the predictor calls the service over a public endpoint, while the INTRANET endpoint type is over a VPC endpoint.
serializer (SerializerBase, optional) – A serializer object that transforms the input Python object for data transmission and deserialize the response data to Python object.
session (Session, optional) – A PAI session object used for communicating with PAI service.
- predict(data)#
Make a prediction with the online prediction service.
The serializer object for the predictor is responsible for data transformation when the ‘predict’ method is invoked. The input data is serialized using the serializer.serialize method before it is sent, and the response is deserialized using the serializer.deserialize method before the prediction result returns.
- Parameters
data – The input data for the prediction. It will be serialized using the serializer of the predictor before transmitted to the prediction service.
- Returns
Prediction result.
- Return type
object
- Raises
PredictionException – Raise if status code of the prediction response does not equal 2xx.
- raw_predict(data: Optional[Any] = None, path: Optional[str] = None, headers: Optional[Dict[str, str]] = None, method: str = 'POST', timeout: Optional[Union[float, Tuple[float, float]]] = None, **kwargs) RawResponse #
Make a prediction with the online prediction service.
- Parameters
data (Any) – Input data to be sent to the prediction service. If it is a file-like object, bytes, or string, it will be sent as the request body. Otherwise, it will be treated as a JSON serializable object and sent as JSON.
path (str, optional) – Path for the request to be sent to. If it is provided, it will be appended to the endpoint URL (Default None).
headers (dict, optional) – Request headers.
method (str, optional) – Request method, default to ‘POST’.
timeout (float, tuple(float, float), optional) – Timeout setting for the request (Default 10).
**kwargs – Additional keyword arguments for the request.
- Returns
Prediction response from the service.
- Return type
RawResponse
- Raises
PredictionException – Raise if status code of the prediction response does not equal 2xx.